Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
48    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
49    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
50    /// The kernel materializes scratch copies of A and b internally
51    /// (LAPACK overwrites both with LU factors and solution).
52    DenseSolveF64 {
53        a: usize,
54        b: usize,
55        x: usize,
56        n: u32,
57        nrhs: u32,
58    },
59    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
60    /// no-blas Rust fallback). Same arena byte-offset contract.
61    DenseSolveF32 {
62        a: usize,
63        b: usize,
64        x: usize,
65        n: u32,
66        nrhs: u32,
67    },
68    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
69    /// the leading slice; `batch` is the number of independent
70    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
71    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
72    BatchedDenseSolveF64 {
73        a: usize,
74        b: usize,
75        x: usize,
76        batch: u32,
77        n: u32,
78        nrhs: u32,
79    },
80    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
81    BatchedDenseSolveF32 {
82        a: usize,
83        b: usize,
84        x: usize,
85        batch: u32,
86        n: u32,
87        nrhs: u32,
88    },
89    /// Batched f64 matmul. Both inputs and output have a leading
90    /// batch axis of size `batch`. Per-batch independent dgemm:
91    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
92    /// that emit per-batch outer products (e.g., BatchedDenseSolve
93    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
94    BatchedDgemmF64 {
95        a: usize,
96        b: usize,
97        c: usize,
98        batch: u32,
99        m: u32,
100        k: u32,
101        n: u32,
102    },
103    /// Batched f32 matmul — same loop-per-batch shape as
104    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
105    /// patterns where both operands carry a batch dim (e.g. q@k^T
106    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
107    /// flatten trick is wrong in that case because it treats `b` as
108    /// a single shared RHS across every batch.
109    BatchedSgemm {
110        a: usize,
111        b: usize,
112        c: usize,
113        batch: u32,
114        m: u32,
115        k: u32,
116        n: u32,
117    },
118    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
119    Dgemm {
120        a: usize,
121        b: usize,
122        c: usize,
123        m: u32,
124        k: u32,
125        n: u32,
126    },
127    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
128    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
129    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
130    TransposeF64 {
131        src: usize,
132        dst: usize,
133        in_total: u32,
134        out_dims: Vec<u32>,
135        in_strides: Vec<u32>,
136    },
137    /// f64 element-wise activation. Single-input, single-output. The
138    /// kernel always reads from `src` and writes to `dst`, so it works
139    /// whether or not the planner aliased the two slots.
140    ActivationF64 {
141        src: usize,
142        dst: usize,
143        len: u32,
144        kind: Activation,
145    },
146    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
147    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
148    /// writes `len` f32 to `dst`.
149    ComplexNormSqF32 {
150        src: usize,
151        dst: usize,
152        /// Logical element count (number of complex values).
153        len: u32,
154    },
155    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
156    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
157    /// `2·len` f32 to `dz`.
158    ComplexNormSqBackwardF32 {
159        z: usize,
160        g: usize,
161        dz: usize,
162        len: u32,
163    },
164    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
165    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
166    ConjugateC64 { src: usize, dst: usize, len: u32 },
167    /// C64 element-wise activation. Only kinds with well-defined
168    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
169    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
170    /// Round, GeLU family) is rejected at lowering — those don't have
171    /// single natural complex definitions. `len` is the **complex
172    /// element count** (the f32 buffer holds `2·len` floats).
173    ActivationC64 {
174        src: usize,
175        dst: usize,
176        len: u32,
177        kind: Activation,
178    },
179    /// f64 contiguous reduction along a single axis range. Layout
180    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
181    /// Sum only for now (Mean composes via 1/N multiply post-pass).
182    ReduceSumF64 {
183        src: usize,
184        dst: usize,
185        outer: u32,
186        reduced: u32,
187        inner: u32,
188    },
189    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
190    /// but at 8 bytes per element.
191    CopyF64 { src: usize, dst: usize, len: u32 },
192    /// i64 element copy (Reshape/Cast on i64 tensors).
193    CopyI64 { src: usize, dst: usize, len: u32 },
194    /// Round f32 → i64 (ONNX Cast on duration scalar).
195    CastF32ToI64 { src: usize, dst: usize, len: u32 },
196    /// i64 → f32 (ONNX Cast on shape scalars, e.g. Albert head-dim).
197    CastI64ToF32 { src: usize, dst: usize, len: u32 },
198    /// bool → i32 (BERT attention mask grid).
199    CastBoolToI32 { src: usize, dst: usize, len: u32 },
200    /// i32 → f32 (BERT attention mask cast before subtract).
201    CastI32ToF32 { src: usize, dst: usize, len: u32 },
202    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
203    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
204    /// Mirror of `BinaryFull` at 8 bytes per element.
205    BinaryFullF64 {
206        lhs: usize,
207        rhs: usize,
208        dst: usize,
209        len: u32,
210        lhs_len: u32,
211        rhs_len: u32,
212        op: BinaryOp,
213        /// Output shape dims (row-major). Empty in the fast path. See
214        /// `BinaryFull` doc for the broadcast convention.
215        out_dims_bcast: Vec<u32>,
216        bcast_lhs_strides: Vec<u32>,
217        bcast_rhs_strides: Vec<u32>,
218    },
219    /// f64 concat — byte-for-byte mirror of `Concat` but copies
220    /// 8 bytes per element. Element-counted offsets/strides match
221    /// the f32 variant; the executor scales by elem_size internally.
222    ConcatF64 {
223        dst: usize,
224        outer: u32,
225        inner: u32,
226        total_axis: u32,
227        inputs: Vec<(usize, u32)>,
228    },
229    /// C64 element-wise binary with broadcast. Same `len` /
230    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
231    /// "element" is one complex value (8 bytes = `[re, im]` as two
232    /// f32s). The executor reads the underlying f32 buffer at
233    /// `2·len` floats and walks element pairs. Supports Add / Sub /
234    /// Mul / Div; Max / Min / Pow have no single natural complex
235    /// definition and panic at lowering.
236    BinaryFullC64 {
237        lhs: usize,
238        rhs: usize,
239        dst: usize,
240        /// Complex element count (NOT f32 count). f32 buffer length
241        /// is `2·len`.
242        len: u32,
243        lhs_len: u32,
244        rhs_len: u32,
245        op: BinaryOp,
246        out_dims_bcast: Vec<u32>,
247        bcast_lhs_strides: Vec<u32>,
248        bcast_rhs_strides: Vec<u32>,
249    },
250    /// Bounded scan. Holds a recursively-compiled body schedule + a
251    /// pre-initialized body arena snapshot (constants filled). Each
252    /// outer execution clones the snapshot, copies the carry-in slot
253    /// from the outer arena, runs the body schedule `length` times,
254    /// then writes the final carry to the outer arena.
255    ///
256    /// Single-carry MVP — body has exactly one Input and one output,
257    /// both same shape and dtype.
258    Scan {
259        body: Arc<ThunkSchedule>,
260        body_init: Arc<Vec<u8>>, // pristine body arena bytes
261        body_input_off: usize,   // byte offset of the body's carry-Input slot
262        body_output_off: usize,  // byte offset of the body's output slot
263        outer_init_off: usize,   // outer-arena offset of the initial carry
264        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
265        length: u32,
266        carry_bytes: u32, // carry size in bytes
267        /// When true, write each step's carry to the outer arena at
268        /// offset `outer_final_off + t * carry_bytes`, producing a
269        /// `[length, *carry]` stacked trajectory. When false, only the
270        /// final carry lands at `outer_final_off`.
271        save_trajectory: bool,
272        /// Per-step `xs` inputs. For each: (body_x_input_off,
273        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
274        /// executor copies `outer_xs_base_off + t * per_step_bytes`
275        /// into `body_x_input_off`. Empty when the scan has no xs.
276        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
277        /// Broadcast inputs — values constant across iterations. For
278        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
279        /// Filled into `body_buf` ONCE before the scan loop starts
280        /// (xs in contrast are re-filled every iteration). Empty when
281        /// the scan has no bcasts.
282        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
283        /// Number of trajectory checkpoints (when `save_trajectory`).
284        /// `0` or `length` ⇒ save every iteration. Otherwise save only
285        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
286        /// `k in 0..K`. Last index is always `length-1` so the final
287        /// carry is always cached.
288        num_checkpoints: u32,
289    },
290
291    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
292    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
293    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
294    /// slice, and the current `dcarry` into the body_vjp's Input
295    /// slots, runs body_vjp, reads new `dcarry` from its single output.
296    /// f64 carry only — the upstream-accumulation step in trajectory
297    /// mode does an element-wise f64 add.
298    ScanBackward {
299        body_vjp: Arc<ThunkSchedule>,
300        body_init: Arc<Vec<u8>>,
301        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
302        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
303        body_d_output_off: usize, // body_vjp's "d_output" Input slot
304        body_dcarry_out_off: usize, // body_vjp's gradient output
305        outer_init_off: usize,    // original init carry
306        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
307        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
308        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
309        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
310        outer_xs_offs: Arc<Vec<(usize, u32)>>,
311        outer_dinit_off: usize, // output: dinit
312        length: u32,
313        carry_bytes: u32,
314        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
315        /// Used to dispatch the trajectory-mode upstream accumulation
316        /// kernel (the dcarry += upstream\[t\] add must use the right
317        /// floating-point type — a hard-coded f64 add silently does
318        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
319        carry_elem_size: u32,
320        save_trajectory: bool, // true → upstream is per-step; false → just final
321        /// Recursive checkpointing config. `0` or `length` ⇒ full
322        /// trajectory cached, no recompute (existing behavior).
323        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
324        /// recomputes intermediate carries via `forward_body` between
325        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
326        num_checkpoints: u32,
327        /// Forward body schedule (same compiled body as the forward
328        /// Op::Scan), used for recompute when `num_checkpoints` is
329        /// active. `None` for the All strategy.
330        forward_body: Option<Arc<ThunkSchedule>>,
331        /// Pristine forward body arena bytes (constants filled).
332        forward_body_init: Option<Arc<Vec<u8>>>,
333        /// Forward body's carry-Input and output slot offsets — needed
334        /// to seed/read the body during recompute.
335        forward_body_carry_in_off: usize,
336        forward_body_output_off: usize,
337        /// Forward body's per-step xs Input slots (one per outer xs).
338        /// Same indexing convention as `body_x_offs`.
339        forward_body_x_offs: Arc<Vec<usize>>,
340    },
341
342    /// Companion to `ScanBackward` that materializes one stacked
343    /// `dxs_i`. Same backward loop; per iteration, after running
344    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
345    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
346    /// is identical — we still need it for the body_vjp recurrence
347    /// even though we don't write it back to the outer arena.
348    ScanBackwardXs {
349        body_vjp: Arc<ThunkSchedule>,
350        body_init: Arc<Vec<u8>>,
351        body_carry_in_off: usize,
352        body_x_offs: Arc<Vec<usize>>,
353        body_d_output_off: usize,
354        body_dcarry_out_off: usize,
355        body_dxs_out_off: usize, // the body_vjp output we extract per step
356        outer_init_off: usize,
357        outer_traj_off: usize,
358        outer_upstream_off: usize,
359        outer_xs_offs: Arc<Vec<(usize, u32)>>,
360        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
361        length: u32,
362        carry_bytes: u32,
363        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
364        carry_elem_size: u32,
365        per_step_bytes: u32, // bytes per row of the dxs output
366        save_trajectory: bool,
367        /// Recursive checkpointing config. Same semantics as
368        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
369        /// means "save every step's carry"; `0 < K < length` means
370        /// the trajectory has only K rows and the executor recomputes
371        /// intermediate carries via `forward_body` (which must be
372        /// `Some`). Implemented via segment-cached recompute,
373        /// mirroring the `ScanBackward` path.
374        num_checkpoints: u32,
375        forward_body: Option<Arc<ThunkSchedule>>,
376        forward_body_init: Option<Arc<Vec<u8>>>,
377        forward_body_carry_in_off: usize,
378        forward_body_output_off: usize,
379        forward_body_x_offs: Arc<Vec<usize>>,
380    },
381    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
382    /// Per execution: clone `body_init`, copy each primal input from the
383    /// outer arena into its body Input slot, run the body schedule,
384    /// copy the body's single output back to the outer arena.
385    CustomFn {
386        body: Arc<ThunkSchedule>,
387        body_init: Arc<Vec<u8>>,
388        /// Per primal input: (body_input_off, outer_input_off, bytes).
389        inputs: Arc<Vec<(usize, usize, u32)>>,
390        body_output_off: usize,
391        outer_output_off: usize,
392        out_bytes: u32,
393    },
394    /// C = A @ B; C += bias; C = act(C)
395    FusedMmBiasAct {
396        a: usize,
397        w: usize,
398        bias: usize,
399        c: usize,
400        m: u32,
401        k: u32,
402        n: u32,
403        act: Option<Activation>,
404    },
405    /// out = LN(x + residual + bias, gamma, beta)
406    FusedResidualLN {
407        x: usize,
408        res: usize,
409        bias: usize,
410        g: usize,
411        b: usize,
412        out: usize,
413        rows: u32,
414        h: u32,
415        eps: f32,
416        has_bias: bool,
417    },
418    /// out = RmsNorm(x + residual + bias, gamma, beta)
419    FusedResidualRmsNorm {
420        x: usize,
421        res: usize,
422        bias: usize,
423        g: usize,
424        b: usize,
425        out: usize,
426        rows: u32,
427        h: u32,
428        eps: f32,
429        has_bias: bool,
430    },
431    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
432    BiasAdd {
433        src: usize,
434        bias: usize,
435        dst: usize,
436        m: u32,
437        n: u32,
438    },
439    /// Element-wise binary op with NumPy-style broadcast.
440    ///
441    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
442    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
443    /// are unused.
444    ///
445    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
446    /// `bcast_rhs_strides` to compute per-cell indices into each
447    /// operand. The strides are precomputed at thunk-construction
448    /// time from the operands' true shapes (with stride 0 on any axis
449    /// where the operand has size 1). This is the only correct way
450    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
451    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
452    /// wrong cells.
453    BinaryFull {
454        lhs: usize,
455        rhs: usize,
456        dst: usize,
457        len: u32,
458        lhs_len: u32,
459        rhs_len: u32,
460        op: BinaryOp,
461        /// Output shape dims (row-major). Empty in the fast path.
462        out_dims_bcast: Vec<u32>,
463        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
464        bcast_lhs_strides: Vec<u32>,
465        /// Per-dim stride into `rhs`.
466        bcast_rhs_strides: Vec<u32>,
467        /// Element size (4 = F32, 8 = I64).
468        elem_bytes: u8,
469    },
470    /// Activation in-place
471    ActivationInPlace {
472        data: usize,
473        len: u32,
474        act: Activation,
475    },
476    /// Gather axis=0: table\[idx\] → out
477    Gather {
478        table: usize,
479        table_len: u32,
480        idx: usize,
481        dst: usize,
482        num_idx: u32,
483        trailing: u32,
484        /// 1 when the index tensor is i64 (ONNX Gather indices).
485        idx_i64: u8,
486        /// Element size of table/output (4 = f32, 8 = i64).
487        table_bytes: u8,
488    },
489    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
490    Narrow {
491        src: usize,
492        dst: usize,
493        outer: u32,
494        src_stride: u32,
495        dst_stride: u32,
496        inner: u32,
497        elem_bytes: u8,
498    },
499    /// Copy (reshape, expand)
500    Copy { src: usize, dst: usize, len: u32 },
501    /// LayerNorm standalone
502    LayerNorm {
503        src: usize,
504        g: usize,
505        b: usize,
506        dst: usize,
507        rows: u32,
508        h: u32,
509        eps: f32,
510    },
511    /// GroupNorm on NCHW `[N,C,H,W]`.
512    GroupNorm {
513        src: usize,
514        g: usize,
515        b: usize,
516        dst: usize,
517        n: u32,
518        c: u32,
519        h: u32,
520        w: u32,
521        num_groups: u32,
522        eps: f32,
523    },
524    /// BatchNorm inference: frozen mean/var, feature axis last.
525    BatchNormInference {
526        src: usize,
527        g: usize,
528        b: usize,
529        mean: usize,
530        var: usize,
531        dst: usize,
532        count: u32,
533        channels: u32,
534        eps: f32,
535    },
536    BatchNormInferenceBackwardInput {
537        x: usize,
538        gamma: usize,
539        mean: usize,
540        var: usize,
541        dy: usize,
542        dx: usize,
543        count: u32,
544        channels: u32,
545        eps: f32,
546    },
547    BatchNormInferenceBackwardGamma {
548        x: usize,
549        mean: usize,
550        var: usize,
551        dy: usize,
552        dgamma: usize,
553        count: u32,
554        channels: u32,
555        eps: f32,
556    },
557    BatchNormInferenceBackwardBeta {
558        dy: usize,
559        dbeta: usize,
560        count: u32,
561        channels: u32,
562    },
563    /// LayerNorm2d on NCHW (SAM / candle semantics).
564    LayerNorm2d {
565        src: usize,
566        g: usize,
567        b: usize,
568        dst: usize,
569        n: u32,
570        c: u32,
571        h: u32,
572        w: u32,
573        eps: f32,
574    },
575    /// ConvTranspose2d on NCHW.
576    ConvTranspose2d {
577        src: usize,
578        weight: usize,
579        dst: usize,
580        n: u32,
581        c_in: u32,
582        h: u32,
583        w_in: u32,
584        c_out: u32,
585        h_out: u32,
586        w_out: u32,
587        kh: u32,
588        kw: u32,
589        sh: u32,
590        sw: u32,
591        ph: u32,
592        pw: u32,
593        dh: u32,
594        dw: u32,
595        groups: u32,
596    },
597    /// Nearest 2× upsample on NCHW (per-batch slice).
598    ResizeNearest2x {
599        src: usize,
600        dst: usize,
601        n: u32,
602        c: u32,
603        h: u32,
604        w: u32,
605    },
606    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
607    AxialRope2d {
608        src: usize,
609        dst: usize,
610        batch: u32,
611        seq: u32,
612        hidden: u32,
613        end_x: u32,
614        end_y: u32,
615        head_dim: u32,
616        num_heads: u32,
617        theta: f32,
618        repeat_factor: u32,
619    },
620    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
621    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
622    RmsNorm {
623        src: usize,
624        g: usize,
625        b: usize,
626        dst: usize,
627        rows: u32,
628        h: u32,
629        eps: f32,
630    },
631    /// Softmax
632    Softmax { data: usize, rows: u32, cols: u32 },
633    /// Inclusive (or exclusive) cumulative sum along the last axis
634    /// (callers pre-flatten higher-dim cumsums via reshape views).
635    Cumsum {
636        src: usize,
637        dst: usize,
638        rows: u32,
639        cols: u32,
640        exclusive: bool,
641    },
642    /// Mamba-style selective scan (plan #15).
643    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
644    /// Output: y \[b,s,h\]. State h carries through the seq.
645    SelectiveScan {
646        x: usize,
647        delta: usize,
648        a: usize,
649        b: usize,
650        c: usize,
651        dst: usize,
652        batch: u32,
653        seq: u32,
654        hidden: u32,
655        state_size: u32,
656    },
657
658    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
659    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
660    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
661    GatedDeltaNet {
662        q: usize,
663        k: usize,
664        v: usize,
665        g: usize,
666        beta: usize,
667        /// When non-zero, load initial `[b, h, n, n]` state and write
668        /// the final state back in place after the scan.
669        state: usize,
670        dst: usize,
671        batch: u32,
672        seq: u32,
673        heads: u32,
674        state_size: u32,
675    },
676
677    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
678    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
679    /// groups-1 conv is mathematically a per-batch matmul, and
680    /// dispatching it through BLAS is 3-10× faster than the
681    /// scalar nest. Common case: ViT patch-projection follow-on,
682    /// transformer "expert" reductions in some MoE designs.
683    ///
684    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
685    ///         = output `[c_out, h*w]`.
686    Conv2D1x1 {
687        src: usize,
688        weight: usize,
689        dst: usize,
690        n: u32,
691        c_in: u32,
692        c_out: u32,
693        hw: u32,
694    },
695
696    /// Fused dequant + matmul (plan #5). Today supports
697    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
698    /// at lowering time with a clear message until kernels are added.
699    DequantMatMul {
700        x: usize,
701        w_q: usize,   // packed i8 bytes for Int8 schemes
702        scale: usize, // [k/block, n] f32 scale
703        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
704        dst: usize,
705        m: u32,
706        k: u32,
707        n: u32,
708        block_size: u32,
709        is_asymmetric: bool,
710    },
711
712    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
713    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
714    /// Q8_K). Scales / mins live inside the packed bytes — no
715    /// side-channel scale tensor.
716    ///
717    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
718    /// keeps the *arena* memory footprint down (weights stay packed)
719    /// but the dequant itself happens per matmul. A future fully
720    /// fused tile-streaming kernel would close the compute gap.
721    DequantMatMulGguf {
722        x: usize,   // f32 activations [m, k]
723        w_q: usize, // packed weight bytes (k*n elements packed)
724        dst: usize, // f32 output [m, n]
725        m: u32,
726        k: u32,
727        n: u32,
728        scheme: rlx_ir::quant::QuantScheme,
729    },
730
731    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
732    DequantMatMulInt4 {
733        x: usize,
734        w_q: usize,
735        scale: usize,
736        zp: usize,
737        dst: usize,
738        m: u32,
739        k: u32,
740        n: u32,
741        block_size: u32,
742        is_asymmetric: bool,
743    },
744
745    /// FP8 dequant + matmul (per-tensor or per-column scale).
746    DequantMatMulFp8 {
747        x: usize,
748        w_q: usize,
749        scale: usize,
750        dst: usize,
751        m: u32,
752        k: u32,
753        n: u32,
754        e5m2: bool,
755    },
756
757    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
758    DequantMatMulNvfp4 {
759        x: usize,
760        w_q: usize,
761        scale: usize,
762        global_scale: usize,
763        dst: usize,
764        m: u32,
765        k: u32,
766        n: u32,
767    },
768
769    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
770    /// `r` is the LoRA rank (typically 4-64) — the rank-r
771    /// intermediate `x·A` lives in scratch, never on the arena.
772    LoraMatMul {
773        x: usize,
774        w: usize,
775        a: usize,
776        b: usize,
777        dst: usize,
778        m: u32,
779        k: u32,
780        n: u32,
781        r: u32,
782        scale: f32,
783    },
784    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
785    /// See Op::Sample. Output values are f32-encoded usize indices
786    /// (matches the rest of the IR's "ids as f32" convention).
787    Sample {
788        logits: usize,
789        dst: usize,
790        batch: u32,
791        vocab: u32,
792        top_k: u32,       // 0 = disabled
793        top_p: f32,       // 1.0 = disabled
794        temperature: f32, // 1.0 = neutral
795        seed: u64,
796    },
797    /// Attention SDPA. `mask` is the offset of the optional mask tensor
798    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
799    /// kinds synthesize the mask in-kernel.
800    ///
801    /// Q/K/V each carry a `_row_stride` (elements per source row).
802    /// Defaults to `heads * head_dim` — matches the standalone
803    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
804    /// Attention fusion below rewrites these to the parent QKV stride
805    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
806    /// directly without materializing the per-head buffers (plan #46).
807    Attention {
808        q: usize,
809        k: usize,
810        v: usize,
811        mask: usize,
812        out: usize,
813        batch: u32,
814        /// Query sequence length.
815        seq: u32,
816        /// Key/value sequence length. Differs from `seq` during cached decode.
817        kv_seq: u32,
818        heads: u32,
819        head_dim: u32,
820        mask_kind: rlx_ir::op::MaskKind,
821        q_row_stride: u32,
822        k_row_stride: u32,
823        v_row_stride: u32,
824        /// Memory layout flag. `false` (the historical default) →
825        /// `[B, S, H, D]` row-major: per-head offset is
826        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
827        /// (head-major), matching the convention used by rlx-cuda /
828        /// rlx-rocm / rlx-tpu: per-head offset is
829        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
830        /// from the input shape vs `num_heads` / `head_dim`.
831        bhsd: bool,
832    },
833    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
834    AttentionBackward {
835        q: usize,
836        k: usize,
837        v: usize,
838        dy: usize,
839        mask: usize,
840        out: usize,
841        batch: u32,
842        seq: u32,
843        kv_seq: u32,
844        heads: u32,
845        head_dim: u32,
846        mask_kind: rlx_ir::op::MaskKind,
847        wrt: rlx_ir::op::AttentionBwdWrt,
848        bhsd: bool,
849    },
850    /// RoPE (rotary position embeddings).
851    /// `src_row_stride` is elements per source row (defaults to `hidden`
852    /// for the standalone case; set to `qkv_axis * inner` when the
853    /// thunk fusion pass below rewires Rope to read directly from the
854    /// fused QKV buffer — plan #45).
855    Rope {
856        src: usize,
857        cos: usize,
858        sin: usize,
859        dst: usize,
860        batch: u32,
861        seq: u32,
862        hidden: u32,
863        head_dim: u32,
864        n_rot: u32,
865        cos_len: u32,
866        src_row_stride: u32,
867    },
868    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
869    /// All intermediates stay in L1 cache. Zero arena writes between ops.
870    FusedAttnBlock {
871        hidden: usize,
872        qkv_w: usize,
873        out_w: usize,
874        mask: usize,
875        out: usize,
876        qkv_b: usize,
877        out_b: usize, // 0 = no bias
878        cos: usize,
879        sin: usize,
880        cos_len: u32, // 0 = no RoPE
881        batch: u32,
882        seq: u32,
883        hs: u32,
884        nh: u32,
885        dh: u32,
886        has_bias: bool,
887        has_rope: bool,
888    },
889    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
890    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
891    FusedBertLayer {
892        // attention
893        hidden: usize,
894        qkv_w: usize,
895        qkv_b: usize,
896        out_w: usize,
897        out_b: usize,
898        mask: usize,
899        // LN1
900        ln1_g: usize,
901        ln1_b: usize,
902        eps1: f32,
903        // FFN (GELU)
904        fc1_w: usize,
905        fc1_b: usize,
906        fc2_w: usize,
907        fc2_b: usize,
908        // LN2
909        ln2_g: usize,
910        ln2_b: usize,
911        eps2: f32,
912        // output
913        out: usize,
914        // dims
915        batch: u32,
916        seq: u32,
917        hs: u32,
918        nh: u32,
919        dh: u32,
920        int_dim: u32,
921    },
922    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
923    FusedNomicLayer {
924        hidden: usize,
925        qkv_w: usize,
926        out_w: usize,
927        mask: usize,
928        cos: usize,
929        sin: usize,
930        cos_len: u32,
931        ln1_g: usize,
932        ln1_b: usize,
933        eps1: f32,
934        fc11_w: usize,
935        fc12_w: usize,
936        fc2_w: usize,
937        ln2_g: usize,
938        ln2_b: usize,
939        eps2: f32,
940        out: usize,
941        batch: u32,
942        seq: u32,
943        hs: u32,
944        nh: u32,
945        dh: u32,
946        int_dim: u32,
947    },
948    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
949    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
950    /// Output: [outer, n_half].
951    FusedSwiGLU {
952        src: usize,
953        dst: usize,
954        n_half: u32,
955        total: u32,
956        gate_first: bool,
957    },
958    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
959    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
960    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
961    /// at compile time to avoid per-run shape work.
962    Concat {
963        dst: usize,
964        outer: u32,
965        inner: u32,
966        total_axis: u32,
967        inputs: Vec<(usize, u32)>,
968    },
969    /// Element-wise comparison: out = (lhs CMP rhs) ? 1 : 0 (Bool u8 or F32 0/1).
970    Compare {
971        lhs: usize,
972        rhs: usize,
973        dst: usize,
974        len: u32,
975        op: CmpOp,
976        /// Nonzero when lhs/rhs are i64 (mask/range ops).
977        inputs_i64: u8,
978        /// Input element size (1 = Bool, 4 = F32, 8 = I64).
979        inputs_elem_bytes: u8,
980        /// Output element size (1 = Bool, 4 = F32).
981        dst_elem_bytes: u8,
982    },
983    /// Reduction along a contiguous range of axes. Input layout (after
984    /// shape decomposition) is `[outer, reduced, inner]`; output is
985    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
986    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
987    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
988    /// all map onto this triplet. Non-contiguous axes are not supported
989    /// and bail to Nop in the compile pass.
990    Reduce {
991        src: usize,
992        dst: usize,
993        outer: u32,
994        reduced: u32,
995        inner: u32,
996        op: ReduceOp,
997    },
998    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
999    /// output `[outer, k]` (f32 or i64 per `indices_i64`). Ties broken by
1000    /// smaller index. Used by MoE gating + beam search.
1001    TopK {
1002        src: usize,
1003        dst: usize,
1004        outer: u32,
1005        axis_dim: u32,
1006        k: u32,
1007        indices_i64: u8,
1008    },
1009    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
1010    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
1011    /// segmented GEMM would amortize. Done when there's a workload.
1012    GroupedMatMul {
1013        input: usize,
1014        weight: usize,
1015        expert_idx: usize,
1016        dst: usize,
1017        m: u32,
1018        k_dim: u32,
1019        n: u32,
1020        num_experts: u32,
1021    },
1022    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
1023    DequantGroupedMatMulGguf {
1024        input: usize,
1025        w_q: usize,
1026        expert_idx: usize,
1027        dst: usize,
1028        m: u32,
1029        k_dim: u32,
1030        n: u32,
1031        num_experts: u32,
1032        scheme: rlx_ir::quant::QuantScheme,
1033    },
1034    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
1035    DequantMoEWeightsGguf {
1036        w_q: usize,
1037        dst: usize,
1038        k_dim: u32,
1039        n: u32,
1040        num_experts: u32,
1041        scheme: rlx_ir::quant::QuantScheme,
1042    },
1043    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
1044    /// Output is zeroed first; multiple updates to the same row accumulate.
1045    ScatterAdd {
1046        updates: usize,
1047        indices: usize,
1048        dst: usize,
1049        num_updates: u32,
1050        out_dim: u32,
1051        trailing: u32,
1052    },
1053    /// Ternary select: out = cond != 0 ? on_true : on_false
1054    Where {
1055        cond: usize,
1056        on_true: usize,
1057        on_false: usize,
1058        dst: usize,
1059        len: u32,
1060        elem_bytes: u8,
1061        /// Element size for cond (1 = Bool mask, 4 = F32 0/1).
1062        cond_elem_bytes: u8,
1063    },
1064    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1065    /// i length; `in_strides[i]` is the input stride (in elements) used to
1066    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1067    /// total element count in the source buffer (≤ output total when
1068    /// broadcasting). Strides are pre-computed at compile time.
1069    Transpose {
1070        src: usize,
1071        dst: usize,
1072        in_total: u32,
1073        out_dims: Vec<u32>,
1074        in_strides: Vec<u32>,
1075        elem_bytes: u8,
1076    },
1077    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1078    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1079    /// being indexed into. Output: outer × num_idx × trailing.
1080    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1081    GatherAxis {
1082        table: usize,
1083        idx: usize,
1084        dst: usize,
1085        outer: u32,
1086        axis_dim: u32,
1087        num_idx: u32,
1088        trailing: u32,
1089        idx_i64: u8,
1090        table_bytes: u8,
1091    },
1092    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1093    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1094    /// the full kernel area (matches torch's `count_include_pad=True`).
1095    Pool2D {
1096        src: usize,
1097        dst: usize,
1098        n: u32,
1099        c: u32,
1100        h: u32,
1101        w: u32,
1102        h_out: u32,
1103        w_out: u32,
1104        kh: u32,
1105        kw: u32,
1106        sh: u32,
1107        sw: u32,
1108        ph: u32,
1109        pw: u32,
1110        kind: ReduceOp,
1111    },
1112    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1113    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1114    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1115    /// Naive direct convolution; sufficient for correctness, not optimised.
1116    Conv2D {
1117        src: usize,
1118        weight: usize,
1119        dst: usize,
1120        n: u32,
1121        c_in: u32,
1122        h: u32,
1123        w: u32,
1124        c_out: u32,
1125        h_out: u32,
1126        w_out: u32,
1127        kh: u32,
1128        kw: u32,
1129        sh: u32,
1130        sw: u32,
1131        ph: u32,
1132        pw: u32,
1133        dh: u32,
1134        dw: u32,
1135        groups: u32,
1136    },
1137
1138    // ── Backward / training kernels ─────────────────────────────
1139    /// Real INT8 matmul with i32 accumulation.
1140    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1141    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1142    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1143    /// to a desktop thunk so a quantized graph compiled here doesn't
1144    /// have to round-trip through fake-quant.
1145    QMatMul {
1146        x: usize,
1147        w: usize,
1148        bias: usize,
1149        out: usize,
1150        m: u32,
1151        k: u32,
1152        n: u32,
1153        x_zp: i32,
1154        w_zp: i32,
1155        out_zp: i32,
1156        mult: f32,
1157    },
1158
1159    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1160    /// but with i8 reads, i32 accumulation, and per-output requantize
1161    /// to i8. Bias is i32 in the accumulator scale.
1162    QConv2d {
1163        x: usize,
1164        w: usize,
1165        bias: usize,
1166        out: usize,
1167        n: u32,
1168        c_in: u32,
1169        h: u32,
1170        w_in: u32,
1171        c_out: u32,
1172        h_out: u32,
1173        w_out: u32,
1174        kh: u32,
1175        kw: u32,
1176        sh: u32,
1177        sw: u32,
1178        ph: u32,
1179        pw: u32,
1180        dh: u32,
1181        dw: u32,
1182        groups: u32,
1183        x_zp: i32,
1184        w_zp: i32,
1185        out_zp: i32,
1186        mult: f32,
1187    },
1188
1189    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1190    /// `chan = (i / inner) % chan_dim` selects the per-channel
1191    /// scale/zp; `chan_axis` is informational only (the kernel uses
1192    /// `chan_dim` and `inner` directly).
1193    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1194    /// always 0.
1195    Quantize {
1196        x: usize,
1197        q: usize,
1198        len: u32,
1199        chan_axis: u32,
1200        chan_dim: u32,
1201        inner: u32,
1202        scales: Vec<f32>,
1203        zero_points: Vec<i32>,
1204    },
1205
1206    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1207    Dequantize {
1208        q: usize,
1209        x: usize,
1210        len: u32,
1211        chan_axis: u32,
1212        chan_dim: u32,
1213        inner: u32,
1214        scales: Vec<f32>,
1215        zero_points: Vec<i32>,
1216    },
1217
1218    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1219    /// quantize-then-dequantize on the fly. Computes
1220    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1221    /// then
1222    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1223    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1224    /// channel-layout convention as `Thunk::Quantize`: every
1225    /// element's channel is `(i / inner) % chan_dim`. The kernel
1226    /// does two passes — one to scan max-abs per channel, one to
1227    /// quant-dequant per element.
1228    FakeQuantize {
1229        x: usize,
1230        out: usize,
1231        len: u32,
1232        chan_axis: u32,
1233        chan_dim: u32,
1234        inner: u32,
1235        bits: u8,
1236        /// STE variant — informational on the forward side (output is
1237        /// the same regardless), kernel-relevant in the matching
1238        /// `FakeQuantizeBackward` thunk.
1239        ste: rlx_ir::op::SteKind,
1240        /// Scale-tracking strategy. `PerBatch` recomputes
1241        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1242        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1243        /// reads `state_off` and never updates it.
1244        scale_mode: rlx_ir::op::ScaleMode,
1245        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1246        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1247        /// per channel.
1248        state_off: Option<usize>,
1249    },
1250
1251    /// Backward pass for `Op::FakeQuantize` under one of four STE
1252    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1253    /// the upstream gradient `dy`, using the same per-channel scale
1254    /// scheme as the forward.
1255    FakeQuantizeBackward {
1256        x: usize,
1257        dy: usize,
1258        dx: usize,
1259        len: u32,
1260        chan_axis: u32,
1261        chan_dim: u32,
1262        inner: u32,
1263        bits: u8,
1264        ste: rlx_ir::op::SteKind,
1265    },
1266
1267    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1268    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1269    FakeQuantizeLSQ {
1270        x: usize,
1271        scale_off: usize,
1272        out: usize,
1273        len: u32,
1274        chan_axis: u32,
1275        chan_dim: u32,
1276        inner: u32,
1277        bits: u8,
1278    },
1279
1280    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1281    /// through inside the quantization range, zeros outside.
1282    FakeQuantizeLSQBackwardX {
1283        x: usize,
1284        scale_off: usize,
1285        dy: usize,
1286        dx: usize,
1287        len: u32,
1288        chan_axis: u32,
1289        chan_dim: u32,
1290        inner: u32,
1291        bits: u8,
1292    },
1293
1294    /// LSQ backward, scale-gradient. Per-channel:
1295    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1296    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1297    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1298    FakeQuantizeLSQBackwardScale {
1299        x: usize,
1300        scale_off: usize,
1301        dy: usize,
1302        dscale: usize,
1303        len: u32,
1304        chan_axis: u32,
1305        chan_dim: u32,
1306        inner: u32,
1307        bits: u8,
1308    },
1309
1310    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1311    ReluBackward {
1312        x: usize,
1313        dy: usize,
1314        dx: usize,
1315        len: u32,
1316    },
1317    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1318    /// but reads/writes 8 bytes per element. Required because
1319    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1320    /// otherwise → backward silently produces 0 gradients on an f64
1321    /// graph. Mirrors the `ActivationBackwardF64` split.
1322    ReluBackwardF64 {
1323        x: usize,
1324        dy: usize,
1325        dx: usize,
1326        len: u32,
1327    },
1328
1329    /// Generic element-wise activation backward.
1330    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1331    /// per-element; expensive activations (Gelu) recompute internals
1332    /// inline rather than threading an extra "saved y" tensor through.
1333    ActivationBackward {
1334        x: usize,
1335        dy: usize,
1336        dx: usize,
1337        len: u32,
1338        kind: Activation,
1339    },
1340    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1341    /// elements; kernel reads/writes 8 bytes per element. Required
1342    /// because `ActivationBackward`'s `&[f32]` slot view silently
1343    /// returns garbage on an f64 graph (cb % 4 still works but every
1344    /// loaded value is half of an f64 → wrong gradient).
1345    ActivationBackwardF64 {
1346        x: usize,
1347        dy: usize,
1348        dx: usize,
1349        len: u32,
1350        kind: Activation,
1351    },
1352
1353    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1354    /// `x` and emits the closed-form `d_x` per row.
1355    LayerNormBackwardInput {
1356        x: usize,
1357        gamma: usize,
1358        dy: usize,
1359        dx: usize,
1360        rows: u32,
1361        h: u32,
1362        eps: f32,
1363    },
1364
1365    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1366    LayerNormBackwardGamma {
1367        x: usize,
1368        dy: usize,
1369        dgamma: usize,
1370        rows: u32,
1371        h: u32,
1372        eps: f32,
1373    },
1374
1375    RmsNormBackwardInput {
1376        x: usize,
1377        gamma: usize,
1378        beta: usize,
1379        dy: usize,
1380        dx: usize,
1381        rows: u32,
1382        h: u32,
1383        eps: f32,
1384    },
1385    RmsNormBackwardGamma {
1386        x: usize,
1387        gamma: usize,
1388        beta: usize,
1389        dy: usize,
1390        dgamma: usize,
1391        rows: u32,
1392        h: u32,
1393        eps: f32,
1394    },
1395    RmsNormBackwardBeta {
1396        x: usize,
1397        gamma: usize,
1398        beta: usize,
1399        dy: usize,
1400        dbeta: usize,
1401        rows: u32,
1402        h: u32,
1403        eps: f32,
1404    },
1405    RopeBackward {
1406        dy: usize,
1407        cos: usize,
1408        sin: usize,
1409        dx: usize,
1410        batch: u32,
1411        seq: u32,
1412        hidden: u32,
1413        head_dim: u32,
1414        n_rot: u32,
1415        cos_len: u32,
1416    },
1417    CumsumBackward {
1418        dy: usize,
1419        dx: usize,
1420        rows: u32,
1421        cols: u32,
1422        exclusive: bool,
1423    },
1424    GatherBackward {
1425        dy: usize,
1426        indices: usize,
1427        dst: usize,
1428        outer: u32,
1429        axis_dim: u32,
1430        num_idx: u32,
1431        trailing: u32,
1432    },
1433
1434    GroupNormBackwardInput {
1435        x: usize,
1436        gamma: usize,
1437        beta: usize,
1438        dy: usize,
1439        dx: usize,
1440        n: u32,
1441        c: u32,
1442        h: u32,
1443        w: u32,
1444        num_groups: u32,
1445        eps: f32,
1446    },
1447    GroupNormBackwardGamma {
1448        x: usize,
1449        dy: usize,
1450        dgamma: usize,
1451        n: u32,
1452        c: u32,
1453        h: u32,
1454        w: u32,
1455        num_groups: u32,
1456        eps: f32,
1457    },
1458    GroupNormBackwardBeta {
1459        dy: usize,
1460        dbeta: usize,
1461        n: u32,
1462        c: u32,
1463        h: u32,
1464        w: u32,
1465    },
1466
1467    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1468    /// inside each window and accumulates `dy` into `dx` at that
1469    /// position. Output is zeroed first; ties resolve to the first
1470    /// hit (lowest (kh,kw) index), matching what the forward kernel
1471    /// does with `acc.max(v)`.
1472    MaxPool2dBackward {
1473        x: usize,
1474        dy: usize,
1475        dx: usize,
1476        n: u32,
1477        c: u32,
1478        h: u32,
1479        w: u32,
1480        h_out: u32,
1481        w_out: u32,
1482        kh: u32,
1483        kw: u32,
1484        sh: u32,
1485        sw: u32,
1486        ph: u32,
1487        pw: u32,
1488    },
1489
1490    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1491    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1492    /// `dx [N, C_in, H, W]`.
1493    Conv2dBackwardInput {
1494        dy: usize,
1495        w: usize,
1496        dx: usize,
1497        n: u32,
1498        c_in: u32,
1499        h: u32,
1500        w_in: u32,
1501        c_out: u32,
1502        h_out: u32,
1503        w_out: u32,
1504        kh: u32,
1505        kw: u32,
1506        sh: u32,
1507        sw: u32,
1508        ph: u32,
1509        pw: u32,
1510        dh: u32,
1511        dw: u32,
1512        groups: u32,
1513    },
1514
1515    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1516    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1517    /// `dw` is zeroed before accumulation.
1518    Conv2dBackwardWeight {
1519        x: usize,
1520        dy: usize,
1521        dw: usize,
1522        n: u32,
1523        c_in: u32,
1524        h: u32,
1525        w: u32,
1526        c_out: u32,
1527        h_out: u32,
1528        w_out: u32,
1529        kh: u32,
1530        kw: u32,
1531        sh: u32,
1532        sw: u32,
1533        ph: u32,
1534        pw: u32,
1535        dh: u32,
1536        dw_dil: u32,
1537        groups: u32,
1538    },
1539
1540    /// NCHW im2col for conv backward-weight matmul. Output `[M, C·kH·kW]`
1541    /// with `M = N · H_out · W_out`. `n == 0` means infer batch from `x`.
1542    Im2Col {
1543        x: usize,
1544        col: usize,
1545        n: u32,
1546        c_in: u32,
1547        h: u32,
1548        w: u32,
1549        h_out: u32,
1550        w_out: u32,
1551        kh: u32,
1552        kw: u32,
1553        sh: u32,
1554        sw: u32,
1555        ph: u32,
1556        pw: u32,
1557        dh: u32,
1558        dw_dil: u32,
1559    },
1560
1561    /// Fused softmax + cross-entropy loss with f32-encoded integer
1562    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1563    /// Numerically stable (max-subtract before exp).
1564    SoftmaxCrossEntropy {
1565        logits: usize,
1566        labels: usize,
1567        dst: usize,
1568        n: u32,
1569        c: u32,
1570    },
1571
1572    /// Backward of the fused loss above.
1573    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1574    SoftmaxCrossEntropyBackward {
1575        logits: usize,
1576        labels: usize,
1577        d_loss: usize,
1578        dlogits: usize,
1579        n: u32,
1580        c: u32,
1581    },
1582
1583    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1584    /// `kernel` is resolved against the global CPU kernel registry at
1585    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1586    /// avoids per-call lookups. v1: f32 contiguous only — see
1587    /// `op_registry::CpuKernel::execute_f32`.
1588    CustomOp {
1589        kernel: Arc<dyn CpuKernel>,
1590        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1591        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1592        attrs: Vec<u8>,
1593    },
1594
1595    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1596    /// real-block layout (first N real, second N imag along the
1597    /// transformed axis). `outer` is the product of all leading axes;
1598    /// `n_complex` is N (the number of complex points). Both halves
1599    /// of the real-block layout are read together by the kernel.
1600    /// `dtype` selects the f32 or f64 path; the two share structure
1601    /// but not buffers, so a flag at compile time avoids per-row
1602    /// dispatch.
1603    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1604    GaussianSplatRender {
1605        positions_off: usize,
1606        positions_len: usize,
1607        scales_off: usize,
1608        scales_len: usize,
1609        rotations_off: usize,
1610        rotations_len: usize,
1611        opacities_off: usize,
1612        opacities_len: usize,
1613        colors_off: usize,
1614        colors_len: usize,
1615        sh_coeffs_off: usize,
1616        sh_coeffs_len: usize,
1617        meta_off: usize,
1618        dst_off: usize,
1619        dst_len: usize,
1620        width: u32,
1621        height: u32,
1622        tile_size: u32,
1623        radius_scale: f32,
1624        alpha_cutoff: f32,
1625        max_splat_steps: u32,
1626        transmittance_threshold: f32,
1627        max_list_entries: u32,
1628    },
1629    GaussianSplatRenderBackward {
1630        positions_off: usize,
1631        positions_len: usize,
1632        scales_off: usize,
1633        scales_len: usize,
1634        rotations_off: usize,
1635        rotations_len: usize,
1636        opacities_off: usize,
1637        opacities_len: usize,
1638        colors_off: usize,
1639        colors_len: usize,
1640        sh_coeffs_off: usize,
1641        sh_coeffs_len: usize,
1642        meta_off: usize,
1643        d_loss_off: usize,
1644        d_loss_len: usize,
1645        packed_off: usize,
1646        packed_len: usize,
1647        width: u32,
1648        height: u32,
1649        tile_size: u32,
1650        radius_scale: f32,
1651        alpha_cutoff: f32,
1652        max_splat_steps: u32,
1653        transmittance_threshold: f32,
1654        max_list_entries: u32,
1655        loss_grad_clip: f32,
1656        sh_band: u32,
1657        max_anisotropy: f32,
1658    },
1659    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1660    GaussianSplatPrepare {
1661        positions_off: usize,
1662        positions_len: usize,
1663        scales_off: usize,
1664        scales_len: usize,
1665        rotations_off: usize,
1666        rotations_len: usize,
1667        opacities_off: usize,
1668        opacities_len: usize,
1669        colors_off: usize,
1670        colors_len: usize,
1671        sh_coeffs_off: usize,
1672        sh_coeffs_len: usize,
1673        meta_off: usize,
1674        meta_len: usize,
1675        prep_off: usize,
1676        prep_len: usize,
1677        width: u32,
1678        height: u32,
1679        tile_size: u32,
1680        radius_scale: f32,
1681        alpha_cutoff: f32,
1682        max_splat_steps: u32,
1683        transmittance_threshold: f32,
1684        max_list_entries: u32,
1685    },
1686    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1687    GaussianSplatRasterize {
1688        prep_off: usize,
1689        prep_len: usize,
1690        meta_off: usize,
1691        meta_len: usize,
1692        dst_off: usize,
1693        dst_len: usize,
1694        count: usize,
1695        width: u32,
1696        height: u32,
1697        tile_size: u32,
1698        alpha_cutoff: f32,
1699        max_splat_steps: u32,
1700        transmittance_threshold: f32,
1701        max_list_entries: u32,
1702    },
1703    Fft1d {
1704        src: usize,
1705        dst: usize,
1706        outer: u32,
1707        n_complex: u32,
1708        inverse: bool,
1709        norm_tag: u32,
1710        dtype: rlx_ir::DType,
1711    },
1712    FftButterflyStage {
1713        state_src: usize,
1714        state_dst: usize,
1715        gate_src: usize,
1716        rev_src: usize,
1717        tw_re_src: usize,
1718        tw_im_src: usize,
1719        batch: u32,
1720        n_fft: u32,
1721        stage: u32,
1722    },
1723    LogMel {
1724        spec: usize,
1725        filters: usize,
1726        dst: usize,
1727        outer: u32,
1728        n_fft: u32,
1729        n_bins: u32,
1730        n_mels: u32,
1731    },
1732    LogMelBackward {
1733        spec: usize,
1734        filters: usize,
1735        dy: usize,
1736        dst: usize,
1737        outer: u32,
1738        n_fft: u32,
1739        n_bins: u32,
1740        n_mels: u32,
1741    },
1742}
1743
1744/// Compiled thunk schedule — the runtime hot path.
1745/// Nop thunks are filtered out at compile time for zero iteration overhead.
1746#[derive(Clone)]
1747pub struct ThunkSchedule {
1748    pub thunks: Vec<Thunk>,
1749    /// TIDE merged placement mask (union across layers).
1750    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1751    /// Per MoE layer placement (`layer[e]`); preferred when set.
1752    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1753    /// MoE router TopK capture (per-layer refresh).
1754    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1755    /// Cached config values.
1756    pub mask_threshold: f32,
1757    pub mask_neg_inf: f32,
1758    pub score_skip: f32,
1759    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1760    /// `Box`) so the schedule can be `Clone` — multiple parallel
1761    /// executors share the same compiled closures (they're read-only
1762    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1763    /// they receive is the only mutable state and is per-executor).
1764    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1765}
1766
1767impl ThunkSchedule {
1768    pub fn strip_nops(&mut self) {
1769        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1770        // compiled_fns must be rebuilt after stripping — caller should
1771        // call strip_nops() before compile_closures().
1772        self.compiled_fns.clear();
1773    }
1774}
1775
1776/// Get the arena byte offset for a node.
1777fn node_offset(arena: &Arena, id: NodeId) -> usize {
1778    if arena.has_buffer(id) {
1779        arena.byte_offset(id)
1780    } else {
1781        usize::MAX
1782    }
1783}
1784
1785/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1786/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1787/// before eliding it. Conservative: when in doubt about reads (an op
1788/// not yet listed here), the fusion will skip — correctness over
1789/// completeness.
1790fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1791    match t {
1792        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1793        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1794        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1795        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1796        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1797        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1798        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1799        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1800        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1801        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1802        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1803        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1804        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1805        Thunk::ConjugateC64 { src, .. } => vec![*src],
1806        Thunk::Scan {
1807            outer_init_off,
1808            xs_inputs,
1809            ..
1810        } => {
1811            let mut v = vec![*outer_init_off];
1812            for (_, outer_xs_off, _) in xs_inputs.iter() {
1813                v.push(*outer_xs_off);
1814            }
1815            v
1816        }
1817        Thunk::ScanBackward {
1818            outer_init_off,
1819            outer_traj_off,
1820            outer_upstream_off,
1821            outer_xs_offs,
1822            ..
1823        } => {
1824            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1825            for (off, _) in outer_xs_offs.iter() {
1826                v.push(*off);
1827            }
1828            v
1829        }
1830        Thunk::ScanBackwardXs {
1831            outer_init_off,
1832            outer_traj_off,
1833            outer_upstream_off,
1834            outer_xs_offs,
1835            ..
1836        } => {
1837            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1838            for (off, _) in outer_xs_offs.iter() {
1839                v.push(*off);
1840            }
1841            v
1842        }
1843        Thunk::CustomFn { inputs, .. } => {
1844            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1845        }
1846        Thunk::ActivationInPlace { data, .. } => vec![*data],
1847        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1848            vec![*src, *g, *b]
1849        }
1850        Thunk::BatchNormInference {
1851            src,
1852            g,
1853            b,
1854            mean,
1855            var,
1856            ..
1857        } => vec![*src, *g, *b, *mean, *var],
1858        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1859        Thunk::AxialRope2d { src, .. } => vec![*src],
1860        Thunk::FusedResidualLN {
1861            x, res, bias, g, b, ..
1862        } => vec![*x, *res, *bias, *g, *b],
1863        Thunk::FusedResidualRmsNorm {
1864            x, res, bias, g, b, ..
1865        } => vec![*x, *res, *bias, *g, *b],
1866        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1867        Thunk::Softmax { data, .. } => vec![*data],
1868        Thunk::Cumsum { src, .. } => vec![*src],
1869        Thunk::Sample { logits, .. } => vec![*logits],
1870        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1871        Thunk::DequantMatMul {
1872            x, w_q, scale, zp, ..
1873        } => vec![*x, *w_q, *scale, *zp],
1874        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1875        Thunk::DequantMatMulInt4 {
1876            x, w_q, scale, zp, ..
1877        } => vec![*x, *w_q, *scale, *zp],
1878        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1879        Thunk::DequantMatMulNvfp4 {
1880            x,
1881            w_q,
1882            scale,
1883            global_scale,
1884            ..
1885        } => vec![*x, *w_q, *scale, *global_scale],
1886        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1887        Thunk::SelectiveScan {
1888            x, delta, a, b, c, ..
1889        } => vec![*x, *delta, *a, *b, *c],
1890        Thunk::GatedDeltaNet {
1891            q,
1892            k,
1893            v,
1894            g,
1895            beta,
1896            state,
1897            ..
1898        } => {
1899            let mut v = vec![*q, *k, *v, *g, *beta];
1900            if *state != 0 {
1901                v.push(*state);
1902            }
1903            v
1904        }
1905        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1906        Thunk::AttentionBackward {
1907            q, k, v, dy, mask, ..
1908        } => {
1909            let mut v = vec![*q, *k, *v, *dy];
1910            if *mask != 0 {
1911                v.push(*mask);
1912            }
1913            v
1914        }
1915        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1916        Thunk::FusedAttnBlock {
1917            hidden,
1918            qkv_w,
1919            out_w,
1920            mask,
1921            qkv_b,
1922            out_b,
1923            cos,
1924            sin,
1925            ..
1926        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1927        Thunk::FusedSwiGLU { src, .. } => vec![*src],
1928        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1929        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1930        Thunk::Narrow { src, .. } => vec![*src],
1931        Thunk::Copy { src, .. } => vec![*src],
1932        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1933        // Anything not enumerated → return the dst as a "read" too,
1934        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
1935        // this list safe to be incomplete.
1936        _ => vec![],
1937    }
1938}
1939
1940/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
1941/// `block_size` consecutive elements of a column share one f32
1942/// scale (and optionally a zero-point). The dequant happens inside
1943/// the inner accumulate so the f32 weight is never materialized.
1944///
1945/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
1946/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
1947/// empty.
1948///
1949/// Today this is the reference scalar implementation — the win is
1950/// memory bandwidth, not flops, since LLM weights dominate the
1951/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
1952/// fused-multiply-add is the natural follow-on.
1953#[allow(clippy::too_many_arguments)]
1954fn dequant_matmul_int8(
1955    x: &[f32],       // [m, k]
1956    w_bytes: &[i8],  // [k, n]
1957    scales: &[f32],  // [k/block, n]
1958    zps: &[f32],     // [k/block, n] or empty
1959    out: &mut [f32], // [m, n]
1960    m: usize,
1961    k: usize,
1962    n: usize,
1963    block_size: usize,
1964    asym: bool,
1965) {
1966    let blocks_per_col = k.div_ceil(block_size);
1967    for i in 0..m {
1968        for j in 0..n {
1969            let mut acc = 0f32;
1970            for p in 0..k {
1971                let block = p / block_size;
1972                let s = scales[block * n + j];
1973                let z = if asym { zps[block * n + j] } else { 0.0 };
1974                let q = w_bytes[p * n + j] as f32;
1975                let dequantized = (q - z) * s;
1976                acc += x[i * k + p] * dequantized;
1977            }
1978            out[i * n + j] = acc;
1979        }
1980    }
1981    let _ = blocks_per_col;
1982}
1983
1984#[allow(clippy::too_many_arguments)]
1985fn dequant_matmul_int4(
1986    x: &[f32],
1987    w_bytes: &[u8],
1988    scales: &[f32],
1989    zps: &[f32],
1990    out: &mut [f32],
1991    m: usize,
1992    k: usize,
1993    n: usize,
1994    block_size: usize,
1995    asym: bool,
1996) {
1997    for i in 0..m {
1998        for j in 0..n {
1999            let mut acc = 0f32;
2000            for p in 0..k {
2001                let block = p / block_size;
2002                let s = scales[block * n + j];
2003                let z = if asym { zps[block * n + j] } else { 0.0 };
2004                let byte_idx = (p * n + j) / 2;
2005                let nibble = if (p * n + j) & 1 == 0 {
2006                    w_bytes[byte_idx] & 0x0F
2007                } else {
2008                    w_bytes[byte_idx] >> 4
2009                };
2010                let dequantized = (nibble as f32 - z) * s;
2011                acc += x[i * k + p] * dequantized;
2012            }
2013            out[i * n + j] = acc;
2014        }
2015    }
2016}
2017
2018fn fp8_e4m3_to_f32(b: u8) -> f32 {
2019    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2020    let exp = (b >> 3) & 0x0F;
2021    let mant = b & 0x07;
2022    if exp == 0 {
2023        if mant == 0 {
2024            return 0.0;
2025        }
2026        return sign * (mant as f32) * 2f32.powi(-9);
2027    }
2028    if exp == 0x0F {
2029        return if mant == 0 {
2030            sign * f32::INFINITY
2031        } else {
2032            f32::NAN
2033        };
2034    }
2035    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2036}
2037
2038fn fp8_e5m2_to_f32(b: u8) -> f32 {
2039    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2040    let exp = (b >> 2) & 0x1F;
2041    let mant = b & 0x03;
2042    if exp == 0 {
2043        if mant == 0 {
2044            return 0.0;
2045        }
2046        return sign * (mant as f32) * 2f32.powi(-16);
2047    }
2048    if exp == 0x1F {
2049        return if mant == 0 {
2050            sign * f32::INFINITY
2051        } else {
2052            f32::NAN
2053        };
2054    }
2055    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2056}
2057
2058#[allow(clippy::too_many_arguments)]
2059fn dequant_matmul_fp8(
2060    x: &[f32],
2061    w_bytes: &[u8],
2062    scales: &[f32],
2063    out: &mut [f32],
2064    m: usize,
2065    k: usize,
2066    n: usize,
2067    e5m2: bool,
2068) {
2069    let dequant = if e5m2 {
2070        fp8_e5m2_to_f32
2071    } else {
2072        fp8_e4m3_to_f32
2073    };
2074    for i in 0..m {
2075        for j in 0..n {
2076            let mut acc = 0f32;
2077            for p in 0..k {
2078                let w = dequant(w_bytes[p * n + j]);
2079                let s = scales.get(j).copied().unwrap_or(1.0);
2080                acc += x[i * k + p] * w * s;
2081            }
2082            out[i * n + j] = acc;
2083        }
2084    }
2085}
2086
2087#[allow(clippy::too_many_arguments)]
2088pub fn dequant_matmul_nvfp4(
2089    x: &[f32],
2090    w_bytes: &[u8],
2091    scale_bytes: &[u8],
2092    global_scale: f32,
2093    out: &mut [f32],
2094    m: usize,
2095    k: usize,
2096    n: usize,
2097) {
2098    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2099    let gs = NVFP4_GROUP_SIZE;
2100    for i in 0..m {
2101        for j in 0..n {
2102            let mut acc = 0f32;
2103            for p in 0..k {
2104                let byte_idx = (p * n + j) / 2;
2105                let nibble = if (p * n + j) & 1 == 0 {
2106                    w_bytes[byte_idx] & 0x0F
2107                } else {
2108                    w_bytes[byte_idx] >> 4
2109                };
2110                let block = p / gs;
2111                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2112                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2113                acc += x[i * k + p] * w;
2114            }
2115            out[i * n + j] = acc;
2116        }
2117    }
2118}
2119
2120/// Fused sampling step: logits → top-k filter → top-p truncation
2121/// → softmax → multinomial sample. Operates on one row of length
2122/// `vocab` and returns the sampled index. Plan #42.
2123///
2124/// Internal scratch is on the stack via SmallVec-style fallback —
2125/// for `vocab > 8192` we heap-allocate a working buffer; below
2126/// that we keep things in a fixed array. (TODO: thread the
2127/// scratch through ThunkSchedule like sdpa_scores does.)
2128fn sample_row(
2129    logits: &[f32],
2130    top_k: usize,
2131    top_p: f32,
2132    temperature: f32,
2133    rng: &mut rlx_ir::Philox4x32,
2134) -> usize {
2135    let v = logits.len();
2136    if v == 0 {
2137        return 0;
2138    }
2139    let temp = temperature.max(1e-6);
2140    // Copy + temperature-scale into a working buffer.
2141    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2142
2143    // Top-k: zero out everything but the k largest by setting to -inf.
2144    if top_k > 0 && top_k < v {
2145        // Partial selection: find k-th largest then mask below.
2146        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2147        // Sort descending; partial would be O(n log k), full sort is fine
2148        // for typical vocab sizes (32k-128k) — single-row work.
2149        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2150        let cutoff = indexed[top_k - 1].1;
2151        for x in scaled.iter_mut() {
2152            if *x < cutoff {
2153                *x = f32::NEG_INFINITY;
2154            }
2155        }
2156    }
2157
2158    // Stable softmax.
2159    let mut max_l = f32::NEG_INFINITY;
2160    for &x in &scaled {
2161        if x > max_l {
2162            max_l = x;
2163        }
2164    }
2165    let mut sum = 0.0f32;
2166    for x in scaled.iter_mut() {
2167        *x = (*x - max_l).exp();
2168        sum += *x;
2169    }
2170    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2171    for x in scaled.iter_mut() {
2172        *x *= inv;
2173    }
2174
2175    // Top-p: keep the smallest set of tokens whose cumulative
2176    // probability exceeds top_p (after sorting descending).
2177    if top_p < 1.0 {
2178        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2179        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2180        let mut cum = 0.0f32;
2181        let mut keep = vec![false; v];
2182        for (idx, p) in indexed.iter() {
2183            keep[*idx] = true;
2184            cum += *p;
2185            if cum >= top_p {
2186                break;
2187            }
2188        }
2189        let mut new_sum = 0.0f32;
2190        for (i, x) in scaled.iter_mut().enumerate() {
2191            if !keep[i] {
2192                *x = 0.0;
2193            }
2194            new_sum += *x;
2195        }
2196        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2197        for x in scaled.iter_mut() {
2198            *x *= inv;
2199        }
2200    }
2201
2202    // Multinomial sample via inverse-CDF.
2203    let r = rng.next_f32();
2204    let mut acc = 0.0f32;
2205    for (i, &p) in scaled.iter().enumerate() {
2206        acc += p;
2207        if r <= acc {
2208            return i;
2209        }
2210    }
2211    v - 1 // floating-point edge case fallback
2212}
2213
2214/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2215/// scores matrix. Custom masks are read from a tensor and not handled here.
2216/// `None` is a no-op so callers don't need to special-case it.
2217#[inline]
2218fn apply_synthetic_mask(
2219    scores: &mut [f32],
2220    q_seq: usize,
2221    k_seq: usize,
2222    kind: rlx_ir::op::MaskKind,
2223) {
2224    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2225    let q_offset = k_seq.saturating_sub(q_seq);
2226    match kind {
2227        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2228        rlx_ir::op::MaskKind::Causal => {
2229            for qi in 0..q_seq {
2230                let abs_q = q_offset + qi;
2231                for ki in (abs_q + 1)..k_seq {
2232                    scores[qi * k_seq + ki] = neg;
2233                }
2234            }
2235        }
2236        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2237            for qi in 0..q_seq {
2238                let abs_q = q_offset + qi;
2239                let lo = abs_q.saturating_sub(w);
2240                for ki in 0..k_seq {
2241                    if ki < lo || ki > abs_q {
2242                        scores[qi * k_seq + ki] = neg;
2243                    }
2244                }
2245            }
2246        }
2247    }
2248}
2249
2250/// NCL `[N,C,L]` or NCHW `[N,C,H,W]` → `(n, c, h, w)` for 2D conv/norm thunks.
2251fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2252    match shape.rank() {
2253        3 => (
2254            shape.dim(0).unwrap_static() as u32,
2255            shape.dim(1).unwrap_static() as u32,
2256            1,
2257            shape.dim(2).unwrap_static() as u32,
2258        ),
2259        4 => (
2260            shape.dim(0).unwrap_static() as u32,
2261            shape.dim(1).unwrap_static() as u32,
2262            shape.dim(2).unwrap_static() as u32,
2263            shape.dim(3).unwrap_static() as u32,
2264        ),
2265        r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2266    }
2267}
2268
2269/// Compile graph into thunk schedule.
2270pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2271    let mut thunks = Vec::with_capacity(graph.len());
2272
2273    for node in graph.nodes() {
2274        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2275        // to their parent's slot by the memory planner — no copy needed.
2276        // Plan #46.
2277        if rlx_opt::is_pure_view(graph, node) {
2278            thunks.push(Thunk::Nop);
2279            continue;
2280        }
2281        let t = match &node.op {
2282            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2283
2284            Op::FusedMatMulBiasAct { activation } => {
2285                let shape = &node.shape;
2286                let n = shape.dim(shape.rank() - 1).unwrap_static();
2287                let total = shape.num_elements().unwrap();
2288                let m = total / n;
2289                let a_len = get_len(graph, node.inputs[0]);
2290                let k = a_len / m;
2291                Thunk::FusedMmBiasAct {
2292                    a: node_offset(arena, node.inputs[0]),
2293                    w: node_offset(arena, node.inputs[1]),
2294                    bias: node_offset(arena, node.inputs[2]),
2295                    c: node_offset(arena, node.id),
2296                    m: m as u32,
2297                    k: k as u32,
2298                    n: n as u32,
2299                    act: *activation,
2300                }
2301            }
2302
2303            Op::FusedResidualLN { has_bias, eps } => {
2304                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2305                let total = node.shape.num_elements().unwrap();
2306                let rows = total / h;
2307                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2308                Thunk::FusedResidualLN {
2309                    x: node_offset(arena, node.inputs[0]),
2310                    res: node_offset(arena, node.inputs[1]),
2311                    bias: if *has_bias {
2312                        node_offset(arena, node.inputs[2])
2313                    } else {
2314                        0
2315                    },
2316                    g: node_offset(arena, node.inputs[g_idx]),
2317                    b: node_offset(arena, node.inputs[b_idx]),
2318                    out: node_offset(arena, node.id),
2319                    rows: rows as u32,
2320                    h: h as u32,
2321                    eps: *eps,
2322                    has_bias: *has_bias,
2323                }
2324            }
2325
2326            Op::FusedResidualRmsNorm { has_bias, eps } => {
2327                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2328                let total = node.shape.num_elements().unwrap();
2329                let rows = total / h;
2330                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2331                Thunk::FusedResidualRmsNorm {
2332                    x: node_offset(arena, node.inputs[0]),
2333                    res: node_offset(arena, node.inputs[1]),
2334                    bias: if *has_bias {
2335                        node_offset(arena, node.inputs[2])
2336                    } else {
2337                        0
2338                    },
2339                    g: node_offset(arena, node.inputs[g_idx]),
2340                    b: node_offset(arena, node.inputs[b_idx]),
2341                    out: node_offset(arena, node.id),
2342                    rows: rows as u32,
2343                    h: h as u32,
2344                    eps: *eps,
2345                    has_bias: *has_bias,
2346                }
2347            }
2348
2349            Op::MatMul => {
2350                let shape = &node.shape;
2351                let a_shape = &graph.node(node.inputs[0]).shape;
2352                let b_shape = &graph.node(node.inputs[1]).shape;
2353                // Prefer inferred matmul shape from operands — ONNX bundle
2354                // meta often over-ranks outputs (e.g. [seq, seq, H]).
2355                let eff =
2356                    rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2357                let rank = eff.rank().max(2);
2358                let n = eff.dim(rank - 1).unwrap_static();
2359                let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2360                // Batched GEMM only when both operands carry batch dimensions.
2361                // 3D×2D (activations × shared weight) must flatten to one Sgemm.
2362                let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2363                let batched_3d = rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2364                if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2365                    let mut batch_prod = 1usize;
2366                    for d in 0..rank - 2 {
2367                        batch_prod *= eff.dim(d).unwrap_static();
2368                    }
2369                    let m_dim = eff.dim(rank - 2).unwrap_static();
2370                    Thunk::BatchedDgemmF64 {
2371                        a: node_offset(arena, node.inputs[0]),
2372                        b: node_offset(arena, node.inputs[1]),
2373                        c: node_offset(arena, node.id),
2374                        batch: batch_prod as u32,
2375                        m: m_dim as u32,
2376                        k: k_dim as u32,
2377                        n: n as u32,
2378                    }
2379                } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2380                    let mut batch_prod = 1usize;
2381                    for d in 0..rank - 2 {
2382                        batch_prod *= eff.dim(d).unwrap_static();
2383                    }
2384                    let m_dim = eff.dim(rank - 2).unwrap_static();
2385                    Thunk::BatchedSgemm {
2386                        a: node_offset(arena, node.inputs[0]),
2387                        b: node_offset(arena, node.inputs[1]),
2388                        c: node_offset(arena, node.id),
2389                        batch: batch_prod as u32,
2390                        m: m_dim as u32,
2391                        k: k_dim as u32,
2392                        n: n as u32,
2393                    }
2394                } else {
2395                    let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2396                        let mut m_prod = 1usize;
2397                        for d in 0..a_shape.rank() - 1 {
2398                            m_prod *= a_shape.dim(d).unwrap_static();
2399                        }
2400                        m_prod
2401                    } else if a_shape.rank() >= 2 {
2402                        a_shape.dim(a_shape.rank() - 2).unwrap_static()
2403                    } else {
2404                        eff.num_elements().unwrap_or(1) / n.max(1)
2405                    };
2406                    match shape.dtype() {
2407                        rlx_ir::DType::F64 => Thunk::Dgemm {
2408                            a: node_offset(arena, node.inputs[0]),
2409                            b: node_offset(arena, node.inputs[1]),
2410                            c: node_offset(arena, node.id),
2411                            m: m as u32,
2412                            k: k_dim as u32,
2413                            n: n as u32,
2414                        },
2415                        _ => Thunk::Sgemm {
2416                            a: node_offset(arena, node.inputs[0]),
2417                            b: node_offset(arena, node.inputs[1]),
2418                            c: node_offset(arena, node.id),
2419                            m: m as u32,
2420                            k: k_dim as u32,
2421                            n: n as u32,
2422                        },
2423                    }
2424                }
2425            }
2426
2427            Op::Binary(op) => {
2428                let lhs_len = get_len(graph, node.inputs[0]);
2429                let rhs_len = get_len(graph, node.inputs[1]);
2430                let out_len = node.shape.num_elements().unwrap();
2431                if node.shape.dtype() == rlx_ir::DType::C64 {
2432                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2433                    // to `BinaryFullC64`; the rest don't have a
2434                    // single natural complex definition.
2435                    match op {
2436                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2437                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2438                            "Op::Binary({op:?}) on DType::C64: complex \
2439                             max/min/pow have no single natural definition \
2440                             — caller should drop to 2N-real-block (see \
2441                             spike-ac) and pick a convention there"
2442                        ),
2443                    }
2444                }
2445                // Compute broadcast strides for the slow path. Empty
2446                // vectors when no broadcast is needed (the fast-path
2447                // kernel ignores them anyway).
2448                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2449                    if lhs_len == out_len && rhs_len == out_len {
2450                        (Vec::new(), Vec::new(), Vec::new())
2451                    } else {
2452                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2453                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2454                        let out_dims_v = get_static_dims(graph, node.id);
2455                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2456                            // Dynamic shape — fall back to the legacy
2457                            // modulo path (correct for scalar / last-
2458                            // axis broadcast, which is the only
2459                            // dynamic case in practice).
2460                            (Vec::new(), Vec::new(), Vec::new())
2461                        } else {
2462                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2463                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2464                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2465                            (od, ls, rs)
2466                        }
2467                    };
2468                if node.shape.dtype() == rlx_ir::DType::C64 {
2469                    Thunk::BinaryFullC64 {
2470                        lhs: node_offset(arena, node.inputs[0]),
2471                        rhs: node_offset(arena, node.inputs[1]),
2472                        dst: node_offset(arena, node.id),
2473                        len: out_len as u32,
2474                        lhs_len: lhs_len as u32,
2475                        rhs_len: rhs_len as u32,
2476                        op: *op,
2477                        out_dims_bcast,
2478                        bcast_lhs_strides,
2479                        bcast_rhs_strides,
2480                    }
2481                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2482                    // f64 path — no BiasAdd fast-path (yet); use the
2483                    // general binary-with-broadcast kernel.
2484                    Thunk::BinaryFullF64 {
2485                        lhs: node_offset(arena, node.inputs[0]),
2486                        rhs: node_offset(arena, node.inputs[1]),
2487                        dst: node_offset(arena, node.id),
2488                        len: out_len as u32,
2489                        lhs_len: lhs_len as u32,
2490                        rhs_len: rhs_len as u32,
2491                        op: *op,
2492                        out_dims_bcast,
2493                        bcast_lhs_strides,
2494                        bcast_rhs_strides,
2495                    }
2496                } else if matches!(op, BinaryOp::Add)
2497                    && rhs_len < out_len
2498                    && out_len % rhs_len == 0
2499                    && is_trailing_bias_broadcast(
2500                        graph.node(node.inputs[1]).shape.dims(),
2501                        graph.node(node.id).shape.dims(),
2502                    )
2503                {
2504                    // `BiasAdd` is only correct when the bias is a
2505                    // *trailing* broadcast — rhs dims match the right-
2506                    // hand side of the output dims (with size-1 only
2507                    // allowed in left-padded outer positions).
2508                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2509                    // has rhs_len divide out_len cleanly but is a
2510                    // mid-shape singleton, NOT a trailing broadcast.
2511                    // Routing it through BiasAdd silently treats it as
2512                    // last-`rhs_len`-cols repeated — wrong values.
2513                    Thunk::BiasAdd {
2514                        src: node_offset(arena, node.inputs[0]),
2515                        bias: node_offset(arena, node.inputs[1]),
2516                        dst: node_offset(arena, node.id),
2517                        m: (out_len / rhs_len) as u32,
2518                        n: rhs_len as u32,
2519                    }
2520                } else {
2521                    let lhs_len = get_len(graph, node.inputs[0]);
2522                    Thunk::BinaryFull {
2523                        lhs: node_offset(arena, node.inputs[0]),
2524                        rhs: node_offset(arena, node.inputs[1]),
2525                        dst: node_offset(arena, node.id),
2526                        len: out_len as u32,
2527                        lhs_len: lhs_len as u32,
2528                        rhs_len: rhs_len as u32,
2529                        op: *op,
2530                        out_dims_bcast,
2531                        bcast_lhs_strides,
2532                        bcast_rhs_strides,
2533                        elem_bytes: node.shape.dtype().size_bytes() as u8,
2534                    }
2535                }
2536            }
2537
2538            Op::Activation(act) => {
2539                let len = node.shape.num_elements().unwrap();
2540                let in_off = node_offset(arena, node.inputs[0]);
2541                let out_off = node_offset(arena, node.id);
2542                if node.shape.dtype() == rlx_ir::DType::C64 {
2543                    // Only Neg/Exp/Log/Sqrt have natural complex
2544                    // extensions used in signal-processing graphs.
2545                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2546                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2547                    match act {
2548                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2549                        other => panic!(
2550                            "Op::Activation({other:?}) on DType::C64: no \
2551                             natural complex extension — supported on C64: \
2552                             Neg, Exp, Log, Sqrt"
2553                        ),
2554                    }
2555                    Thunk::ActivationC64 {
2556                        src: in_off,
2557                        dst: out_off,
2558                        len: len as u32,
2559                        kind: *act,
2560                    }
2561                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2562                    Thunk::ActivationF64 {
2563                        src: in_off,
2564                        dst: out_off,
2565                        len: len as u32,
2566                        kind: *act,
2567                    }
2568                } else if in_off == out_off {
2569                    // ActivationInPlace operates on a single buffer. When the
2570                    // planner has assigned input and output the same slot
2571                    // (typical post-fusion case), we just run on that slot.
2572                    Thunk::ActivationInPlace {
2573                        data: out_off,
2574                        len: len as u32,
2575                        act: *act,
2576                    }
2577                } else {
2578                    // Two-step: copy input → output, then activate output in place.
2579                    // The schedule executes them in this order; downstream
2580                    // thunks see the activated output at out_off.
2581                    thunks.push(Thunk::Copy {
2582                        src: in_off,
2583                        dst: out_off,
2584                        len: len as u32,
2585                    });
2586                    Thunk::ActivationInPlace {
2587                        data: out_off,
2588                        len: len as u32,
2589                        act: *act,
2590                    }
2591                }
2592            }
2593
2594            Op::Gather { axis } if *axis == 0 => {
2595                let table_shape = &graph.node(node.inputs[0]).shape;
2596                let table_total = table_shape.num_elements().unwrap();
2597                let trailing: usize = (1..table_shape.rank())
2598                    .map(|i| table_shape.dim(i).unwrap_static())
2599                    .product();
2600                let idx_len = get_len(graph, node.inputs[1]);
2601                let idx_i64 =
2602                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2603                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2604                Thunk::Gather {
2605                    table: node_offset(arena, node.inputs[0]),
2606                    table_len: table_total as u32,
2607                    idx: node_offset(arena, node.inputs[1]),
2608                    dst: node_offset(arena, node.id),
2609                    num_idx: idx_len as u32,
2610                    trailing: trailing as u32,
2611                    idx_i64,
2612                    table_bytes,
2613                }
2614            }
2615
2616            Op::Gather { axis } => {
2617                // Non-zero axis: outer × num_idx × trailing layout.
2618                let table_shape = &graph.node(node.inputs[0]).shape;
2619                let rank = table_shape.rank();
2620                let outer: usize = (0..*axis)
2621                    .map(|i| table_shape.dim(i).unwrap_static())
2622                    .product::<usize>()
2623                    .max(1);
2624                let trailing: usize = (*axis + 1..rank)
2625                    .map(|i| table_shape.dim(i).unwrap_static())
2626                    .product::<usize>()
2627                    .max(1);
2628                let axis_dim = table_shape.dim(*axis).unwrap_static();
2629                let idx_len = get_len(graph, node.inputs[1]);
2630                let idx_i64 =
2631                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2632                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2633                Thunk::GatherAxis {
2634                    table: node_offset(arena, node.inputs[0]),
2635                    idx: node_offset(arena, node.inputs[1]),
2636                    dst: node_offset(arena, node.id),
2637                    outer: outer as u32,
2638                    axis_dim: axis_dim as u32,
2639                    num_idx: idx_len as u32,
2640                    trailing: trailing as u32,
2641                    idx_i64,
2642                    table_bytes,
2643                }
2644            }
2645
2646            Op::Narrow { axis, start, len } => {
2647                let in_shape = &graph.node(node.inputs[0]).shape;
2648                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2649                let rank = in_shape.rank();
2650                let outer: usize = (0..*axis)
2651                    .map(|i| in_shape.dim(i).unwrap_static())
2652                    .product::<usize>()
2653                    .max(1);
2654                let inner: usize = (*axis + 1..rank)
2655                    .map(|i| in_shape.dim(i).unwrap_static())
2656                    .product::<usize>()
2657                    .max(1);
2658                let in_axis = in_shape.dim(*axis).unwrap_static();
2659                let src_byte_offset =
2660                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2661                Thunk::Narrow {
2662                    src: src_byte_offset,
2663                    dst: node_offset(arena, node.id),
2664                    outer: outer as u32,
2665                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2666                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2667                    inner: (*len * inner) as u32,         // elements to copy per outer step
2668                    elem_bytes,
2669                }
2670            }
2671
2672            Op::Reshape { .. } | Op::StopGradient => {
2673                // Pure layout change: same total element count, plain copy.
2674                let len = node.shape.num_elements().unwrap();
2675                let src = node_offset(arena, node.inputs[0]);
2676                let dst = node_offset(arena, node.id);
2677                match node.shape.dtype() {
2678                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2679                        src,
2680                        dst,
2681                        len: len as u32,
2682                    },
2683                    rlx_ir::DType::I64 => Thunk::CopyI64 {
2684                        src,
2685                        dst,
2686                        len: len as u32,
2687                    },
2688                    _ => Thunk::Copy {
2689                        src,
2690                        dst,
2691                        len: len as u32,
2692                    },
2693                }
2694            }
2695
2696            Op::Cast { to } => {
2697                let in_node = graph.node(node.inputs[0]);
2698                let in_dtype = in_node.shape.dtype();
2699                let out_dtype = *to;
2700                let len = node.shape.num_elements().unwrap();
2701                let src = node_offset(arena, node.inputs[0]);
2702                let dst = node_offset(arena, node.id);
2703                if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2704                    Thunk::CastF32ToI64 {
2705                        src,
2706                        dst,
2707                        len: len as u32,
2708                    }
2709                } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2710                    Thunk::CastI64ToF32 {
2711                        src,
2712                        dst,
2713                        len: len as u32,
2714                    }
2715                } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2716                    Thunk::CastBoolToI32 {
2717                        src,
2718                        dst,
2719                        len: len as u32,
2720                    }
2721                } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2722                    Thunk::CastI32ToF32 {
2723                        src,
2724                        dst,
2725                        len: len as u32,
2726                    }
2727                } else if in_dtype == out_dtype {
2728                    match out_dtype {
2729                        rlx_ir::DType::F64 => Thunk::CopyF64 {
2730                            src,
2731                            dst,
2732                            len: len as u32,
2733                        },
2734                        rlx_ir::DType::I64 => Thunk::CopyI64 {
2735                            src,
2736                            dst,
2737                            len: len as u32,
2738                        },
2739                        _ => Thunk::Copy {
2740                            src,
2741                            dst,
2742                            len: len as u32,
2743                        },
2744                    }
2745                } else {
2746                    Thunk::Copy {
2747                        src,
2748                        dst,
2749                        len: len as u32,
2750                    }
2751                }
2752            }
2753
2754            Op::Quantize {
2755                axis,
2756                scales,
2757                zero_points,
2758            } => {
2759                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2760                Thunk::Quantize {
2761                    x: node_offset(arena, node.inputs[0]),
2762                    q: node_offset(arena, node.id),
2763                    len: node.shape.num_elements().unwrap() as u32,
2764                    chan_axis: chan_axis as u32,
2765                    chan_dim: chan_dim as u32,
2766                    inner: inner as u32,
2767                    scales: scales.clone(),
2768                    zero_points: zero_points.clone(),
2769                }
2770            }
2771
2772            Op::FakeQuantize {
2773                bits,
2774                axis,
2775                ste,
2776                scale_mode,
2777            } => {
2778                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2779                let state_off = match scale_mode {
2780                    rlx_ir::op::ScaleMode::PerBatch => None,
2781                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2782                        // Second input carries the [chan_dim] scale state.
2783                        debug_assert_eq!(
2784                            node.inputs.len(),
2785                            2,
2786                            "EMA/Fixed FakeQuantize needs a state input"
2787                        );
2788                        Some(node_offset(arena, node.inputs[1]))
2789                    }
2790                };
2791                Thunk::FakeQuantize {
2792                    x: node_offset(arena, node.inputs[0]),
2793                    out: node_offset(arena, node.id),
2794                    len: node.shape.num_elements().unwrap() as u32,
2795                    chan_axis: chan_axis as u32,
2796                    chan_dim: chan_dim as u32,
2797                    inner: inner as u32,
2798                    bits: *bits,
2799                    ste: *ste,
2800                    scale_mode: *scale_mode,
2801                    state_off,
2802                }
2803            }
2804
2805            Op::FakeQuantizeLSQ { bits, axis } => {
2806                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2807                Thunk::FakeQuantizeLSQ {
2808                    x: node_offset(arena, node.inputs[0]),
2809                    scale_off: node_offset(arena, node.inputs[1]),
2810                    out: node_offset(arena, node.id),
2811                    len: node.shape.num_elements().unwrap() as u32,
2812                    chan_axis: chan_axis as u32,
2813                    chan_dim: chan_dim as u32,
2814                    inner: inner as u32,
2815                    bits: *bits,
2816                }
2817            }
2818
2819            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2820                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2821                Thunk::FakeQuantizeLSQBackwardX {
2822                    x: node_offset(arena, node.inputs[0]),
2823                    scale_off: node_offset(arena, node.inputs[1]),
2824                    dy: node_offset(arena, node.inputs[2]),
2825                    dx: node_offset(arena, node.id),
2826                    len: node.shape.num_elements().unwrap() as u32,
2827                    chan_axis: chan_axis as u32,
2828                    chan_dim: chan_dim as u32,
2829                    inner: inner as u32,
2830                    bits: *bits,
2831                }
2832            }
2833
2834            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2835                // Output shape is [chan_dim] — node.shape doesn't
2836                // describe the input data layout, but inputs[0] does.
2837                let in_shape = &graph.node(node.inputs[0]).shape;
2838                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2839                Thunk::FakeQuantizeLSQBackwardScale {
2840                    x: node_offset(arena, node.inputs[0]),
2841                    scale_off: node_offset(arena, node.inputs[1]),
2842                    dy: node_offset(arena, node.inputs[2]),
2843                    dscale: node_offset(arena, node.id),
2844                    len: in_shape.num_elements().unwrap() as u32,
2845                    chan_axis: chan_axis as u32,
2846                    chan_dim: chan_dim as u32,
2847                    inner: inner as u32,
2848                    bits: *bits,
2849                }
2850            }
2851
2852            Op::FakeQuantizeBackward { bits, axis, ste } => {
2853                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2854                Thunk::FakeQuantizeBackward {
2855                    x: node_offset(arena, node.inputs[0]),
2856                    dy: node_offset(arena, node.inputs[1]),
2857                    dx: node_offset(arena, node.id),
2858                    len: node.shape.num_elements().unwrap() as u32,
2859                    chan_axis: chan_axis as u32,
2860                    chan_dim: chan_dim as u32,
2861                    inner: inner as u32,
2862                    bits: *bits,
2863                    ste: *ste,
2864                }
2865            }
2866
2867            Op::Dequantize {
2868                axis,
2869                scales,
2870                zero_points,
2871            } => {
2872                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2873                Thunk::Dequantize {
2874                    q: node_offset(arena, node.inputs[0]),
2875                    x: node_offset(arena, node.id),
2876                    len: node.shape.num_elements().unwrap() as u32,
2877                    chan_axis: chan_axis as u32,
2878                    chan_dim: chan_dim as u32,
2879                    inner: inner as u32,
2880                    scales: scales.clone(),
2881                    zero_points: zero_points.clone(),
2882                }
2883            }
2884
2885            Op::Expand { .. } => {
2886                // Broadcast: build per-output-dim strides where any input dim
2887                // of size 1 has stride 0 (read the same element repeatedly).
2888                // Reuses the Thunk::Transpose runtime — N-D walk with strides
2889                // is identical; only the strides differ.
2890                let in_shape = &graph.node(node.inputs[0]).shape;
2891                let out_shape = &node.shape;
2892                let in_rank = in_shape.rank();
2893                let out_rank = out_shape.rank();
2894                // Implicit leading 1s if input has lower rank.
2895                let pad = out_rank.saturating_sub(in_rank);
2896                let in_dims: Vec<usize> = (0..out_rank)
2897                    .map(|i| {
2898                        if i < pad {
2899                            1
2900                        } else {
2901                            in_shape.dim(i - pad).unwrap_static()
2902                        }
2903                    })
2904                    .collect();
2905                // Row-major input strides (over the padded shape).
2906                let mut in_strides_full = vec![1usize; out_rank];
2907                for d in (0..out_rank.saturating_sub(1)).rev() {
2908                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2909                }
2910                let out_dims: Vec<u32> = (0..out_rank)
2911                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
2912                    .collect();
2913                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
2914                let in_strides: Vec<u32> = (0..out_rank)
2915                    .map(|i| {
2916                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2917                            0
2918                        } else {
2919                            in_strides_full[i] as u32
2920                        }
2921                    })
2922                    .collect();
2923                let in_total = in_dims.iter().product::<usize>() as u32;
2924                let src = node_offset(arena, node.inputs[0]);
2925                let dst = node_offset(arena, node.id);
2926                let elem_bytes = node.shape.dtype().size_bytes() as u8;
2927                match node.shape.dtype() {
2928                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
2929                        src,
2930                        dst,
2931                        in_total,
2932                        out_dims,
2933                        in_strides,
2934                    },
2935                    _ => Thunk::Transpose {
2936                        src,
2937                        dst,
2938                        in_total,
2939                        out_dims,
2940                        in_strides,
2941                        elem_bytes,
2942                    },
2943                }
2944            }
2945
2946            Op::RmsNorm { eps, .. } => {
2947                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2948                let total = node.shape.num_elements().unwrap();
2949                Thunk::RmsNorm {
2950                    src: node_offset(arena, node.inputs[0]),
2951                    g: node_offset(arena, node.inputs[1]),
2952                    b: node_offset(arena, node.inputs[2]),
2953                    dst: node_offset(arena, node.id),
2954                    rows: (total / h) as u32,
2955                    h: h as u32,
2956                    eps: *eps,
2957                }
2958            }
2959
2960            Op::LayerNorm { eps, .. } => {
2961                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2962                let total = node.shape.num_elements().unwrap();
2963                Thunk::LayerNorm {
2964                    src: node_offset(arena, node.inputs[0]),
2965                    g: node_offset(arena, node.inputs[1]),
2966                    b: node_offset(arena, node.inputs[2]),
2967                    dst: node_offset(arena, node.id),
2968                    rows: (total / h) as u32,
2969                    h: h as u32,
2970                    eps: *eps,
2971                }
2972            }
2973
2974            Op::GroupNorm { num_groups, eps } => {
2975                let in_shape = &graph.node(node.inputs[0]).shape;
2976                let (n, c, h, w) = conv_nchw_dims(in_shape);
2977                Thunk::GroupNorm {
2978                    src: node_offset(arena, node.inputs[0]),
2979                    g: node_offset(arena, node.inputs[1]),
2980                    b: node_offset(arena, node.inputs[2]),
2981                    dst: node_offset(arena, node.id),
2982                    n,
2983                    c,
2984                    h,
2985                    w,
2986                    num_groups: *num_groups as u32,
2987                    eps: *eps,
2988                }
2989            }
2990
2991            Op::BatchNormInference { eps } => {
2992                let in_shape = &graph.node(node.inputs[0]).shape;
2993                let rank = in_shape.rank();
2994                let channels = in_shape.dim(rank - 1).unwrap_static();
2995                let total = in_shape.num_elements().unwrap_or(0);
2996                let count = (total / channels.max(1)) as u32;
2997                Thunk::BatchNormInference {
2998                    src: node_offset(arena, node.inputs[0]),
2999                    g: node_offset(arena, node.inputs[1]),
3000                    b: node_offset(arena, node.inputs[2]),
3001                    mean: node_offset(arena, node.inputs[3]),
3002                    var: node_offset(arena, node.inputs[4]),
3003                    dst: node_offset(arena, node.id),
3004                    count,
3005                    channels: channels as u32,
3006                    eps: *eps,
3007                }
3008            }
3009
3010            Op::BatchNormInferenceBackwardInput { eps } => {
3011                let x_shape = &graph.node(node.inputs[0]).shape;
3012                let rank = x_shape.rank();
3013                let channels = x_shape.dim(rank - 1).unwrap_static();
3014                let total = x_shape.num_elements().unwrap_or(0);
3015                Thunk::BatchNormInferenceBackwardInput {
3016                    x: node_offset(arena, node.inputs[0]),
3017                    gamma: node_offset(arena, node.inputs[1]),
3018                    mean: node_offset(arena, node.inputs[2]),
3019                    var: node_offset(arena, node.inputs[3]),
3020                    dy: node_offset(arena, node.inputs[4]),
3021                    dx: node_offset(arena, node.id),
3022                    count: (total / channels.max(1)) as u32,
3023                    channels: channels as u32,
3024                    eps: *eps,
3025                }
3026            }
3027
3028            Op::BatchNormInferenceBackwardGamma { eps } => {
3029                let x_shape = &graph.node(node.inputs[0]).shape;
3030                let rank = x_shape.rank();
3031                let channels = x_shape.dim(rank - 1).unwrap_static();
3032                let total = x_shape.num_elements().unwrap_or(0);
3033                let _gamma_shape = &graph.node(node.id).shape;
3034                Thunk::BatchNormInferenceBackwardGamma {
3035                    x: node_offset(arena, node.inputs[0]),
3036                    mean: node_offset(arena, node.inputs[1]),
3037                    var: node_offset(arena, node.inputs[2]),
3038                    dy: node_offset(arena, node.inputs[3]),
3039                    dgamma: node_offset(arena, node.id),
3040                    count: (total / channels.max(1)) as u32,
3041                    channels: channels as u32,
3042                    eps: *eps,
3043                }
3044            }
3045
3046            Op::BatchNormInferenceBackwardBeta => {
3047                let dy_shape = &graph.node(node.inputs[0]).shape;
3048                let rank = dy_shape.rank();
3049                let channels = dy_shape.dim(rank - 1).unwrap_static();
3050                let total = dy_shape.num_elements().unwrap_or(0);
3051                Thunk::BatchNormInferenceBackwardBeta {
3052                    dy: node_offset(arena, node.inputs[0]),
3053                    dbeta: node_offset(arena, node.id),
3054                    count: (total / channels.max(1)) as u32,
3055                    channels: channels as u32,
3056                }
3057            }
3058
3059            Op::LayerNorm2d { eps } => {
3060                let in_shape = &graph.node(node.inputs[0]).shape;
3061                let (n, c, h, w) = conv_nchw_dims(in_shape);
3062                Thunk::LayerNorm2d {
3063                    src: node_offset(arena, node.inputs[0]),
3064                    g: node_offset(arena, node.inputs[1]),
3065                    b: node_offset(arena, node.inputs[2]),
3066                    dst: node_offset(arena, node.id),
3067                    n,
3068                    c,
3069                    h,
3070                    w,
3071                    eps: *eps,
3072                }
3073            }
3074
3075            Op::ConvTranspose2d {
3076                kernel_size,
3077                stride,
3078                padding,
3079                dilation,
3080                output_padding: _,
3081                groups,
3082            } => {
3083                let in_shape = &graph.node(node.inputs[0]).shape;
3084                let out_shape = &node.shape;
3085                let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3086                let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3087                Thunk::ConvTranspose2d {
3088                    src: node_offset(arena, node.inputs[0]),
3089                    weight: node_offset(arena, node.inputs[1]),
3090                    dst: node_offset(arena, node.id),
3091                    n,
3092                    c_in,
3093                    h,
3094                    w_in,
3095                    c_out,
3096                    h_out,
3097                    w_out,
3098                    kh: kernel_size[0] as u32,
3099                    kw: kernel_size[1] as u32,
3100                    sh: stride.first().copied().unwrap_or(1) as u32,
3101                    sw: stride.get(1).copied().unwrap_or(1) as u32,
3102                    ph: padding.first().copied().unwrap_or(0) as u32,
3103                    pw: padding.get(1).copied().unwrap_or(0) as u32,
3104                    dh: dilation.first().copied().unwrap_or(1) as u32,
3105                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
3106                    groups: *groups as u32,
3107                }
3108            }
3109
3110            Op::ResizeNearest2x => {
3111                let in_shape = &graph.node(node.inputs[0]).shape;
3112                let (n, c, h, w) = conv_nchw_dims(in_shape);
3113                Thunk::ResizeNearest2x {
3114                    src: node_offset(arena, node.inputs[0]),
3115                    dst: node_offset(arena, node.id),
3116                    n,
3117                    c,
3118                    h,
3119                    w,
3120                }
3121            }
3122
3123            Op::AxialRope2d {
3124                end_x,
3125                end_y,
3126                head_dim,
3127                num_heads,
3128                theta,
3129                repeat_factor,
3130            } => {
3131                let in_shape = &graph.node(node.inputs[0]).shape;
3132                let batch = in_shape.dim(0).unwrap_static() as u32;
3133                let seq = in_shape.dim(1).unwrap_static() as u32;
3134                let hidden = in_shape.dim(2).unwrap_static() as u32;
3135                Thunk::AxialRope2d {
3136                    src: node_offset(arena, node.inputs[0]),
3137                    dst: node_offset(arena, node.id),
3138                    batch,
3139                    seq,
3140                    hidden,
3141                    end_x: *end_x as u32,
3142                    end_y: *end_y as u32,
3143                    head_dim: *head_dim as u32,
3144                    num_heads: *num_heads as u32,
3145                    theta: *theta,
3146                    repeat_factor: *repeat_factor as u32,
3147                }
3148            }
3149
3150            Op::Softmax { axis } => {
3151                let rank = node.shape.rank();
3152                let ax = if *axis < 0 {
3153                    (rank as i32 + axis) as usize
3154                } else {
3155                    *axis as usize
3156                };
3157                let cols = node.shape.dim(ax).unwrap_static();
3158                let total = node.shape.num_elements().unwrap();
3159                let in_off = node_offset(arena, node.inputs[0]);
3160                let out_off = node_offset(arena, node.id);
3161                // Softmax kernel runs in-place on its data buffer. If the
3162                // planner gave input and output separate slots (their live
3163                // ranges overlap, so no aliasing), the output starts
3164                // uninitialized — emit a Copy first so the data is there.
3165                // Same pattern as Op::Activation.
3166                if in_off != out_off {
3167                    thunks.push(Thunk::Copy {
3168                        src: in_off,
3169                        dst: out_off,
3170                        len: total as u32,
3171                    });
3172                }
3173                Thunk::Softmax {
3174                    data: out_off,
3175                    rows: (total / cols) as u32,
3176                    cols: cols as u32,
3177                }
3178            }
3179
3180            Op::SelectiveScan { state_size } => {
3181                let in_shape = &graph.node(node.inputs[0]).shape;
3182                let (batch, seq, hidden) = (
3183                    in_shape.dim(0).unwrap_static(),
3184                    in_shape.dim(1).unwrap_static(),
3185                    in_shape.dim(2).unwrap_static(),
3186                );
3187                Thunk::SelectiveScan {
3188                    x: node_offset(arena, node.inputs[0]),
3189                    delta: node_offset(arena, node.inputs[1]),
3190                    a: node_offset(arena, node.inputs[2]),
3191                    b: node_offset(arena, node.inputs[3]),
3192                    c: node_offset(arena, node.inputs[4]),
3193                    dst: node_offset(arena, node.id),
3194                    batch: batch as u32,
3195                    seq: seq as u32,
3196                    hidden: hidden as u32,
3197                    state_size: *state_size as u32,
3198                }
3199            }
3200
3201            Op::GatedDeltaNet {
3202                state_size,
3203                carry_state,
3204            } => {
3205                let q_shape = &graph.node(node.inputs[0]).shape;
3206                let (batch, seq, heads) = (
3207                    q_shape.dim(0).unwrap_static(),
3208                    q_shape.dim(1).unwrap_static(),
3209                    q_shape.dim(2).unwrap_static(),
3210                );
3211                let state_off = if *carry_state {
3212                    node_offset(arena, node.inputs[5])
3213                } else {
3214                    0
3215                };
3216                Thunk::GatedDeltaNet {
3217                    q: node_offset(arena, node.inputs[0]),
3218                    k: node_offset(arena, node.inputs[1]),
3219                    v: node_offset(arena, node.inputs[2]),
3220                    g: node_offset(arena, node.inputs[3]),
3221                    beta: node_offset(arena, node.inputs[4]),
3222                    state: state_off,
3223                    dst: node_offset(arena, node.id),
3224                    batch: batch as u32,
3225                    seq: seq as u32,
3226                    heads: heads as u32,
3227                    state_size: *state_size as u32,
3228                }
3229            }
3230
3231            Op::QMatMul {
3232                x_zp,
3233                w_zp,
3234                out_zp,
3235                mult,
3236            } => {
3237                let x_shape = &graph.node(node.inputs[0]).shape;
3238                let w_shape = &graph.node(node.inputs[1]).shape;
3239                let m = x_shape.dim(0).unwrap_static();
3240                let k = x_shape.dim(1).unwrap_static();
3241                let n = w_shape.dim(1).unwrap_static();
3242                Thunk::QMatMul {
3243                    x: node_offset(arena, node.inputs[0]),
3244                    w: node_offset(arena, node.inputs[1]),
3245                    bias: node_offset(arena, node.inputs[2]),
3246                    out: node_offset(arena, node.id),
3247                    m: m as u32,
3248                    k: k as u32,
3249                    n: n as u32,
3250                    x_zp: *x_zp,
3251                    w_zp: *w_zp,
3252                    out_zp: *out_zp,
3253                    mult: *mult,
3254                }
3255            }
3256
3257            Op::QConv2d {
3258                kernel_size,
3259                stride,
3260                padding,
3261                dilation,
3262                groups,
3263                x_zp,
3264                w_zp,
3265                out_zp,
3266                mult,
3267            } => {
3268                let in_shape = &graph.node(node.inputs[0]).shape;
3269                let w_shape = &graph.node(node.inputs[1]).shape;
3270                let out_shape = &node.shape;
3271                if kernel_size.len() == 2
3272                    && in_shape.rank() == 4
3273                    && w_shape.rank() == 4
3274                    && out_shape.rank() == 4
3275                {
3276                    Thunk::QConv2d {
3277                        x: node_offset(arena, node.inputs[0]),
3278                        w: node_offset(arena, node.inputs[1]),
3279                        bias: node_offset(arena, node.inputs[2]),
3280                        out: node_offset(arena, node.id),
3281                        n: in_shape.dim(0).unwrap_static() as u32,
3282                        c_in: in_shape.dim(1).unwrap_static() as u32,
3283                        h: in_shape.dim(2).unwrap_static() as u32,
3284                        w_in: in_shape.dim(3).unwrap_static() as u32,
3285                        c_out: out_shape.dim(1).unwrap_static() as u32,
3286                        h_out: out_shape.dim(2).unwrap_static() as u32,
3287                        w_out: out_shape.dim(3).unwrap_static() as u32,
3288                        kh: kernel_size[0] as u32,
3289                        kw: kernel_size[1] as u32,
3290                        sh: stride.first().copied().unwrap_or(1) as u32,
3291                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3292                        ph: padding.first().copied().unwrap_or(0) as u32,
3293                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3294                        dh: dilation.first().copied().unwrap_or(1) as u32,
3295                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3296                        groups: *groups as u32,
3297                        x_zp: *x_zp,
3298                        w_zp: *w_zp,
3299                        out_zp: *out_zp,
3300                        mult: *mult,
3301                    }
3302                } else {
3303                    Thunk::Nop
3304                }
3305            }
3306
3307            Op::DequantMatMul { scheme } => {
3308                use rlx_ir::quant::QuantScheme;
3309                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3310                let total = node.shape.num_elements().unwrap();
3311                let m = total / n.max(1);
3312                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3313                let k = x_total / m.max(1);
3314                if scheme.is_gguf() {
3315                    Thunk::DequantMatMulGguf {
3316                        x: node_offset(arena, node.inputs[0]),
3317                        w_q: node_offset(arena, node.inputs[1]),
3318                        dst: node_offset(arena, node.id),
3319                        m: m as u32,
3320                        k: k as u32,
3321                        n: n as u32,
3322                        scheme: *scheme,
3323                    }
3324                } else {
3325                    match scheme {
3326                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3327                            x: node_offset(arena, node.inputs[0]),
3328                            w_q: node_offset(arena, node.inputs[1]),
3329                            scale: node_offset(arena, node.inputs[2]),
3330                            global_scale: node_offset(arena, node.inputs[3]),
3331                            dst: node_offset(arena, node.id),
3332                            m: m as u32,
3333                            k: k as u32,
3334                            n: n as u32,
3335                        },
3336                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3337                            x: node_offset(arena, node.inputs[0]),
3338                            w_q: node_offset(arena, node.inputs[1]),
3339                            scale: node_offset(arena, node.inputs[2]),
3340                            zp: node_offset(arena, node.inputs[3]),
3341                            dst: node_offset(arena, node.id),
3342                            m: m as u32,
3343                            k: k as u32,
3344                            n: n as u32,
3345                            block_size: *block_size,
3346                            is_asymmetric: false,
3347                        },
3348                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3349                            x: node_offset(arena, node.inputs[0]),
3350                            w_q: node_offset(arena, node.inputs[1]),
3351                            scale: node_offset(arena, node.inputs[2]),
3352                            dst: node_offset(arena, node.id),
3353                            m: m as u32,
3354                            k: k as u32,
3355                            n: n as u32,
3356                            e5m2: false,
3357                        },
3358                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3359                            x: node_offset(arena, node.inputs[0]),
3360                            w_q: node_offset(arena, node.inputs[1]),
3361                            scale: node_offset(arena, node.inputs[2]),
3362                            dst: node_offset(arena, node.id),
3363                            m: m as u32,
3364                            k: k as u32,
3365                            n: n as u32,
3366                            e5m2: true,
3367                        },
3368                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3369                            x: node_offset(arena, node.inputs[0]),
3370                            w_q: node_offset(arena, node.inputs[1]),
3371                            scale: node_offset(arena, node.inputs[2]),
3372                            zp: node_offset(arena, node.inputs[3]),
3373                            dst: node_offset(arena, node.id),
3374                            m: m as u32,
3375                            k: k as u32,
3376                            n: n as u32,
3377                            block_size: *block_size,
3378                            is_asymmetric: false,
3379                        },
3380                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3381                            x: node_offset(arena, node.inputs[0]),
3382                            w_q: node_offset(arena, node.inputs[1]),
3383                            scale: node_offset(arena, node.inputs[2]),
3384                            zp: node_offset(arena, node.inputs[3]),
3385                            dst: node_offset(arena, node.id),
3386                            m: m as u32,
3387                            k: k as u32,
3388                            n: n as u32,
3389                            block_size: *block_size,
3390                            is_asymmetric: true,
3391                        },
3392                        other => panic!(
3393                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3394                        ),
3395                    }
3396                }
3397            }
3398
3399            Op::LoraMatMul { scale } => {
3400                // x [m, k], w [k, n], a [k, r], b [r, n].
3401                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3402                let total = node.shape.num_elements().unwrap();
3403                let m = total / n.max(1);
3404                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3405                let k = x_total / m.max(1);
3406                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3407                let r = a_total / k.max(1);
3408                Thunk::LoraMatMul {
3409                    x: node_offset(arena, node.inputs[0]),
3410                    w: node_offset(arena, node.inputs[1]),
3411                    a: node_offset(arena, node.inputs[2]),
3412                    b: node_offset(arena, node.inputs[3]),
3413                    dst: node_offset(arena, node.id),
3414                    m: m as u32,
3415                    k: k as u32,
3416                    n: n as u32,
3417                    r: r as u32,
3418                    scale: *scale,
3419                }
3420            }
3421
3422            Op::Sample {
3423                top_k,
3424                top_p,
3425                temperature,
3426                seed,
3427            } => {
3428                let in_shape = &graph.node(node.inputs[0]).shape;
3429                // Logits are [batch, vocab] (or [vocab] → batch=1).
3430                let (batch, vocab) = if in_shape.rank() >= 2 {
3431                    (
3432                        in_shape.dim(0).unwrap_static(),
3433                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3434                    )
3435                } else {
3436                    (1, in_shape.num_elements().unwrap_or(0))
3437                };
3438                Thunk::Sample {
3439                    logits: node_offset(arena, node.inputs[0]),
3440                    dst: node_offset(arena, node.id),
3441                    batch: batch as u32,
3442                    vocab: vocab as u32,
3443                    top_k: *top_k as u32,
3444                    top_p: *top_p,
3445                    temperature: *temperature,
3446                    seed: *seed,
3447                }
3448            }
3449
3450            Op::Cumsum { axis, exclusive } => {
3451                // For now CPU only supports last-axis cumsum (the
3452                // common case for sampling / ragged offsets).
3453                // Other axes can lower via Transpose → Cumsum →
3454                // Transpose; not on the hot path today.
3455                let rank = node.shape.rank();
3456                let ax = if *axis < 0 {
3457                    (rank as i32 + axis) as usize
3458                } else {
3459                    *axis as usize
3460                };
3461                assert_eq!(
3462                    ax,
3463                    rank - 1,
3464                    "Cumsum only supports the last axis on CPU today"
3465                );
3466                let cols = node.shape.dim(ax).unwrap_static();
3467                let total = node.shape.num_elements().unwrap();
3468                Thunk::Cumsum {
3469                    src: node_offset(arena, node.inputs[0]),
3470                    dst: node_offset(arena, node.id),
3471                    rows: (total / cols) as u32,
3472                    cols: cols as u32,
3473                    exclusive: *exclusive,
3474                }
3475            }
3476
3477            Op::Attention {
3478                num_heads,
3479                head_dim,
3480                mask_kind,
3481                score_scale: _,
3482                attn_logit_softcap: _,
3483            } => {
3484                // Layout dispatch: rank-4 input could be either
3485                // `[B, S, H, D]` (CPU's historical convention) or
3486                // `[B, H, S, D]` (the convention the GPU/TPU backends
3487                // share). Disambiguate by which axis matches
3488                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3489                let q_shape = &graph.node(node.inputs[0]).shape;
3490                let k_shape = &graph.node(node.inputs[1]).shape;
3491                let rank = q_shape.rank();
3492                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3493                    let d1 = q_shape.dim(1).unwrap_static();
3494                    let d2 = q_shape.dim(2).unwrap_static();
3495                    if d1 == *num_heads {
3496                        // [B, H, S, D]
3497                        (
3498                            q_shape.dim(0).unwrap_static(),
3499                            d2,
3500                            k_shape.dim(2).unwrap_static(),
3501                            true,
3502                        )
3503                    } else {
3504                        // [B, S, H, D]
3505                        (
3506                            q_shape.dim(0).unwrap_static(),
3507                            d1,
3508                            k_shape.dim(1).unwrap_static(),
3509                            false,
3510                        )
3511                    }
3512                } else if rank >= 3 {
3513                    (
3514                        q_shape.dim(0).unwrap_static(),
3515                        q_shape.dim(1).unwrap_static(),
3516                        k_shape.dim(1).unwrap_static(),
3517                        false,
3518                    )
3519                } else {
3520                    (
3521                        1,
3522                        q_shape.dim(0).unwrap_static(),
3523                        k_shape.dim(0).unwrap_static(),
3524                        false,
3525                    )
3526                };
3527                let mask_off = if matches!(
3528                    mask_kind,
3529                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3530                ) {
3531                    node_offset(arena, node.inputs[3])
3532                } else {
3533                    0
3534                };
3535                let hs = (*num_heads * *head_dim) as u32;
3536                Thunk::Attention {
3537                    q: node_offset(arena, node.inputs[0]),
3538                    k: node_offset(arena, node.inputs[1]),
3539                    v: node_offset(arena, node.inputs[2]),
3540                    mask: mask_off,
3541                    out: node_offset(arena, node.id),
3542                    batch: batch as u32,
3543                    seq: seq as u32,
3544                    kv_seq: kv_seq as u32,
3545                    heads: *num_heads as u32,
3546                    head_dim: *head_dim as u32,
3547                    mask_kind: *mask_kind,
3548                    // Defaults: each input is its own contiguous buffer
3549                    // with row stride = hidden. Rewritten by the
3550                    // Narrow→Attention fusion when applicable.
3551                    q_row_stride: hs,
3552                    k_row_stride: hs,
3553                    v_row_stride: hs,
3554                    bhsd,
3555                }
3556            }
3557
3558            Op::AttentionBackward {
3559                num_heads,
3560                head_dim,
3561                mask_kind,
3562                wrt,
3563            } => {
3564                let q_shape = &graph.node(node.inputs[0]).shape;
3565                let k_shape = &graph.node(node.inputs[1]).shape;
3566                let rank = q_shape.rank();
3567                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3568                    let d1 = q_shape.dim(1).unwrap_static();
3569                    let d2 = q_shape.dim(2).unwrap_static();
3570                    if d1 == *num_heads {
3571                        (
3572                            q_shape.dim(0).unwrap_static(),
3573                            d2,
3574                            k_shape.dim(2).unwrap_static(),
3575                            true,
3576                        )
3577                    } else {
3578                        (
3579                            q_shape.dim(0).unwrap_static(),
3580                            d1,
3581                            k_shape.dim(1).unwrap_static(),
3582                            false,
3583                        )
3584                    }
3585                } else if rank >= 3 {
3586                    (
3587                        q_shape.dim(0).unwrap_static(),
3588                        q_shape.dim(1).unwrap_static(),
3589                        k_shape.dim(1).unwrap_static(),
3590                        false,
3591                    )
3592                } else {
3593                    (
3594                        1,
3595                        q_shape.dim(0).unwrap_static(),
3596                        k_shape.dim(0).unwrap_static(),
3597                        false,
3598                    )
3599                };
3600                let mask_off = if matches!(
3601                    mask_kind,
3602                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3603                ) {
3604                    node_offset(arena, node.inputs[4])
3605                } else {
3606                    0
3607                };
3608                Thunk::AttentionBackward {
3609                    q: node_offset(arena, node.inputs[0]),
3610                    k: node_offset(arena, node.inputs[1]),
3611                    v: node_offset(arena, node.inputs[2]),
3612                    dy: node_offset(arena, node.inputs[3]),
3613                    mask: mask_off,
3614                    out: node_offset(arena, node.id),
3615                    batch: batch as u32,
3616                    seq: seq as u32,
3617                    kv_seq: kv_seq as u32,
3618                    heads: *num_heads as u32,
3619                    head_dim: *head_dim as u32,
3620                    mask_kind: *mask_kind,
3621                    wrt: *wrt,
3622                    bhsd,
3623                }
3624            }
3625
3626            Op::FusedAttentionBlock {
3627                num_heads,
3628                head_dim,
3629                has_bias,
3630                has_rope,
3631            } => {
3632                let x_shape = &graph.node(node.inputs[0]).shape;
3633                let (batch, seq) = if x_shape.rank() >= 3 {
3634                    (
3635                        x_shape.dim(0).unwrap_static(),
3636                        x_shape.dim(1).unwrap_static(),
3637                    )
3638                } else {
3639                    let total = x_shape.num_elements().unwrap();
3640                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3641                    (total / (s * num_heads * head_dim), s)
3642                };
3643                let hs = (*num_heads * *head_dim) as u32;
3644                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3645                let mut idx = 4;
3646                let (qkv_b_off, out_b_off) = if *has_bias {
3647                    let qb = node_offset(arena, node.inputs[idx]);
3648                    let ob = node_offset(arena, node.inputs[idx + 1]);
3649                    idx += 2;
3650                    (qb, ob)
3651                } else {
3652                    (0, 0)
3653                };
3654                let (cos_off, sin_off, cl) = if *has_rope {
3655                    let c = node_offset(arena, node.inputs[idx]);
3656                    let s = node_offset(arena, node.inputs[idx + 1]);
3657                    let clen = get_len(graph, node.inputs[idx]);
3658                    (c, s, clen as u32)
3659                } else {
3660                    (0, 0, 0)
3661                };
3662
3663                Thunk::FusedAttnBlock {
3664                    hidden: node_offset(arena, node.inputs[0]),
3665                    qkv_w: node_offset(arena, node.inputs[1]),
3666                    out_w: node_offset(arena, node.inputs[2]),
3667                    mask: node_offset(arena, node.inputs[3]),
3668                    out: node_offset(arena, node.id),
3669                    qkv_b: qkv_b_off,
3670                    out_b: out_b_off,
3671                    cos: cos_off,
3672                    sin: sin_off,
3673                    cos_len: cl,
3674                    batch: batch as u32,
3675                    seq: seq as u32,
3676                    hs,
3677                    nh: *num_heads as u32,
3678                    dh: *head_dim as u32,
3679                    has_bias: *has_bias,
3680                    has_rope: *has_rope,
3681                }
3682            }
3683
3684            Op::Rope { head_dim, n_rot } => {
3685                let x_shape = &graph.node(node.inputs[0]).shape;
3686                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3687                    (
3688                        x_shape.dim(0).unwrap_static(),
3689                        x_shape.dim(1).unwrap_static(),
3690                        x_shape.dim(2).unwrap_static(),
3691                    )
3692                } else {
3693                    let total = x_shape.num_elements().unwrap();
3694                    (
3695                        1,
3696                        x_shape.dim(0).unwrap_static(),
3697                        total / x_shape.dim(0).unwrap_static(),
3698                    )
3699                };
3700                let cos_len = get_len(graph, node.inputs[1]);
3701                Thunk::Rope {
3702                    src: node_offset(arena, node.inputs[0]),
3703                    cos: node_offset(arena, node.inputs[1]),
3704                    sin: node_offset(arena, node.inputs[2]),
3705                    dst: node_offset(arena, node.id),
3706                    batch: batch as u32,
3707                    seq: seq as u32,
3708                    hidden: hidden as u32,
3709                    head_dim: *head_dim as u32,
3710                    n_rot: *n_rot as u32,
3711                    cos_len: cos_len as u32,
3712                    // Default: source rows are tightly packed (rewritten
3713                    // by the Narrow→Rope fusion pass below if Rope ends
3714                    // up reading from a wider parent like QKV).
3715                    src_row_stride: hidden as u32,
3716                }
3717            }
3718
3719            Op::FusedSwiGLU {
3720                cast_to: _,
3721                gate_first,
3722            } => {
3723                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3724                let total = node.shape.num_elements().unwrap();
3725                Thunk::FusedSwiGLU {
3726                    src: node_offset(arena, node.inputs[0]),
3727                    dst: node_offset(arena, node.id),
3728                    n_half: n_half as u32,
3729                    total: total as u32,
3730                    gate_first: *gate_first,
3731                }
3732            }
3733
3734            Op::Conv {
3735                kernel_size,
3736                stride,
3737                padding,
3738                dilation,
3739                groups,
3740            } => {
3741                let in_shape = &graph.node(node.inputs[0]).shape;
3742                let w_shape = &graph.node(node.inputs[1]).shape;
3743                let out_shape = &node.shape;
3744                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3745                // padding=0, dilation=1, groups=1. Emits a single
3746                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3747                let is_1x1_simple = kernel_size.len() == 2
3748                    && kernel_size[0] == 1
3749                    && kernel_size[1] == 1
3750                    && stride.iter().all(|&s| s == 1)
3751                    && padding.iter().all(|&p| p == 0)
3752                    && dilation.iter().all(|&d| d == 1)
3753                    && *groups == 1;
3754                if is_1x1_simple
3755                    && in_shape.rank() >= 3
3756                    && out_shape.rank() >= 3
3757                    && w_shape.rank() >= 2
3758                {
3759                    let (n, c_in, h, w) = conv_nchw_dims(in_shape);
3760                    let (_, c_out, _, _) = conv_nchw_dims(out_shape);
3761                    Thunk::Conv2D1x1 {
3762                        src: node_offset(arena, node.inputs[0]),
3763                        weight: node_offset(arena, node.inputs[1]),
3764                        dst: node_offset(arena, node.id),
3765                        n,
3766                        c_in,
3767                        c_out,
3768                        hw: h.saturating_mul(w),
3769                    }
3770                } else if kernel_size.len() == 2
3771                    && in_shape.rank() >= 3
3772                    && w_shape.rank() >= 2
3773                    && out_shape.rank() >= 3
3774                {
3775                    let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3776                    let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3777                    Thunk::Conv2D {
3778                        src: node_offset(arena, node.inputs[0]),
3779                        weight: node_offset(arena, node.inputs[1]),
3780                        dst: node_offset(arena, node.id),
3781                        n,
3782                        c_in,
3783                        h,
3784                        w: w_in,
3785                        c_out,
3786                        h_out,
3787                        w_out,
3788                        kh: kernel_size[0] as u32,
3789                        kw: kernel_size[1] as u32,
3790                        sh: stride.first().copied().unwrap_or(1) as u32,
3791                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3792                        ph: padding.first().copied().unwrap_or(0) as u32,
3793                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3794                        dh: dilation.first().copied().unwrap_or(1) as u32,
3795                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3796                        groups: *groups as u32,
3797                    }
3798                } else {
3799                    Thunk::Nop
3800                }
3801            }
3802
3803            Op::Pool {
3804                kind,
3805                kernel_size,
3806                stride,
3807                padding,
3808            } => {
3809                // Currently support 2D pooling on rank-4 NCHW tensors.
3810                let in_shape = &graph.node(node.inputs[0]).shape;
3811                let out_shape = &node.shape;
3812                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3813                    Thunk::Pool2D {
3814                        src: node_offset(arena, node.inputs[0]),
3815                        dst: node_offset(arena, node.id),
3816                        n: in_shape.dim(0).unwrap_static() as u32,
3817                        c: in_shape.dim(1).unwrap_static() as u32,
3818                        h: in_shape.dim(2).unwrap_static() as u32,
3819                        w: in_shape.dim(3).unwrap_static() as u32,
3820                        h_out: out_shape.dim(2).unwrap_static() as u32,
3821                        w_out: out_shape.dim(3).unwrap_static() as u32,
3822                        kh: kernel_size[0] as u32,
3823                        kw: kernel_size[1] as u32,
3824                        sh: stride.first().copied().unwrap_or(1) as u32,
3825                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3826                        ph: padding.first().copied().unwrap_or(0) as u32,
3827                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3828                        kind: *kind,
3829                    }
3830                } else {
3831                    Thunk::Nop
3832                }
3833            }
3834
3835            Op::Transpose { perm } => {
3836                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
3837                // runtime loop is just an N-D index walk + scatter.
3838                let in_shape = &graph.node(node.inputs[0]).shape;
3839                let in_rank = in_shape.rank();
3840                if perm.iter().any(|&p| p >= in_rank) {
3841                    Thunk::Nop
3842                } else {
3843                    let in_dims: Vec<usize> = (0..in_rank)
3844                        .map(|i| in_shape.dim(i).unwrap_static())
3845                        .collect();
3846                    // Row-major input strides: stride[d] = product of dims[d+1..].
3847                    let mut in_strides_full = vec![1usize; in_rank];
3848                    for d in (0..in_rank.saturating_sub(1)).rev() {
3849                        in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3850                    }
3851                    let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3852                    let in_strides: Vec<u32> =
3853                        perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3854                    let in_total = in_dims.iter().product::<usize>() as u32;
3855                    let src = node_offset(arena, node.inputs[0]);
3856                    let dst = node_offset(arena, node.id);
3857                    let elem_bytes = node.shape.dtype().size_bytes() as u8;
3858                    match node.shape.dtype() {
3859                        rlx_ir::DType::F64 => Thunk::TransposeF64 {
3860                            src,
3861                            dst,
3862                            in_total,
3863                            out_dims,
3864                            in_strides,
3865                        },
3866                        _ => Thunk::Transpose {
3867                            src,
3868                            dst,
3869                            in_total,
3870                            out_dims,
3871                            in_strides,
3872                            elem_bytes,
3873                        },
3874                    }
3875                }
3876            }
3877
3878            Op::ScatterAdd => {
3879                // updates: [num_updates, ...trailing], indices: [num_updates],
3880                // output: [out_dim, ...trailing]
3881                let upd_shape = &graph.node(node.inputs[0]).shape;
3882                let out_shape = &node.shape;
3883                let num_updates = upd_shape.dim(0).unwrap_static();
3884                let out_dim = out_shape.dim(0).unwrap_static();
3885                let trailing: usize = (1..out_shape.rank())
3886                    .map(|i| out_shape.dim(i).unwrap_static())
3887                    .product::<usize>()
3888                    .max(1);
3889                Thunk::ScatterAdd {
3890                    updates: node_offset(arena, node.inputs[0]),
3891                    indices: node_offset(arena, node.inputs[1]),
3892                    dst: node_offset(arena, node.id),
3893                    num_updates: num_updates as u32,
3894                    out_dim: out_dim as u32,
3895                    trailing: trailing as u32,
3896                }
3897            }
3898
3899            Op::GroupedMatMul => {
3900                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
3901                let in_shape = &graph.node(node.inputs[0]).shape;
3902                let w_shape = &graph.node(node.inputs[1]).shape;
3903                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3904                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3905                let num_experts = w_shape.dim(0).unwrap_static();
3906                let n = w_shape.dim(2).unwrap_static();
3907                Thunk::GroupedMatMul {
3908                    input: node_offset(arena, node.inputs[0]),
3909                    weight: node_offset(arena, node.inputs[1]),
3910                    expert_idx: node_offset(arena, node.inputs[2]),
3911                    dst: node_offset(arena, node.id),
3912                    m: m as u32,
3913                    k_dim: k_dim as u32,
3914                    n: n as u32,
3915                    num_experts: num_experts as u32,
3916                }
3917            }
3918
3919            Op::DequantGroupedMatMul { scheme } => {
3920                let in_shape = &graph.node(node.inputs[0]).shape;
3921                let w_shape = &graph.node(node.inputs[1]).shape;
3922                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3923                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3924                let out_shape = &node.shape;
3925                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3926                let block_elems = scheme.gguf_block_size() as usize;
3927                let block_bytes = scheme.gguf_block_bytes() as usize;
3928                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3929                let total_bytes = w_shape.num_elements().unwrap();
3930                let num_experts = total_bytes / slab_bytes.max(1);
3931                Thunk::DequantGroupedMatMulGguf {
3932                    input: node_offset(arena, node.inputs[0]),
3933                    w_q: node_offset(arena, node.inputs[1]),
3934                    expert_idx: node_offset(arena, node.inputs[2]),
3935                    dst: node_offset(arena, node.id),
3936                    m: m as u32,
3937                    k_dim: k_dim as u32,
3938                    n: n as u32,
3939                    num_experts: num_experts as u32,
3940                    scheme: *scheme,
3941                }
3942            }
3943
3944            Op::DequantMoEWeights { scheme } => {
3945                let w_shape = &graph.node(node.inputs[0]).shape;
3946                let out_shape = &node.shape;
3947                let num_experts = out_shape.dim(0).unwrap_static();
3948                let k_dim = out_shape.dim(1).unwrap_static();
3949                let n = out_shape.dim(2).unwrap_static();
3950                let block_elems = scheme.gguf_block_size() as usize;
3951                let block_bytes = scheme.gguf_block_bytes() as usize;
3952                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3953                let total_bytes = w_shape.num_elements().unwrap();
3954                assert_eq!(
3955                    total_bytes,
3956                    num_experts * slab_bytes,
3957                    "DequantMoEWeights packed bytes mismatch"
3958                );
3959                Thunk::DequantMoEWeightsGguf {
3960                    w_q: node_offset(arena, node.inputs[0]),
3961                    dst: node_offset(arena, node.id),
3962                    k_dim: k_dim as u32,
3963                    n: n as u32,
3964                    num_experts: num_experts as u32,
3965                    scheme: *scheme,
3966                }
3967            }
3968
3969            Op::TopK { k } => {
3970                let in_shape = &graph.node(node.inputs[0]).shape;
3971                let rank = in_shape.rank();
3972                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3973                let outer = in_shape.num_elements().unwrap() / axis_dim;
3974                let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
3975                Thunk::TopK {
3976                    src: node_offset(arena, node.inputs[0]),
3977                    dst: node_offset(arena, node.id),
3978                    outer: outer as u32,
3979                    axis_dim: axis_dim as u32,
3980                    k: *k as u32,
3981                    indices_i64,
3982                }
3983            }
3984
3985            Op::Reduce {
3986                op,
3987                axes,
3988                keep_dim: _,
3989            } => {
3990                // Decompose the input shape into [outer, reduced, inner]
3991                // around the reduced axis range. Non-contiguous reduced
3992                // axes aren't supported here — caller must transpose them
3993                // contiguous first (the coverage tool would surface the
3994                // gap if a model needs it).
3995                let in_shape = &graph.node(node.inputs[0]).shape;
3996                let rank = in_shape.rank();
3997                let mut sorted = axes.clone();
3998                sorted.sort();
3999                sorted.dedup();
4000                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4001                    && !sorted.is_empty()
4002                    && *sorted.last().unwrap() < rank;
4003                if !contiguous {
4004                    Thunk::Nop
4005                } else {
4006                    let first = sorted[0];
4007                    let last = *sorted.last().unwrap();
4008                    let outer: usize = (0..first)
4009                        .map(|i| in_shape.dim(i).unwrap_static())
4010                        .product::<usize>()
4011                        .max(1);
4012                    let reduced: usize = (first..=last)
4013                        .map(|i| in_shape.dim(i).unwrap_static())
4014                        .product();
4015                    let inner: usize = (last + 1..rank)
4016                        .map(|i| in_shape.dim(i).unwrap_static())
4017                        .product::<usize>()
4018                        .max(1);
4019                    let src = node_offset(arena, node.inputs[0]);
4020                    let dst = node_offset(arena, node.id);
4021                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4022                        Thunk::ReduceSumF64 {
4023                            src,
4024                            dst,
4025                            outer: outer as u32,
4026                            reduced: reduced as u32,
4027                            inner: inner as u32,
4028                        }
4029                    } else {
4030                        Thunk::Reduce {
4031                            src,
4032                            dst,
4033                            outer: outer as u32,
4034                            reduced: reduced as u32,
4035                            inner: inner as u32,
4036                            op: *op,
4037                        }
4038                    }
4039                }
4040            }
4041
4042            Op::Compare(cmp) => {
4043                let len = node.shape.num_elements().unwrap();
4044                let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4045                let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4046                Thunk::Compare {
4047                    lhs: node_offset(arena, node.inputs[0]),
4048                    rhs: node_offset(arena, node.inputs[1]),
4049                    dst: node_offset(arena, node.id),
4050                    len: len as u32,
4051                    op: *cmp,
4052                    inputs_i64,
4053                    inputs_elem_bytes: in_dtype.size_bytes() as u8,
4054                    dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4055                }
4056            }
4057
4058            Op::Where => {
4059                let len = node.shape.num_elements().unwrap();
4060                let elem_bytes = node.shape.dtype().size_bytes() as u8;
4061                let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4062                Thunk::Where {
4063                    cond: node_offset(arena, node.inputs[0]),
4064                    on_true: node_offset(arena, node.inputs[1]),
4065                    on_false: node_offset(arena, node.inputs[2]),
4066                    dst: node_offset(arena, node.id),
4067                    len: len as u32,
4068                    elem_bytes,
4069                    cond_elem_bytes,
4070                }
4071            }
4072
4073            Op::ReluBackward => {
4074                let len: usize = (0..node.shape.rank())
4075                    .map(|i| node.shape.dim(i).unwrap_static())
4076                    .product();
4077                let x = node_offset(arena, node.inputs[0]);
4078                let dy = node_offset(arena, node.inputs[1]);
4079                let dx = node_offset(arena, node.id);
4080                match node.shape.dtype() {
4081                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4082                        x,
4083                        dy,
4084                        dx,
4085                        len: len as u32,
4086                    },
4087                    _ => Thunk::ReluBackward {
4088                        x,
4089                        dy,
4090                        dx,
4091                        len: len as u32,
4092                    },
4093                }
4094            }
4095
4096            Op::ComplexNormSq => {
4097                let len: usize = (0..node.shape.rank())
4098                    .map(|i| node.shape.dim(i).unwrap_static())
4099                    .product();
4100                let src = node_offset(arena, node.inputs[0]);
4101                let dst = node_offset(arena, node.id);
4102                Thunk::ComplexNormSqF32 {
4103                    src,
4104                    dst,
4105                    len: len as u32,
4106                }
4107            }
4108
4109            Op::ComplexNormSqBackward => {
4110                let len: usize = (0..node.shape.rank())
4111                    .map(|i| node.shape.dim(i).unwrap_static())
4112                    .product();
4113                let z = node_offset(arena, node.inputs[0]);
4114                let g = node_offset(arena, node.inputs[1]);
4115                let dz = node_offset(arena, node.id);
4116                Thunk::ComplexNormSqBackwardF32 {
4117                    z,
4118                    g,
4119                    dz,
4120                    len: len as u32,
4121                }
4122            }
4123
4124            Op::Conjugate => {
4125                let len: usize = (0..node.shape.rank())
4126                    .map(|i| node.shape.dim(i).unwrap_static())
4127                    .product();
4128                Thunk::ConjugateC64 {
4129                    src: node_offset(arena, node.inputs[0]),
4130                    dst: node_offset(arena, node.id),
4131                    len: len as u32,
4132                }
4133            }
4134
4135            Op::ActivationBackward { kind } => {
4136                let len: usize = (0..node.shape.rank())
4137                    .map(|i| node.shape.dim(i).unwrap_static())
4138                    .product();
4139                let x = node_offset(arena, node.inputs[0]);
4140                let dy = node_offset(arena, node.inputs[1]);
4141                let dx = node_offset(arena, node.id);
4142                match node.shape.dtype() {
4143                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4144                        x,
4145                        dy,
4146                        dx,
4147                        len: len as u32,
4148                        kind: *kind,
4149                    },
4150                    _ => Thunk::ActivationBackward {
4151                        x,
4152                        dy,
4153                        dx,
4154                        len: len as u32,
4155                        kind: *kind,
4156                    },
4157                }
4158            }
4159
4160            Op::LayerNormBackwardInput { eps, .. } => {
4161                // axis = -1 only (matches forward LayerNorm thunk).
4162                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4163                let total = node.shape.num_elements().unwrap();
4164                Thunk::LayerNormBackwardInput {
4165                    x: node_offset(arena, node.inputs[0]),
4166                    gamma: node_offset(arena, node.inputs[1]),
4167                    dy: node_offset(arena, node.inputs[2]),
4168                    dx: node_offset(arena, node.id),
4169                    rows: (total / h) as u32,
4170                    h: h as u32,
4171                    eps: *eps,
4172                }
4173            }
4174
4175            Op::LayerNormBackwardGamma { eps, .. } => {
4176                let x_shape = &graph.node(node.inputs[0]).shape;
4177                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4178                let x_total = x_shape.num_elements().unwrap();
4179                Thunk::LayerNormBackwardGamma {
4180                    x: node_offset(arena, node.inputs[0]),
4181                    dy: node_offset(arena, node.inputs[1]),
4182                    dgamma: node_offset(arena, node.id),
4183                    rows: (x_total / h) as u32,
4184                    h: h as u32,
4185                    eps: *eps,
4186                }
4187            }
4188
4189            Op::RmsNormBackwardInput { eps, .. }
4190            | Op::RmsNormBackwardGamma { eps, .. }
4191            | Op::RmsNormBackwardBeta { eps, .. } => {
4192                let x_shape = &graph.node(node.inputs[0]).shape;
4193                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4194                let rows = (x_shape.num_elements().unwrap() / h) as u32;
4195                let off = |i: usize| node_offset(arena, node.inputs[i]);
4196                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4197                match &node.op {
4198                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4199                        x: common.0,
4200                        gamma: common.1,
4201                        beta: common.2,
4202                        dy: common.3,
4203                        dx: node_offset(arena, node.id),
4204                        rows: common.4,
4205                        h: common.5,
4206                        eps: common.6,
4207                    },
4208                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4209                        x: common.0,
4210                        gamma: common.1,
4211                        beta: common.2,
4212                        dy: common.3,
4213                        dgamma: node_offset(arena, node.id),
4214                        rows: common.4,
4215                        h: common.5,
4216                        eps: common.6,
4217                    },
4218                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4219                        x: common.0,
4220                        gamma: common.1,
4221                        beta: common.2,
4222                        dy: common.3,
4223                        dbeta: node_offset(arena, node.id),
4224                        rows: common.4,
4225                        h: common.5,
4226                        eps: common.6,
4227                    },
4228                    _ => unreachable!(),
4229                }
4230            }
4231
4232            Op::RopeBackward { head_dim, n_rot } => {
4233                let dy_shape = &graph.node(node.inputs[0]).shape;
4234                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4235                    (
4236                        dy_shape.dim(0).unwrap_static(),
4237                        dy_shape.dim(1).unwrap_static(),
4238                        dy_shape.dim(2).unwrap_static(),
4239                    )
4240                } else {
4241                    (
4242                        1,
4243                        dy_shape.dim(0).unwrap_static(),
4244                        dy_shape.dim(1).unwrap_static(),
4245                    )
4246                };
4247                let cos_shape = &graph.node(node.inputs[1]).shape;
4248                let cos_len = cos_shape.num_elements().unwrap();
4249                Thunk::RopeBackward {
4250                    dy: node_offset(arena, node.inputs[0]),
4251                    cos: node_offset(arena, node.inputs[1]),
4252                    sin: node_offset(arena, node.inputs[2]),
4253                    dx: node_offset(arena, node.id),
4254                    batch: batch as u32,
4255                    seq: seq as u32,
4256                    hidden: hidden as u32,
4257                    head_dim: *head_dim as u32,
4258                    n_rot: *n_rot as u32,
4259                    cos_len: cos_len as u32,
4260                }
4261            }
4262
4263            Op::CumsumBackward { exclusive, .. } => {
4264                let dy_shape = &graph.node(node.inputs[0]).shape;
4265                let rank = dy_shape.rank();
4266                let cols = dy_shape.dim(rank - 1).unwrap_static();
4267                let rows = dy_shape.num_elements().unwrap() / cols;
4268                Thunk::CumsumBackward {
4269                    dy: node_offset(arena, node.inputs[0]),
4270                    dx: node_offset(arena, node.id),
4271                    rows: rows as u32,
4272                    cols: cols as u32,
4273                    exclusive: *exclusive,
4274                }
4275            }
4276
4277            Op::GatherBackward { .. } => {
4278                let dy_shape = &graph.node(node.inputs[0]).shape;
4279                let idx_shape = &graph.node(node.inputs[1]).shape;
4280                let out_shape = &node.shape;
4281                let rank = out_shape.rank();
4282                let axis = match &node.op {
4283                    Op::GatherBackward { axis } => *axis,
4284                    _ => 0,
4285                };
4286                let axis_u = if axis < 0 {
4287                    (rank as i32 + axis) as usize
4288                } else {
4289                    axis as usize
4290                };
4291                let outer: usize = (0..axis_u)
4292                    .map(|i| dy_shape.dim(i).unwrap_static())
4293                    .product::<usize>()
4294                    .max(1);
4295                let num_idx = idx_shape.dim(axis_u).unwrap_static();
4296                let trailing: usize = (axis_u + 1..dy_shape.rank())
4297                    .map(|i| dy_shape.dim(i).unwrap_static())
4298                    .product::<usize>()
4299                    .max(1);
4300                let axis_dim = out_shape.dim(axis_u).unwrap_static();
4301                Thunk::GatherBackward {
4302                    dy: node_offset(arena, node.inputs[0]),
4303                    indices: node_offset(arena, node.inputs[1]),
4304                    dst: node_offset(arena, node.id),
4305                    outer: outer as u32,
4306                    axis_dim: axis_dim as u32,
4307                    num_idx: num_idx as u32,
4308                    trailing: trailing as u32,
4309                }
4310            }
4311
4312            Op::GroupNormBackwardInput { num_groups, eps }
4313            | Op::GroupNormBackwardGamma { num_groups, eps }
4314            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4315                let x_shape = &graph.node(node.inputs[0]).shape;
4316                let n = x_shape.dim(0).unwrap_static() as u32;
4317                let c = x_shape.dim(1).unwrap_static() as u32;
4318                let h = x_shape.dim(2).unwrap_static() as u32;
4319                let w = x_shape.dim(3).unwrap_static() as u32;
4320                match &node.op {
4321                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4322                        x: node_offset(arena, node.inputs[0]),
4323                        gamma: node_offset(arena, node.inputs[1]),
4324                        beta: node_offset(arena, node.inputs[2]),
4325                        dy: node_offset(arena, node.inputs[3]),
4326                        dx: node_offset(arena, node.id),
4327                        n,
4328                        c,
4329                        h,
4330                        w,
4331                        num_groups: *num_groups as u32,
4332                        eps: *eps,
4333                    },
4334                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4335                        x: node_offset(arena, node.inputs[0]),
4336                        dy: node_offset(arena, node.inputs[1]),
4337                        dgamma: node_offset(arena, node.id),
4338                        n,
4339                        c,
4340                        h,
4341                        w,
4342                        num_groups: *num_groups as u32,
4343                        eps: *eps,
4344                    },
4345                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4346                        dy: node_offset(arena, node.inputs[1]),
4347                        dbeta: node_offset(arena, node.id),
4348                        n,
4349                        c,
4350                        h,
4351                        w,
4352                    },
4353                    _ => unreachable!(),
4354                }
4355            }
4356
4357            Op::MaxPool2dBackward {
4358                kernel_size,
4359                stride,
4360                padding,
4361            } => {
4362                let x_shape = &graph.node(node.inputs[0]).shape;
4363                let dy_shape = &graph.node(node.inputs[1]).shape;
4364                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4365                    Thunk::MaxPool2dBackward {
4366                        x: node_offset(arena, node.inputs[0]),
4367                        dy: node_offset(arena, node.inputs[1]),
4368                        dx: node_offset(arena, node.id),
4369                        n: x_shape.dim(0).unwrap_static() as u32,
4370                        c: x_shape.dim(1).unwrap_static() as u32,
4371                        h: x_shape.dim(2).unwrap_static() as u32,
4372                        w: x_shape.dim(3).unwrap_static() as u32,
4373                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4374                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4375                        kh: kernel_size[0] as u32,
4376                        kw: kernel_size[1] as u32,
4377                        sh: stride.first().copied().unwrap_or(1) as u32,
4378                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4379                        ph: padding.first().copied().unwrap_or(0) as u32,
4380                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4381                    }
4382                } else {
4383                    Thunk::Nop
4384                }
4385            }
4386
4387            Op::Conv2dBackwardInput {
4388                kernel_size,
4389                stride,
4390                padding,
4391                dilation,
4392                groups,
4393            } => {
4394                let dy_shape = &graph.node(node.inputs[0]).shape;
4395                let w_shape = &graph.node(node.inputs[1]).shape;
4396                let out_shape = &node.shape;
4397                if kernel_size.len() == 2
4398                    && dy_shape.rank() == 4
4399                    && w_shape.rank() == 4
4400                    && out_shape.rank() == 4
4401                {
4402                    Thunk::Conv2dBackwardInput {
4403                        dy: node_offset(arena, node.inputs[0]),
4404                        w: node_offset(arena, node.inputs[1]),
4405                        dx: node_offset(arena, node.id),
4406                        n: out_shape.dim(0).unwrap_static() as u32,
4407                        c_in: out_shape.dim(1).unwrap_static() as u32,
4408                        h: out_shape.dim(2).unwrap_static() as u32,
4409                        w_in: out_shape.dim(3).unwrap_static() as u32,
4410                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4411                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4412                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4413                        kh: kernel_size[0] as u32,
4414                        kw: kernel_size[1] as u32,
4415                        sh: stride.first().copied().unwrap_or(1) as u32,
4416                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4417                        ph: padding.first().copied().unwrap_or(0) as u32,
4418                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4419                        dh: dilation.first().copied().unwrap_or(1) as u32,
4420                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4421                        groups: *groups as u32,
4422                    }
4423                } else {
4424                    Thunk::Nop
4425                }
4426            }
4427
4428            Op::Conv2dBackwardWeight {
4429                kernel_size,
4430                stride,
4431                padding,
4432                dilation,
4433                groups,
4434            } => {
4435                let x_shape = &graph.node(node.inputs[0]).shape;
4436                let dy_shape = &graph.node(node.inputs[1]).shape;
4437                let dw_shape = &node.shape;
4438                if kernel_size.len() == 2
4439                    && x_shape.rank() == 4
4440                    && dy_shape.rank() == 4
4441                    && dw_shape.rank() == 4
4442                {
4443                    Thunk::Conv2dBackwardWeight {
4444                        x: node_offset(arena, node.inputs[0]),
4445                        dy: node_offset(arena, node.inputs[1]),
4446                        dw: node_offset(arena, node.id),
4447                        n: x_shape.dim(0).unwrap_static() as u32,
4448                        c_in: x_shape.dim(1).unwrap_static() as u32,
4449                        h: x_shape.dim(2).unwrap_static() as u32,
4450                        w: x_shape.dim(3).unwrap_static() as u32,
4451                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4452                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4453                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4454                        kh: kernel_size[0] as u32,
4455                        kw: kernel_size[1] as u32,
4456                        sh: stride.first().copied().unwrap_or(1) as u32,
4457                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4458                        ph: padding.first().copied().unwrap_or(0) as u32,
4459                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4460                        dh: dilation.first().copied().unwrap_or(1) as u32,
4461                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4462                        groups: *groups as u32,
4463                    }
4464                } else {
4465                    Thunk::Nop
4466                }
4467            }
4468
4469            Op::Im2Col {
4470                kernel_size,
4471                stride,
4472                padding,
4473                dilation,
4474            } => {
4475                let x_shape = &graph.node(node.inputs[0]).shape;
4476                let out_shape = &node.shape;
4477                if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4478                    let n = match x_shape.dim(0) {
4479                        rlx_ir::shape::Dim::Static(v) => v as u32,
4480                        _ => 0,
4481                    };
4482                    let c_in = x_shape.dim(1).unwrap_static() as u32;
4483                    let h = x_shape.dim(2).unwrap_static() as u32;
4484                    let w = x_shape.dim(3).unwrap_static() as u32;
4485                    let kh = kernel_size[0] as u32;
4486                    let kw = kernel_size[1] as u32;
4487                    let sh = stride.first().copied().unwrap_or(1) as u32;
4488                    let sw = stride.get(1).copied().unwrap_or(1) as u32;
4489                    let ph = padding.first().copied().unwrap_or(0) as u32;
4490                    let pw = padding.get(1).copied().unwrap_or(0) as u32;
4491                    let dh = dilation.first().copied().unwrap_or(1) as u32;
4492                    let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4493                    let h_out = rlx_ir::shape::conv2d_spatial_output(
4494                        h as usize,
4495                        kh as usize,
4496                        sh as usize,
4497                        ph as usize,
4498                        dh as usize,
4499                    ) as u32;
4500                    let w_out = rlx_ir::shape::conv2d_spatial_output(
4501                        w as usize,
4502                        kw as usize,
4503                        sw as usize,
4504                        pw as usize,
4505                        dw_dil as usize,
4506                    ) as u32;
4507                    Thunk::Im2Col {
4508                        x: node_offset(arena, node.inputs[0]),
4509                        col: node_offset(arena, node.id),
4510                        n,
4511                        c_in,
4512                        h,
4513                        w,
4514                        h_out,
4515                        w_out,
4516                        kh,
4517                        kw,
4518                        sh,
4519                        sw,
4520                        ph,
4521                        pw,
4522                        dh,
4523                        dw_dil,
4524                    }
4525                } else {
4526                    Thunk::Nop
4527                }
4528            }
4529
4530            Op::SoftmaxCrossEntropyWithLogits => {
4531                let logits_shape = &graph.node(node.inputs[0]).shape;
4532                if logits_shape.rank() == 2 {
4533                    Thunk::SoftmaxCrossEntropy {
4534                        logits: node_offset(arena, node.inputs[0]),
4535                        labels: node_offset(arena, node.inputs[1]),
4536                        dst: node_offset(arena, node.id),
4537                        n: logits_shape.dim(0).unwrap_static() as u32,
4538                        c: logits_shape.dim(1).unwrap_static() as u32,
4539                    }
4540                } else {
4541                    Thunk::Nop
4542                }
4543            }
4544
4545            Op::SoftmaxCrossEntropyBackward => {
4546                let logits_shape = &graph.node(node.inputs[0]).shape;
4547                if logits_shape.rank() == 2 {
4548                    Thunk::SoftmaxCrossEntropyBackward {
4549                        logits: node_offset(arena, node.inputs[0]),
4550                        labels: node_offset(arena, node.inputs[1]),
4551                        d_loss: node_offset(arena, node.inputs[2]),
4552                        dlogits: node_offset(arena, node.id),
4553                        n: logits_shape.dim(0).unwrap_static() as u32,
4554                        c: logits_shape.dim(1).unwrap_static() as u32,
4555                    }
4556                } else {
4557                    Thunk::Nop
4558                }
4559            }
4560
4561            Op::DenseSolve => {
4562                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4563                let a_shape = &graph.node(node.inputs[0]).shape;
4564                let n = a_shape.dim(0).unwrap_static();
4565                debug_assert_eq!(
4566                    n,
4567                    a_shape.dim(1).unwrap_static(),
4568                    "DenseSolve: A must be square"
4569                );
4570                let b_elems = node.shape.num_elements().unwrap();
4571                let nrhs = b_elems / n;
4572                match node.shape.dtype() {
4573                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4574                        a: node_offset(arena, node.inputs[0]),
4575                        b: node_offset(arena, node.inputs[1]),
4576                        x: node_offset(arena, node.id),
4577                        n: n as u32,
4578                        nrhs: nrhs as u32,
4579                    },
4580                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4581                        a: node_offset(arena, node.inputs[0]),
4582                        b: node_offset(arena, node.inputs[1]),
4583                        x: node_offset(arena, node.id),
4584                        n: n as u32,
4585                        nrhs: nrhs as u32,
4586                    },
4587                    other => panic!(
4588                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4589                         Add another variant when needed."
4590                    ),
4591                }
4592            }
4593
4594            Op::BatchedDenseSolve => {
4595                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4596                let a_shape = &graph.node(node.inputs[0]).shape;
4597                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4598                let batch = a_shape.dim(0).unwrap_static();
4599                let n = a_shape.dim(1).unwrap_static();
4600                debug_assert_eq!(
4601                    n,
4602                    a_shape.dim(2).unwrap_static(),
4603                    "BatchedDenseSolve: A's last two dims must match"
4604                );
4605                let total = node.shape.num_elements().unwrap();
4606                let nrhs = total / (batch * n);
4607                match node.shape.dtype() {
4608                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4609                        a: node_offset(arena, node.inputs[0]),
4610                        b: node_offset(arena, node.inputs[1]),
4611                        x: node_offset(arena, node.id),
4612                        batch: batch as u32,
4613                        n: n as u32,
4614                        nrhs: nrhs as u32,
4615                    },
4616                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4617                        a: node_offset(arena, node.inputs[0]),
4618                        b: node_offset(arena, node.inputs[1]),
4619                        x: node_offset(arena, node.id),
4620                        batch: batch as u32,
4621                        n: n as u32,
4622                        nrhs: nrhs as u32,
4623                    },
4624                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4625                }
4626            }
4627
4628            Op::Scan {
4629                body,
4630                length,
4631                save_trajectory,
4632                num_bcast,
4633                num_xs,
4634                num_checkpoints,
4635            } => {
4636                assert!(
4637                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4638                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4639                    *num_checkpoints,
4640                    *length
4641                );
4642                if *num_checkpoints != 0 && *num_checkpoints != *length {
4643                    assert!(
4644                        *save_trajectory,
4645                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4646                    );
4647                }
4648                // Plan + compile the body sub-graph standalone. The body
4649                // gets its own Arena; per execution we clone its
4650                // pristine bytes, copy the outer carry (and per-step xs
4651                // slices, if any) into the body's Input slots, run the
4652                // body schedule N times, then copy the body's output
4653                // back to the outer arena.
4654                //
4655                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4656                // — first declared is the carry, rest are x_t_i. Single
4657                // graph output (the next carry), same shape as carry.
4658                let body_plan = rlx_opt::memory::plan_memory(body);
4659                let _body_arena_size = body_plan.arena_size;
4660                // Snapshot per-input byte offsets before plan_memory
4661                // moves into the Arena below.
4662                let body_offsets: HashMap<NodeId, usize> = body_plan
4663                    .assignments
4664                    .iter()
4665                    .map(|(id, slot)| (*id, slot.offset))
4666                    .collect();
4667
4668                // Collect body Input nodes in NodeId order; first is
4669                // carry, rest are per-step xs in matching order.
4670                let mut body_inputs: Vec<NodeId> = body
4671                    .nodes()
4672                    .iter()
4673                    .filter(|n| matches!(n.op, Op::Input { .. }))
4674                    .map(|n| n.id)
4675                    .collect();
4676                body_inputs.sort();
4677                let n_body_inputs = body_inputs.len();
4678                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4679                if n_body_inputs != expected {
4680                    let names: Vec<String> = body
4681                        .nodes()
4682                        .iter()
4683                        .filter_map(|n| match &n.op {
4684                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4685                            _ => None,
4686                        })
4687                        .collect();
4688                    panic!(
4689                        "Op::Scan body has {} Op::Input nodes; expected {} \
4690                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4691                        n_body_inputs,
4692                        expected,
4693                        *num_bcast,
4694                        *num_xs,
4695                        names.join(", ")
4696                    );
4697                }
4698
4699                let body_input_id = body_inputs[0];
4700                let body_input_off = body_offsets[&body_input_id];
4701                let body_output_id = body
4702                    .outputs
4703                    .first()
4704                    .copied()
4705                    .expect("Op::Scan body must declare one output");
4706                let body_output_off = body_offsets[&body_output_id];
4707
4708                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4709                // Fill body Constant nodes — mirror the outer-graph logic
4710                // in rlx-runtime/src/backend.rs (dtype-aware).
4711                for n in body.nodes() {
4712                    if let Op::Constant { data } = &n.op
4713                        && body_arena.has_buffer(n.id)
4714                        && !data.is_empty()
4715                    {
4716                        match n.shape.dtype() {
4717                            rlx_ir::DType::F64 => {
4718                                let off = body_arena.byte_offset(n.id);
4719                                let buf = body_arena.raw_buf_mut();
4720                                let nbytes = (buf.len() - off).min(data.len());
4721                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4722                            }
4723                            _ => {
4724                                let buf = body_arena.slice_mut(n.id);
4725                                let n_floats = data.len() / 4;
4726                                let n_lim = buf.len().min(n_floats);
4727                                for i in 0..n_lim {
4728                                    let bytes = [
4729                                        data[i * 4],
4730                                        data[i * 4 + 1],
4731                                        data[i * 4 + 2],
4732                                        data[i * 4 + 3],
4733                                    ];
4734                                    buf[i] = f32::from_le_bytes(bytes);
4735                                }
4736                            }
4737                        }
4738                    }
4739                }
4740                let body_init = body_arena.raw_buf().to_vec();
4741                let body_schedule = compile_thunks(body, &body_arena);
4742
4743                // Carry bytes — for trajectory mode, the outer node's
4744                // shape is [length, *carry_shape], so dividing by length
4745                // gives one row's bytes; the body's input slot still
4746                // holds carry_shape bytes.
4747                let carry_bytes = if *save_trajectory {
4748                    let total = node
4749                        .shape
4750                        .size_bytes()
4751                        .expect("Op::Scan trajectory output must have static shape");
4752                    total / *length as usize
4753                } else {
4754                    node.shape
4755                        .size_bytes()
4756                        .expect("Op::Scan carry must have static shape")
4757                };
4758
4759                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
4760                // outer node.inputs[1..1+num_bcast]. They keep their
4761                // natural shape (no [length, ...] prefix) and are
4762                // copied into body_buf ONCE before the scan loop.
4763                let mut bcast_inputs: Vec<(usize, usize, u32)> =
4764                    Vec::with_capacity(*num_bcast as usize);
4765                for i in 0..*num_bcast as usize {
4766                    let body_b_id = body_inputs[1 + i];
4767                    let body_b_off = body_offsets[&body_b_id];
4768                    let outer_b_id = node.inputs[1 + i];
4769                    let outer_b_off = node_offset(arena, outer_b_id);
4770                    let outer_b_shape = &graph.node(outer_b_id).shape;
4771                    let total = outer_b_shape
4772                        .size_bytes()
4773                        .expect("Op::Scan bcast must have static shape");
4774                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4775                }
4776
4777                // xs occupy body_inputs[1+num_bcast..] and node.inputs
4778                // [1+num_bcast..]. Each has shape [length, *per_step];
4779                // per-step bytes = total / length.
4780                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4781                let xs_base = 1 + *num_bcast as usize;
4782                for i in 0..*num_xs as usize {
4783                    let body_x_id = body_inputs[xs_base + i];
4784                    let body_x_off = body_offsets[&body_x_id];
4785                    let outer_xs_id = node.inputs[xs_base + i];
4786                    let outer_xs_off = node_offset(arena, outer_xs_id);
4787                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4788                    let total = outer_xs_shape
4789                        .size_bytes()
4790                        .expect("Op::Scan xs must have static shape");
4791                    let per_step = total / *length as usize;
4792                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4793                }
4794
4795                Thunk::Scan {
4796                    body: Arc::new(body_schedule),
4797                    body_init: Arc::new(body_init),
4798                    body_input_off,
4799                    body_output_off,
4800                    outer_init_off: node_offset(arena, node.inputs[0]),
4801                    outer_final_off: node_offset(arena, node.id),
4802                    length: *length,
4803                    carry_bytes: carry_bytes as u32,
4804                    save_trajectory: *save_trajectory,
4805                    xs_inputs: Arc::new(xs_inputs),
4806                    bcast_inputs: Arc::new(bcast_inputs),
4807                    num_checkpoints: *num_checkpoints,
4808                }
4809            }
4810
4811            Op::ScanBackward {
4812                body_vjp,
4813                length,
4814                save_trajectory,
4815                num_xs,
4816                num_checkpoints,
4817                forward_body,
4818            } => {
4819                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4820                if is_recursive {
4821                    assert!(
4822                        forward_body.is_some(),
4823                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
4824                    );
4825                }
4826                // body_vjp has signature
4827                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
4828                // Identify slots:
4829                //   * "d_output" by exact name (AD-introduced seed Input).
4830                //   * Remaining Inputs sorted by NodeId — first is the
4831                //     carry mirror, rest are x_t_i mirrors in body's
4832                //     original Op::Input declaration order.
4833                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4834                let body_offsets: HashMap<NodeId, usize> = body_plan
4835                    .assignments
4836                    .iter()
4837                    .map(|(id, slot)| (*id, slot.offset))
4838                    .collect();
4839                let mut body_d_output_off: Option<usize> = None;
4840                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4841                for n in body_vjp.nodes() {
4842                    if let Op::Input { name } = &n.op {
4843                        let off = body_offsets[&n.id];
4844                        if name == "d_output" {
4845                            body_d_output_off = Some(off);
4846                        } else {
4847                            body_other_inputs.push((n.id, off));
4848                        }
4849                    }
4850                }
4851                body_other_inputs.sort_by_key(|(id, _)| *id);
4852                let body_d_output_off =
4853                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4854                let expected_others = 1 + *num_xs as usize;
4855                assert_eq!(
4856                    body_other_inputs.len(),
4857                    expected_others,
4858                    "ScanBackward body_vjp has {} non-d_output Inputs; \
4859                     expected {} (1 carry + {} xs)",
4860                    body_other_inputs.len(),
4861                    expected_others,
4862                    num_xs
4863                );
4864                let body_carry_in_off = body_other_inputs[0].1;
4865                let body_x_offs: Vec<usize> = body_other_inputs
4866                    .iter()
4867                    .skip(1)
4868                    .map(|(_, off)| *off)
4869                    .collect();
4870                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4871
4872                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4873                // Fill body_vjp's Constants (mirrors the Scan lowering).
4874                for n in body_vjp.nodes() {
4875                    if let Op::Constant { data } = &n.op
4876                        && body_arena.has_buffer(n.id)
4877                        && !data.is_empty()
4878                    {
4879                        match n.shape.dtype() {
4880                            rlx_ir::DType::F64 => {
4881                                let off = body_arena.byte_offset(n.id);
4882                                let buf = body_arena.raw_buf_mut();
4883                                let nb = (buf.len() - off).min(data.len());
4884                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4885                            }
4886                            _ => {
4887                                let buf = body_arena.slice_mut(n.id);
4888                                let nf = data.len() / 4;
4889                                let nl = buf.len().min(nf);
4890                                for i in 0..nl {
4891                                    let bytes = [
4892                                        data[i * 4],
4893                                        data[i * 4 + 1],
4894                                        data[i * 4 + 2],
4895                                        data[i * 4 + 3],
4896                                    ];
4897                                    buf[i] = f32::from_le_bytes(bytes);
4898                                }
4899                            }
4900                        }
4901                    }
4902                }
4903                let body_init = body_arena.raw_buf().to_vec();
4904                let body_schedule = compile_thunks(body_vjp, &body_arena);
4905
4906                // Carry bytes from the dcarry output node (== carry shape).
4907                let carry_bytes = body_vjp
4908                    .node(body_vjp.outputs[0])
4909                    .shape
4910                    .size_bytes()
4911                    .expect("ScanBackward dcarry must be statically shaped");
4912                let carry_elem_size = body_vjp
4913                    .node(body_vjp.outputs[0])
4914                    .shape
4915                    .dtype()
4916                    .size_bytes() as u32;
4917
4918                // For each xs input on the outer node:
4919                // (outer_xs_base, per_step_bytes).
4920                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4921                for i in 0..*num_xs as usize {
4922                    let outer_xs_id = node.inputs[3 + i];
4923                    let outer_xs_off = node_offset(arena, outer_xs_id);
4924                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4925                    let total = outer_xs_shape
4926                        .size_bytes()
4927                        .expect("ScanBackward xs must have static shape");
4928                    let per_step = total / *length as usize;
4929                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4930                }
4931
4932                // If recursive checkpointing is active, we also compile
4933                // the forward body so the executor can recompute
4934                // intermediate carries. The forward body is supplied
4935                // by the AD pass via `forward_body: Some(_)`.
4936                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4937                    if is_recursive {
4938                        let fb = forward_body.as_ref().unwrap();
4939                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4940                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4941                            .assignments
4942                            .iter()
4943                            .map(|(id, slot)| (*id, slot.offset))
4944                            .collect();
4945                        let mut fb_inputs: Vec<NodeId> = fb
4946                            .nodes()
4947                            .iter()
4948                            .filter(|n| matches!(n.op, Op::Input { .. }))
4949                            .map(|n| n.id)
4950                            .collect();
4951                        fb_inputs.sort();
4952                        let fb_carry = fb_offsets[&fb_inputs[0]];
4953                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4954                            .map(|i| fb_offsets[&fb_inputs[i]])
4955                            .collect();
4956                        let fb_out = fb_offsets[&fb.outputs[0]];
4957                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4958                        for n in fb.nodes() {
4959                            if let Op::Constant { data } = &n.op
4960                                && fb_arena.has_buffer(n.id)
4961                                && !data.is_empty()
4962                            {
4963                                // Byte-copy works for any
4964                                // numeric dtype as long as the
4965                                // arena slot is sized to hold
4966                                // it — the Constant's `data`
4967                                // already encodes the right
4968                                // bytes per element.
4969                                let off = fb_arena.byte_offset(n.id);
4970                                let buf = fb_arena.raw_buf_mut();
4971                                let nb = (buf.len() - off).min(data.len());
4972                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4973                            }
4974                        }
4975                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4976                        let fb_sched = compile_thunks(fb, &fb_arena);
4977                        (
4978                            Some(Arc::new(fb_sched)),
4979                            Some(Arc::new(fb_init_bytes)),
4980                            fb_carry,
4981                            fb_out,
4982                            fb_xs,
4983                        )
4984                    } else {
4985                        (None, None, 0, 0, Vec::new())
4986                    };
4987
4988                Thunk::ScanBackward {
4989                    body_vjp: Arc::new(body_schedule),
4990                    body_init: Arc::new(body_init),
4991                    body_carry_in_off,
4992                    body_x_offs: Arc::new(body_x_offs),
4993                    body_d_output_off,
4994                    body_dcarry_out_off,
4995                    outer_init_off: node_offset(arena, node.inputs[0]),
4996                    outer_traj_off: node_offset(arena, node.inputs[1]),
4997                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4998                    outer_xs_offs: Arc::new(outer_xs_offs),
4999                    outer_dinit_off: node_offset(arena, node.id),
5000                    length: *length,
5001                    carry_bytes: carry_bytes as u32,
5002                    carry_elem_size,
5003                    save_trajectory: *save_trajectory,
5004                    num_checkpoints: *num_checkpoints,
5005                    forward_body: fb_schedule,
5006                    forward_body_init: fb_init,
5007                    forward_body_carry_in_off: fb_carry_in_off,
5008                    forward_body_output_off: fb_output_off,
5009                    forward_body_x_offs: Arc::new(fb_x_offs),
5010                }
5011            }
5012
5013            Op::ScanBackwardXs {
5014                body_vjp,
5015                length,
5016                save_trajectory,
5017                num_xs,
5018                xs_idx,
5019                num_checkpoints,
5020                forward_body,
5021            } => {
5022                assert!(
5023                    *num_checkpoints == 0 || *num_checkpoints <= *length,
5024                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5025                    *num_checkpoints,
5026                    *length
5027                );
5028                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5029                if is_recursive {
5030                    assert!(
5031                        forward_body.is_some(),
5032                        "Op::ScanBackwardXs with num_checkpoints<length \
5033                         requires forward_body"
5034                    );
5035                }
5036                // Mirror ScanBackward's body_vjp slot identification +
5037                // arena prep, then add: per-iteration extraction of the
5038                // body_vjp output that corresponds to the chosen xs.
5039                //
5040                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
5041                //   outputs[0]      = dcarry
5042                //   outputs[1 + i]  = dx_t_i
5043                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5044                let body_offsets: HashMap<NodeId, usize> = body_plan
5045                    .assignments
5046                    .iter()
5047                    .map(|(id, slot)| (*id, slot.offset))
5048                    .collect();
5049                let mut body_d_output_off: Option<usize> = None;
5050                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5051                for n in body_vjp.nodes() {
5052                    if let Op::Input { name } = &n.op {
5053                        let off = body_offsets[&n.id];
5054                        if name == "d_output" {
5055                            body_d_output_off = Some(off);
5056                        } else {
5057                            body_other_inputs.push((n.id, off));
5058                        }
5059                    }
5060                }
5061                body_other_inputs.sort_by_key(|(id, _)| *id);
5062                let body_d_output_off =
5063                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5064                let expected_others = 1 + *num_xs as usize;
5065                assert_eq!(
5066                    body_other_inputs.len(),
5067                    expected_others,
5068                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5069                    body_other_inputs.len(),
5070                    expected_others
5071                );
5072                let body_carry_in_off = body_other_inputs[0].1;
5073                let body_x_offs: Vec<usize> = body_other_inputs
5074                    .iter()
5075                    .skip(1)
5076                    .map(|(_, off)| *off)
5077                    .collect();
5078                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5079                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5080                let body_dxs_out_off = body_offsets[&dxs_out_node];
5081
5082                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5083                for n in body_vjp.nodes() {
5084                    if let Op::Constant { data } = &n.op
5085                        && body_arena.has_buffer(n.id)
5086                        && !data.is_empty()
5087                    {
5088                        match n.shape.dtype() {
5089                            rlx_ir::DType::F64 => {
5090                                let off = body_arena.byte_offset(n.id);
5091                                let buf = body_arena.raw_buf_mut();
5092                                let nb = (buf.len() - off).min(data.len());
5093                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5094                            }
5095                            _ => {
5096                                let buf = body_arena.slice_mut(n.id);
5097                                let nf = data.len() / 4;
5098                                let nl = buf.len().min(nf);
5099                                for i in 0..nl {
5100                                    let bytes = [
5101                                        data[i * 4],
5102                                        data[i * 4 + 1],
5103                                        data[i * 4 + 2],
5104                                        data[i * 4 + 3],
5105                                    ];
5106                                    buf[i] = f32::from_le_bytes(bytes);
5107                                }
5108                            }
5109                        }
5110                    }
5111                }
5112                let body_init = body_arena.raw_buf().to_vec();
5113                let body_schedule = compile_thunks(body_vjp, &body_arena);
5114
5115                let carry_bytes = body_vjp
5116                    .node(body_vjp.outputs[0])
5117                    .shape
5118                    .size_bytes()
5119                    .expect("ScanBackwardXs dcarry must be statically shaped");
5120                let carry_elem_size = body_vjp
5121                    .node(body_vjp.outputs[0])
5122                    .shape
5123                    .dtype()
5124                    .size_bytes() as u32;
5125                let per_step_bytes = body_vjp
5126                    .node(dxs_out_node)
5127                    .shape
5128                    .size_bytes()
5129                    .expect("ScanBackwardXs dxs body output must be statically shaped");
5130
5131                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5132                for i in 0..*num_xs as usize {
5133                    let outer_xs_id = node.inputs[3 + i];
5134                    let outer_xs_off = node_offset(arena, outer_xs_id);
5135                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
5136                    let total = outer_xs_shape
5137                        .size_bytes()
5138                        .expect("ScanBackwardXs xs must have static shape");
5139                    let per_step = total / *length as usize;
5140                    outer_xs_offs.push((outer_xs_off, per_step as u32));
5141                }
5142
5143                // Compile forward_body for recompute when checkpointed.
5144                // Mirrors the same code path in the ScanBackward arm.
5145                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5146                    if is_recursive {
5147                        let fb = forward_body.as_ref().unwrap();
5148                        let fb_plan = rlx_opt::memory::plan_memory(fb);
5149                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
5150                            .assignments
5151                            .iter()
5152                            .map(|(id, slot)| (*id, slot.offset))
5153                            .collect();
5154                        let mut fb_inputs: Vec<NodeId> = fb
5155                            .nodes()
5156                            .iter()
5157                            .filter(|n| matches!(n.op, Op::Input { .. }))
5158                            .map(|n| n.id)
5159                            .collect();
5160                        fb_inputs.sort();
5161                        let fb_carry = fb_offsets[&fb_inputs[0]];
5162                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
5163                            .map(|i| fb_offsets[&fb_inputs[i]])
5164                            .collect();
5165                        let fb_out = fb_offsets[&fb.outputs[0]];
5166                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5167                        for n in fb.nodes() {
5168                            if let Op::Constant { data } = &n.op
5169                                && fb_arena.has_buffer(n.id)
5170                                && !data.is_empty()
5171                            {
5172                                // Byte-copy works for any
5173                                // numeric dtype as long as the
5174                                // arena slot is sized to hold
5175                                // it — the Constant's `data`
5176                                // already encodes the right
5177                                // bytes per element.
5178                                let off = fb_arena.byte_offset(n.id);
5179                                let buf = fb_arena.raw_buf_mut();
5180                                let nb = (buf.len() - off).min(data.len());
5181                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5182                            }
5183                        }
5184                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
5185                        let fb_sched = compile_thunks(fb, &fb_arena);
5186                        (
5187                            Some(Arc::new(fb_sched)),
5188                            Some(Arc::new(fb_init_bytes)),
5189                            fb_carry,
5190                            fb_out,
5191                            fb_xs,
5192                        )
5193                    } else {
5194                        (None, None, 0, 0, Vec::new())
5195                    };
5196
5197                Thunk::ScanBackwardXs {
5198                    body_vjp: Arc::new(body_schedule),
5199                    body_init: Arc::new(body_init),
5200                    body_carry_in_off,
5201                    body_x_offs: Arc::new(body_x_offs),
5202                    body_d_output_off,
5203                    body_dcarry_out_off,
5204                    body_dxs_out_off,
5205                    outer_init_off: node_offset(arena, node.inputs[0]),
5206                    outer_traj_off: node_offset(arena, node.inputs[1]),
5207                    outer_upstream_off: node_offset(arena, node.inputs[2]),
5208                    outer_xs_offs: Arc::new(outer_xs_offs),
5209                    outer_dxs_off: node_offset(arena, node.id),
5210                    length: *length,
5211                    carry_bytes: carry_bytes as u32,
5212                    carry_elem_size,
5213                    per_step_bytes: per_step_bytes as u32,
5214                    save_trajectory: *save_trajectory,
5215                    num_checkpoints: *num_checkpoints,
5216                    forward_body: fb_schedule,
5217                    forward_body_init: fb_init,
5218                    forward_body_carry_in_off: fb_carry_in_off,
5219                    forward_body_output_off: fb_output_off,
5220                    forward_body_x_offs: Arc::new(fb_x_offs),
5221                }
5222            }
5223
5224            Op::Concat { axis } => {
5225                // Compute outer/inner from the OUTPUT shape: all inputs share
5226                // the same shape except along `axis`. The output's leading
5227                // and trailing dims match.
5228                let out_shape = &node.shape;
5229                let rank = out_shape.rank();
5230                let outer: usize = (0..*axis)
5231                    .map(|i| out_shape.dim(i).unwrap_static())
5232                    .product::<usize>()
5233                    .max(1);
5234                let inner: usize = (*axis + 1..rank)
5235                    .map(|i| out_shape.dim(i).unwrap_static())
5236                    .product::<usize>()
5237                    .max(1);
5238                let total_axis = out_shape.dim(*axis).unwrap_static();
5239                let inputs: Vec<(usize, u32)> = node
5240                    .inputs
5241                    .iter()
5242                    .map(|&in_id| {
5243                        let in_shape = &graph.node(in_id).shape;
5244                        let in_axis = in_shape.dim(*axis).unwrap_static();
5245                        (node_offset(arena, in_id), in_axis as u32)
5246                    })
5247                    .collect();
5248                let dst = node_offset(arena, node.id);
5249                match out_shape.dtype() {
5250                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
5251                        dst,
5252                        outer: outer as u32,
5253                        inner: inner as u32,
5254                        total_axis: total_axis as u32,
5255                        inputs,
5256                    },
5257                    _ => Thunk::Concat {
5258                        dst,
5259                        outer: outer as u32,
5260                        inner: inner as u32,
5261                        total_axis: total_axis as u32,
5262                        inputs,
5263                    },
5264                }
5265            }
5266
5267            Op::GaussianSplatRender {
5268                width,
5269                height,
5270                tile_size,
5271                radius_scale,
5272                alpha_cutoff,
5273                max_splat_steps,
5274                transmittance_threshold,
5275                max_list_entries,
5276            } => {
5277                let elem_len =
5278                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5279                Thunk::GaussianSplatRender {
5280                    positions_off: node_offset(arena, node.inputs[0]),
5281                    positions_len: elem_len(node.inputs[0]),
5282                    scales_off: node_offset(arena, node.inputs[1]),
5283                    scales_len: elem_len(node.inputs[1]),
5284                    rotations_off: node_offset(arena, node.inputs[2]),
5285                    rotations_len: elem_len(node.inputs[2]),
5286                    opacities_off: node_offset(arena, node.inputs[3]),
5287                    opacities_len: elem_len(node.inputs[3]),
5288                    colors_off: node_offset(arena, node.inputs[4]),
5289                    colors_len: elem_len(node.inputs[4]),
5290                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5291                    sh_coeffs_len: elem_len(node.inputs[5]),
5292                    meta_off: node_offset(arena, node.inputs[6]),
5293                    dst_off: node_offset(arena, node.id),
5294                    dst_len: node.shape.num_elements().unwrap_or(0),
5295                    width: *width,
5296                    height: *height,
5297                    tile_size: *tile_size,
5298                    radius_scale: *radius_scale,
5299                    alpha_cutoff: *alpha_cutoff,
5300                    max_splat_steps: *max_splat_steps,
5301                    transmittance_threshold: *transmittance_threshold,
5302                    max_list_entries: *max_list_entries,
5303                }
5304            }
5305
5306            Op::GaussianSplatRenderBackward {
5307                width,
5308                height,
5309                tile_size,
5310                radius_scale,
5311                alpha_cutoff,
5312                max_splat_steps,
5313                transmittance_threshold,
5314                max_list_entries,
5315                loss_grad_clip,
5316                sh_band,
5317                max_anisotropy,
5318            } => {
5319                let elem_len =
5320                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5321                Thunk::GaussianSplatRenderBackward {
5322                    positions_off: node_offset(arena, node.inputs[0]),
5323                    positions_len: elem_len(node.inputs[0]),
5324                    scales_off: node_offset(arena, node.inputs[1]),
5325                    scales_len: elem_len(node.inputs[1]),
5326                    rotations_off: node_offset(arena, node.inputs[2]),
5327                    rotations_len: elem_len(node.inputs[2]),
5328                    opacities_off: node_offset(arena, node.inputs[3]),
5329                    opacities_len: elem_len(node.inputs[3]),
5330                    colors_off: node_offset(arena, node.inputs[4]),
5331                    colors_len: elem_len(node.inputs[4]),
5332                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5333                    sh_coeffs_len: elem_len(node.inputs[5]),
5334                    meta_off: node_offset(arena, node.inputs[6]),
5335                    d_loss_off: node_offset(arena, node.inputs[7]),
5336                    d_loss_len: elem_len(node.inputs[7]),
5337                    packed_off: node_offset(arena, node.id),
5338                    packed_len: node.shape.num_elements().unwrap_or(0),
5339                    width: *width,
5340                    height: *height,
5341                    tile_size: *tile_size,
5342                    radius_scale: *radius_scale,
5343                    alpha_cutoff: *alpha_cutoff,
5344                    max_splat_steps: *max_splat_steps,
5345                    transmittance_threshold: *transmittance_threshold,
5346                    max_list_entries: *max_list_entries,
5347                    loss_grad_clip: *loss_grad_clip,
5348                    sh_band: *sh_band,
5349                    max_anisotropy: *max_anisotropy,
5350                }
5351            }
5352
5353            Op::GaussianSplatPrepare {
5354                width,
5355                height,
5356                tile_size,
5357                radius_scale,
5358                alpha_cutoff,
5359                max_splat_steps,
5360                transmittance_threshold,
5361                max_list_entries,
5362            } => {
5363                let elem_len =
5364                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5365                Thunk::GaussianSplatPrepare {
5366                    positions_off: node_offset(arena, node.inputs[0]),
5367                    positions_len: elem_len(node.inputs[0]),
5368                    scales_off: node_offset(arena, node.inputs[1]),
5369                    scales_len: elem_len(node.inputs[1]),
5370                    rotations_off: node_offset(arena, node.inputs[2]),
5371                    rotations_len: elem_len(node.inputs[2]),
5372                    opacities_off: node_offset(arena, node.inputs[3]),
5373                    opacities_len: elem_len(node.inputs[3]),
5374                    colors_off: node_offset(arena, node.inputs[4]),
5375                    colors_len: elem_len(node.inputs[4]),
5376                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5377                    sh_coeffs_len: elem_len(node.inputs[5]),
5378                    meta_off: node_offset(arena, node.inputs[6]),
5379                    meta_len: elem_len(node.inputs[6]),
5380                    prep_off: node_offset(arena, node.id),
5381                    prep_len: node.shape.num_elements().unwrap_or(0),
5382                    width: *width,
5383                    height: *height,
5384                    tile_size: *tile_size,
5385                    radius_scale: *radius_scale,
5386                    alpha_cutoff: *alpha_cutoff,
5387                    max_splat_steps: *max_splat_steps,
5388                    transmittance_threshold: *transmittance_threshold,
5389                    max_list_entries: *max_list_entries,
5390                }
5391            }
5392
5393            Op::GaussianSplatRasterize {
5394                width,
5395                height,
5396                tile_size,
5397                alpha_cutoff,
5398                max_splat_steps,
5399                transmittance_threshold,
5400                max_list_entries,
5401            } => {
5402                let elem_len =
5403                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5404                let prep_id = node.inputs[0];
5405                let count = match &graph.node(prep_id).op {
5406                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5407                        elem_len(graph.node(prep_id).inputs[0]) / 3
5408                    }
5409                    _ => 1,
5410                };
5411                Thunk::GaussianSplatRasterize {
5412                    prep_off: node_offset(arena, prep_id),
5413                    prep_len: elem_len(prep_id),
5414                    meta_off: node_offset(arena, node.inputs[1]),
5415                    meta_len: elem_len(node.inputs[1]),
5416                    dst_off: node_offset(arena, node.id),
5417                    dst_len: node.shape.num_elements().unwrap_or(0),
5418                    count,
5419                    width: *width,
5420                    height: *height,
5421                    tile_size: *tile_size,
5422                    alpha_cutoff: *alpha_cutoff,
5423                    max_splat_steps: *max_splat_steps,
5424                    transmittance_threshold: *transmittance_threshold,
5425                    max_list_entries: *max_list_entries,
5426                }
5427            }
5428
5429            Op::Custom { name, attrs, .. } => {
5430                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5431                    panic!(
5432                        "compile_thunks: no CPU kernel registered for \
5433                         Op::Custom('{name}'). Register one via \
5434                         rlx_cpu::op_registry::register_cpu_kernel \
5435                         before compiling on the CPU backend."
5436                    )
5437                });
5438                let inputs_v: Vec<(usize, u32, Shape)> = node
5439                    .inputs
5440                    .iter()
5441                    .map(|&in_id| {
5442                        let s = graph.node(in_id).shape.clone();
5443                        let len = s.num_elements().unwrap_or(0) as u32;
5444                        (node_offset(arena, in_id), len, s)
5445                    })
5446                    .collect();
5447                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5448                Thunk::CustomOp {
5449                    kernel,
5450                    inputs: inputs_v,
5451                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5452                    attrs: attrs.clone(),
5453                }
5454            }
5455
5456            Op::Fft { inverse, norm } => {
5457                let shape = &node.shape;
5458                let meta = rlx_ir::fft::fft_meta(shape);
5459                let dtype = shape.dtype();
5460                assert!(
5461                    matches!(
5462                        dtype,
5463                        rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5464                    ),
5465                    "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5466                );
5467                Thunk::Fft1d {
5468                    src: node_offset(arena, node.inputs[0]),
5469                    dst: node_offset(arena, node.id),
5470                    outer: meta.outer as u32,
5471                    n_complex: meta.n_complex as u32,
5472                    inverse: *inverse,
5473                    norm_tag: norm.tag(),
5474                    dtype,
5475                }
5476            }
5477
5478            Op::FftButterflyStage { stage, n_fft } => {
5479                let state_shape = graph.node(node.inputs[0]).shape.clone();
5480                assert_eq!(
5481                    state_shape.dtype(),
5482                    rlx_ir::DType::F32,
5483                    "Op::FftButterflyStage requires F32 state"
5484                );
5485                let batch = state_shape.dim(0).unwrap_static() as u32;
5486                Thunk::FftButterflyStage {
5487                    state_src: node_offset(arena, node.inputs[0]),
5488                    state_dst: node_offset(arena, node.id),
5489                    gate_src: node_offset(arena, node.inputs[1]),
5490                    rev_src: node_offset(arena, node.inputs[2]),
5491                    tw_re_src: node_offset(arena, node.inputs[3]),
5492                    tw_im_src: node_offset(arena, node.inputs[4]),
5493                    batch,
5494                    n_fft: *n_fft,
5495                    stage: *stage,
5496                }
5497            }
5498
5499            Op::LogMel => {
5500                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5501                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5502                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5503                    .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5504                Thunk::LogMel {
5505                    spec: node_offset(arena, node.inputs[0]),
5506                    filters: node_offset(arena, node.inputs[1]),
5507                    dst: node_offset(arena, node.id),
5508                    outer: meta.outer as u32,
5509                    n_fft: meta.n_fft as u32,
5510                    n_bins: meta.n_bins as u32,
5511                    n_mels: meta.n_mels as u32,
5512                }
5513            }
5514
5515            Op::LogMelBackward => {
5516                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5517                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5518                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5519                    .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5520                Thunk::LogMelBackward {
5521                    spec: node_offset(arena, node.inputs[0]),
5522                    filters: node_offset(arena, node.inputs[1]),
5523                    dy: node_offset(arena, node.inputs[2]),
5524                    dst: node_offset(arena, node.id),
5525                    outer: meta.outer as u32,
5526                    n_fft: meta.n_fft as u32,
5527                    n_bins: meta.n_bins as u32,
5528                    n_mels: meta.n_mels as u32,
5529                }
5530            }
5531
5532            Op::CustomFn {
5533                fwd_body,
5534                num_inputs,
5535                ..
5536            } => {
5537                // Plan + compile the body sub-graph standalone, fill its
5538                // Constants (mirrors the Op::Scan body lowering), then
5539                // capture per-input copy specs and the output spec.
5540                // Body Inputs in NodeId order match the outer node's
5541                // operand vector by position.
5542                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5543                let body_offsets: HashMap<NodeId, usize> = body_plan
5544                    .assignments
5545                    .iter()
5546                    .map(|(id, slot)| (*id, slot.offset))
5547                    .collect();
5548
5549                let mut body_input_ids: Vec<NodeId> = fwd_body
5550                    .nodes()
5551                    .iter()
5552                    .filter(|n| matches!(n.op, Op::Input { .. }))
5553                    .map(|n| n.id)
5554                    .collect();
5555                body_input_ids.sort();
5556                assert_eq!(
5557                    body_input_ids.len(),
5558                    *num_inputs as usize,
5559                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5560                    body_input_ids.len(),
5561                    *num_inputs,
5562                );
5563
5564                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5565                for n in fwd_body.nodes() {
5566                    if let Op::Constant { data } = &n.op
5567                        && body_arena.has_buffer(n.id)
5568                        && !data.is_empty()
5569                    {
5570                        match n.shape.dtype() {
5571                            rlx_ir::DType::F64 => {
5572                                let off = body_arena.byte_offset(n.id);
5573                                let buf = body_arena.raw_buf_mut();
5574                                let nb = (buf.len() - off).min(data.len());
5575                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5576                            }
5577                            _ => {
5578                                let buf = body_arena.slice_mut(n.id);
5579                                let nf = data.len() / 4;
5580                                let nl = buf.len().min(nf);
5581                                for i in 0..nl {
5582                                    let bytes = [
5583                                        data[i * 4],
5584                                        data[i * 4 + 1],
5585                                        data[i * 4 + 2],
5586                                        data[i * 4 + 3],
5587                                    ];
5588                                    buf[i] = f32::from_le_bytes(bytes);
5589                                }
5590                            }
5591                        }
5592                    }
5593                }
5594                let body_init = body_arena.raw_buf().to_vec();
5595                let body_schedule = compile_thunks(fwd_body, &body_arena);
5596
5597                // Per primal input: (body_input_off, outer_input_off, bytes).
5598                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5599                    .map(|i| {
5600                        let body_in = body_input_ids[i];
5601                        let body_off = body_offsets[&body_in];
5602                        let outer_in = node.inputs[i];
5603                        let outer_off = node_offset(arena, outer_in);
5604                        let bytes = graph
5605                            .node(outer_in)
5606                            .shape
5607                            .size_bytes()
5608                            .expect("Op::CustomFn primal input must have static shape");
5609                        (body_off, outer_off, bytes as u32)
5610                    })
5611                    .collect();
5612
5613                let body_output_id = fwd_body
5614                    .outputs
5615                    .first()
5616                    .copied()
5617                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5618                let body_output_off = body_offsets[&body_output_id];
5619                let out_bytes = node
5620                    .shape
5621                    .size_bytes()
5622                    .expect("Op::CustomFn output must have static shape");
5623
5624                Thunk::CustomFn {
5625                    body: Arc::new(body_schedule),
5626                    body_init: Arc::new(body_init),
5627                    inputs: Arc::new(inputs_v),
5628                    body_output_off,
5629                    outer_output_off: node_offset(arena, node.id),
5630                    out_bytes: out_bytes as u32,
5631                }
5632            }
5633
5634            _ => Thunk::Nop,
5635        };
5636        thunks.push(t);
5637    }
5638
5639    let cfg = crate::config::RuntimeConfig::global();
5640    let mask_thr = cfg.mask_binary_threshold;
5641    let mask_neg = cfg.attn_mask_neg_inf;
5642    let score_skip = cfg.score_skip_threshold;
5643
5644    // Pre-compile closures (skip Nops — they're filtered out)
5645    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5646        .iter()
5647        .filter(|t| !matches!(t, Thunk::Nop))
5648        .map(|thunk| {
5649            match thunk.clone() {
5650                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5651
5652                Thunk::Sgemm { a, b, c, m, k, n } => {
5653                    let (m, k, n) = (m as usize, k as usize, n as usize);
5654                    Arc::new(move |base: *mut u8| unsafe {
5655                        crate::blas::sgemm(
5656                            sl(a, base, m * k),
5657                            sl(b, base, k * n),
5658                            sl_mut(c, base, m * n),
5659                            m,
5660                            k,
5661                            n,
5662                        );
5663                    })
5664                }
5665
5666                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5667                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5668                    Arc::new(move |base: *mut u8| unsafe {
5669                        let a_src = sl_f64(a, base, n_ * n_);
5670                        let b_src = sl_f64(b, base, n_ * nrhs_);
5671                        let mut a_scratch: Vec<f64> = a_src.to_vec();
5672                        let mut x_buf: Vec<f64> = b_src.to_vec();
5673                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5674                        if info != 0 {
5675                            panic!("DenseSolveF64: singular (info={info})");
5676                        }
5677                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5678                    })
5679                }
5680
5681                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5682                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5683                    Arc::new(move |base: *mut u8| unsafe {
5684                        let a_src = sl(a, base, n_ * n_);
5685                        let b_src = sl(b, base, n_ * nrhs_);
5686                        let mut a_scratch: Vec<f32> = a_src.to_vec();
5687                        let mut x_buf: Vec<f32> = b_src.to_vec();
5688                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5689                        if info != 0 {
5690                            panic!("DenseSolveF32: singular (info={info})");
5691                        }
5692                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5693                    })
5694                }
5695
5696                Thunk::FusedMmBiasAct {
5697                    a,
5698                    w,
5699                    bias,
5700                    c,
5701                    m,
5702                    k,
5703                    n,
5704                    act,
5705                } => {
5706                    let (m, k, n) = (m as usize, k as usize, n as usize);
5707                    Arc::new(move |base: *mut u8| unsafe {
5708                        let out = sl_mut(c, base, m * n);
5709                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5710                        // Bias + activation epilogue. Gelu uses the fused
5711                        // `par_bias_gelu` kernel (bias add + Gelu in one
5712                        // pass). For everything else, do the bias add first
5713                        // and then apply the activation per-element. The
5714                        // pre-fix code dispatched `_ => bias_add` and dropped
5715                        // the activation entirely — silent correctness bug
5716                        // for Silu/Relu/Sigmoid/etc.
5717                        match act {
5718                            Some(Activation::Gelu) => {
5719                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5720                            }
5721                            Some(other) => {
5722                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
5723                                apply_activation_inplace(out, other);
5724                            }
5725                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5726                        }
5727                    })
5728                }
5729
5730                Thunk::FusedResidualLN {
5731                    x,
5732                    res,
5733                    bias,
5734                    g,
5735                    b,
5736                    out,
5737                    rows,
5738                    h,
5739                    eps,
5740                    has_bias,
5741                } => {
5742                    let (rows, h) = (rows as usize, h as usize);
5743                    Arc::new(move |base: *mut u8| unsafe {
5744                        let zero = vec![0f32; h]; // closure only — not hot path
5745                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
5746                        let xp = sl(x, base, rows * h).as_ptr() as usize;
5747                        let rp = sl(res, base, rows * h).as_ptr() as usize;
5748                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5749                        let bp = bi.as_ptr() as usize;
5750                        let gp = sl(g, base, h).as_ptr() as usize;
5751                        let bbp = sl(b, base, h).as_ptr() as usize;
5752                        crate::pool::par_for(rows, 4, &|off, cnt| {
5753                            let xs = std::slice::from_raw_parts(
5754                                (xp as *const f32).add(off * h),
5755                                cnt * h,
5756                            );
5757                            let rs = std::slice::from_raw_parts(
5758                                (rp as *const f32).add(off * h),
5759                                cnt * h,
5760                            );
5761                            let os = std::slice::from_raw_parts_mut(
5762                                (op as *mut f32).add(off * h),
5763                                cnt * h,
5764                            );
5765                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
5766                            let g = std::slice::from_raw_parts(gp as *const f32, h);
5767                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
5768                            crate::kernels::residual_bias_layer_norm(
5769                                xs, rs, bi, g, b, os, cnt, h, eps,
5770                            );
5771                        });
5772                    })
5773                }
5774
5775                Thunk::BiasAdd {
5776                    src,
5777                    bias,
5778                    dst,
5779                    m,
5780                    n,
5781                } => {
5782                    let (m, n) = (m as usize, n as usize);
5783                    let len = m * n;
5784                    Arc::new(move |base: *mut u8| unsafe {
5785                        let out = sl_mut(dst, base, len);
5786                        if src != dst {
5787                            let src_ptr = base.add(src) as *const f32;
5788                            let dst_ptr = base.add(dst) as *mut f32;
5789                            if src_ptr != dst_ptr {
5790                                std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5791                            }
5792                        }
5793                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
5794                    })
5795                }
5796
5797                Thunk::Gather {
5798                    table,
5799                    table_len,
5800                    idx,
5801                    dst,
5802                    num_idx,
5803                    trailing,
5804                    idx_i64,
5805                    table_bytes,
5806                } => {
5807                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5808                    let rows = tl / tr.max(1);
5809                    let (idx_i64, table_bytes) = (idx_i64, table_bytes);
5810                    Arc::new(move |base: *mut u8| unsafe {
5811                        if table_bytes == 8 {
5812                            let tab = sl_i64(table, base, tl);
5813                            let out = sl_mut_i64(dst, base, ni * tr);
5814                            if idx_i64 != 0 {
5815                                let ids = sl_i64(idx, base, ni);
5816                                for i in 0..ni {
5817                                    let row = ids[i].max(0) as usize;
5818                                    if row < rows {
5819                                        out[i * tr..(i + 1) * tr]
5820                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5821                                    }
5822                                }
5823                            } else {
5824                                let ids = sl(idx, base, ni);
5825                                for i in 0..ni {
5826                                    let row = ids[i] as usize;
5827                                    if row < rows {
5828                                        out[i * tr..(i + 1) * tr]
5829                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5830                                    }
5831                                }
5832                            }
5833                        } else {
5834                            let tab = sl(table, base, tl);
5835                            let out = sl_mut(dst, base, ni * tr);
5836                            if idx_i64 != 0 {
5837                                let ids = sl_i64(idx, base, ni);
5838                                for i in 0..ni {
5839                                    let row = ids[i].max(0) as usize;
5840                                    if row < rows {
5841                                        out[i * tr..(i + 1) * tr]
5842                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5843                                    }
5844                                }
5845                            } else {
5846                                let ids = sl(idx, base, ni);
5847                                for i in 0..ni {
5848                                    let row = ids[i] as usize;
5849                                    if row < rows {
5850                                        out[i * tr..(i + 1) * tr]
5851                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5852                                    }
5853                                }
5854                            }
5855                        }
5856                    })
5857                }
5858
5859                Thunk::Narrow {
5860                    src,
5861                    dst,
5862                    outer,
5863                    src_stride,
5864                    dst_stride,
5865                    inner,
5866                    elem_bytes,
5867                } => {
5868                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5869                }
5870
5871                Thunk::Copy { src, dst, len } => {
5872                    let len = len as usize;
5873                    Arc::new(move |base: *mut u8| unsafe {
5874                        if src == dst || len == 0 {
5875                            return;
5876                        }
5877                        let src_ptr = base.add(src) as *const f32;
5878                        let dst_ptr = base.add(dst) as *mut f32;
5879                        if src_ptr == dst_ptr {
5880                            return;
5881                        }
5882                        std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5883                    })
5884                }
5885
5886                Thunk::Softmax { data, rows, cols } => {
5887                    let (rows, cols) = (rows as usize, cols as usize);
5888                    Arc::new(move |base: *mut u8| unsafe {
5889                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5890                    })
5891                }
5892
5893                Thunk::Cumsum {
5894                    src,
5895                    dst,
5896                    rows,
5897                    cols,
5898                    exclusive,
5899                } => {
5900                    let (rows, cols) = (rows as usize, cols as usize);
5901                    Arc::new(move |base: *mut u8| unsafe {
5902                        let s = sl(src, base, rows * cols);
5903                        let d = sl_mut(dst, base, rows * cols);
5904                        if exclusive {
5905                            for r in 0..rows {
5906                                let mut acc = 0.0f32;
5907                                for c in 0..cols {
5908                                    d[r * cols + c] = acc;
5909                                    acc += s[r * cols + c];
5910                                }
5911                            }
5912                        } else {
5913                            for r in 0..rows {
5914                                let mut acc = 0.0f32;
5915                                for c in 0..cols {
5916                                    acc += s[r * cols + c];
5917                                    d[r * cols + c] = acc;
5918                                }
5919                            }
5920                        }
5921                    })
5922                }
5923
5924                Thunk::Sample {
5925                    logits,
5926                    dst,
5927                    batch,
5928                    vocab,
5929                    top_k,
5930                    top_p,
5931                    temperature,
5932                    seed,
5933                } => {
5934                    let (b, v) = (batch as usize, vocab as usize);
5935                    let k = (top_k as usize).min(v);
5936                    Arc::new(move |base: *mut u8| unsafe {
5937                        let lg = sl(logits, base, b * v);
5938                        let out = sl_mut(dst, base, b);
5939                        let mut rng =
5940                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5941                        for bi in 0..b {
5942                            let row = &lg[bi * v..(bi + 1) * v];
5943                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5944                        }
5945                    })
5946                }
5947
5948                Thunk::DequantMatMul {
5949                    x,
5950                    w_q,
5951                    scale,
5952                    zp,
5953                    dst,
5954                    m,
5955                    k,
5956                    n,
5957                    block_size,
5958                    is_asymmetric,
5959                } => {
5960                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5961                    let n_blocks_per_col = k.div_ceil(bs);
5962                    Arc::new(move |base: *mut u8| unsafe {
5963                        let xs = sl(x, base, m * k);
5964                        // w_q is packed i8 — use raw byte slice + reinterpret.
5965                        let raw = base.add(w_q);
5966                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5967                        let scales = sl(scale, base, n_blocks_per_col * n);
5968                        let zps = if is_asymmetric {
5969                            sl(zp, base, n_blocks_per_col * n)
5970                        } else {
5971                            &[][..]
5972                        };
5973                        let out = sl_mut(dst, base, m * n);
5974                        dequant_matmul_int8(
5975                            xs,
5976                            w_bytes,
5977                            scales,
5978                            zps,
5979                            out,
5980                            m,
5981                            k,
5982                            n,
5983                            bs,
5984                            is_asymmetric,
5985                        );
5986                    })
5987                }
5988
5989                Thunk::DequantMatMulGguf {
5990                    x,
5991                    w_q,
5992                    dst,
5993                    m,
5994                    k,
5995                    n,
5996                    scheme,
5997                } => {
5998                    let (m, k, n) = (m as usize, k as usize, n as usize);
5999                    let block_bytes = scheme.gguf_block_bytes() as usize;
6000                    let block_elems = scheme.gguf_block_size() as usize;
6001                    let total_bytes = (k * n) / block_elems * block_bytes;
6002                    Arc::new(move |base: *mut u8| unsafe {
6003                        let xs = sl(x, base, m * k);
6004                        let w_bytes =
6005                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6006                        let out = sl_mut(dst, base, m * n);
6007                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6008                    })
6009                }
6010
6011                Thunk::DequantMatMulInt4 {
6012                    x,
6013                    w_q,
6014                    scale,
6015                    zp,
6016                    dst,
6017                    m,
6018                    k,
6019                    n,
6020                    block_size,
6021                    is_asymmetric,
6022                } => {
6023                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6024                    let n_blocks = k.div_ceil(bs);
6025                    Arc::new(move |base: *mut u8| unsafe {
6026                        let xs = sl(x, base, m * k);
6027                        let w_bytes = std::slice::from_raw_parts(
6028                            base.add(w_q) as *const u8,
6029                            (k * n).div_ceil(2),
6030                        );
6031                        let scales = sl(scale, base, n_blocks * n);
6032                        let zps = if is_asymmetric {
6033                            sl(zp, base, n_blocks * n)
6034                        } else {
6035                            &[][..]
6036                        };
6037                        let out = sl_mut(dst, base, m * n);
6038                        dequant_matmul_int4(
6039                            xs,
6040                            w_bytes,
6041                            scales,
6042                            zps,
6043                            out,
6044                            m,
6045                            k,
6046                            n,
6047                            bs,
6048                            is_asymmetric,
6049                        );
6050                    })
6051                }
6052
6053                Thunk::DequantMatMulFp8 {
6054                    x,
6055                    w_q,
6056                    scale,
6057                    dst,
6058                    m,
6059                    k,
6060                    n,
6061                    e5m2,
6062                } => {
6063                    let (m, k, n) = (m as usize, k as usize, n as usize);
6064                    Arc::new(move |base: *mut u8| unsafe {
6065                        let xs = sl(x, base, m * k);
6066                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6067                        let scales = sl(scale, base, n);
6068                        let out = sl_mut(dst, base, m * n);
6069                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6070                    })
6071                }
6072
6073                Thunk::DequantMatMulNvfp4 {
6074                    x,
6075                    w_q,
6076                    scale,
6077                    global_scale,
6078                    dst,
6079                    m,
6080                    k,
6081                    n,
6082                } => {
6083                    let (m, k, n) = (m as usize, k as usize, n as usize);
6084                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6085                    Arc::new(move |base: *mut u8| unsafe {
6086                        let xs = sl(x, base, m * k);
6087                        let w_bytes = std::slice::from_raw_parts(
6088                            base.add(w_q) as *const u8,
6089                            (k * n).div_ceil(2),
6090                        );
6091                        let scale_bytes =
6092                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6093                        let gs = sl(global_scale, base, 1)[0];
6094                        let out = sl_mut(dst, base, m * n);
6095                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6096                    })
6097                }
6098
6099                Thunk::LoraMatMul {
6100                    x,
6101                    w,
6102                    a,
6103                    b,
6104                    dst,
6105                    m,
6106                    k,
6107                    n,
6108                    r,
6109                    scale,
6110                } => {
6111                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6112                    Arc::new(move |base: *mut u8| unsafe {
6113                        let xs = sl(x, base, m * k);
6114                        let ws = sl(w, base, k * n);
6115                        let a_s = sl(a, base, k * r);
6116                        let bs = sl(b, base, r * n);
6117                        let out = sl_mut(dst, base, m * n);
6118                        // Step 1: out = x · W.
6119                        crate::blas::sgemm(xs, ws, out, m, k, n);
6120                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
6121                        let mut tmp = vec![0f32; m * r];
6122                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6123                        // Step 3: out += scale * (tmp · B).
6124                        // sgemm_accumulate uses alpha=1.0 internally, so
6125                        // scale tmp first.
6126                        if scale != 1.0 {
6127                            for v in tmp.iter_mut() {
6128                                *v *= scale;
6129                            }
6130                        }
6131                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6132                    })
6133                }
6134
6135                Thunk::LayerNorm {
6136                    src,
6137                    g,
6138                    b,
6139                    dst,
6140                    rows,
6141                    h,
6142                    eps,
6143                } => {
6144                    let (rows, h) = (rows as usize, h as usize);
6145                    Arc::new(move |base: *mut u8| unsafe {
6146                        let inp = sl(src, base, rows * h);
6147                        let gamma = sl(g, base, h);
6148                        let beta = sl(b, base, h);
6149                        let out = sl_mut(dst, base, rows * h);
6150                        for row in 0..rows {
6151                            crate::kernels::layer_norm_row(
6152                                &inp[row * h..(row + 1) * h],
6153                                gamma,
6154                                beta,
6155                                &mut out[row * h..(row + 1) * h],
6156                                h,
6157                                eps,
6158                            );
6159                        }
6160                    })
6161                }
6162
6163                Thunk::BatchNormInference {
6164                    src,
6165                    g,
6166                    b,
6167                    mean,
6168                    var,
6169                    dst,
6170                    count,
6171                    channels,
6172                    eps,
6173                } => {
6174                    let count = count as usize;
6175                    let c = channels as usize;
6176                    let n = count * c;
6177                    let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6178                    Arc::new(move |base: *mut u8| unsafe {
6179                        crate::kernels::batch_norm_inference(
6180                            sl(src, base, n),
6181                            sl(g, base, c),
6182                            sl(b, base, c),
6183                            sl(mean, base, c),
6184                            sl(var, base, c),
6185                            sl_mut(dst, base, n),
6186                            c,
6187                            eps,
6188                        );
6189                    })
6190                }
6191
6192                Thunk::Attention {
6193                    q,
6194                    k,
6195                    v,
6196                    mask,
6197                    out,
6198                    batch,
6199                    seq,
6200                    kv_seq,
6201                    heads,
6202                    head_dim,
6203                    mask_kind,
6204                    q_row_stride,
6205                    k_row_stride,
6206                    v_row_stride,
6207                    bhsd,
6208                } => {
6209                    if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6210                        eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6211                    }
6212                    // Q seq length (`q_s`) and K/V seq length (`k_s`) differ
6213                    // during cached decode (`q_s=1`, `k_s=past_seq+1`). The
6214                    // earlier version of this kernel destructured
6215                    // `kv_seq: _` and used a single `s = seq` for both axes,
6216                    // so cached decode only scored 1×1 instead of 1×k_s —
6217                    // attention couldn't see the past K cache and decode
6218                    // collapsed into repetitive fragments
6219                    // (`Self-based on [1\nAnswer: Self-based on [1…`).
6220                    let (b, q_s, k_s, nh, dh) = (
6221                        batch as usize,
6222                        seq as usize,
6223                        kv_seq as usize,
6224                        heads as usize,
6225                        head_dim as usize,
6226                    );
6227                    let hs = nh * dh;
6228                    let qrs = q_row_stride as usize;
6229                    let krs = k_row_stride as usize;
6230                    let vrs = v_row_stride as usize;
6231                    let scale = (dh as f32).powf(-0.5);
6232                    Arc::new(move |base: *mut u8| unsafe {
6233                        if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6234                            eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6235                        }
6236                        // Slice lengths use the source's row stride so the
6237                        // compiler-emitted bounds checks cover the whole
6238                        // strided span (the kernel walks with q/k/v_rs).
6239                        // For [B, H, S, D] the buffer is dense B*H*S*D.
6240                        let (q_len, k_len, v_len, o_len) = if bhsd {
6241                            let qn = b * nh * q_s * dh;
6242                            let kn = b * nh * k_s * dh;
6243                            (qn, kn, kn, qn)
6244                        } else {
6245                            (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6246                        };
6247                        let q_d = sl(q, base, q_len);
6248                        let k_d = sl(k, base, k_len);
6249                        let v_d = sl(v, base, v_len);
6250                        let m_d: &[f32] = match mask_kind {
6251                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6252                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6253                            _ => &[],
6254                        };
6255                        let o_d = sl_mut(out, base, o_len);
6256                        let mut qh = vec![0f32; q_s * dh];
6257                        let mut kh = vec![0f32; k_s * dh];
6258                        let mut vh = vec![0f32; k_s * dh];
6259                        let mut sc = vec![0f32; q_s * k_s];
6260                        let mut oh = vec![0f32; q_s * dh];
6261                        for bi in 0..b {
6262                            for hi in 0..nh {
6263                                // Gather per-head Q.
6264                                for si in 0..q_s {
6265                                    let q_off = if bhsd {
6266                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6267                                    } else {
6268                                        bi * q_s * qrs + si * qrs + hi * dh
6269                                    };
6270                                    qh[si * dh..(si + 1) * dh]
6271                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
6272                                }
6273                                // Gather per-head K, V.
6274                                for si in 0..k_s {
6275                                    let (k_off, v_off) = if bhsd {
6276                                        (
6277                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6278                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6279                                        )
6280                                    } else {
6281                                        (
6282                                            bi * k_s * krs + si * krs + hi * dh,
6283                                            bi * k_s * vrs + si * vrs + hi * dh,
6284                                        )
6285                                    };
6286                                    kh[si * dh..(si + 1) * dh]
6287                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
6288                                    vh[si * dh..(si + 1) * dh]
6289                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
6290                                }
6291                                for qi in 0..q_s {
6292                                    for ki in 0..k_s {
6293                                        let mut dot = 0f32;
6294                                        for d in 0..dh {
6295                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
6296                                        }
6297                                        sc[qi * k_s + ki] = dot * scale;
6298                                    }
6299                                }
6300                                // Apply mask. Causal/SlidingWindow use absolute
6301                                // positions so they handle Lq != Lk (decode mode
6302                                // with cached K/V): q_offset = k_s - q_s.
6303                                let q_offset = k_s.saturating_sub(q_s);
6304                                match mask_kind {
6305                                    rlx_ir::op::MaskKind::None => {}
6306                                    rlx_ir::op::MaskKind::Causal => {
6307                                        for qi in 0..q_s {
6308                                            let abs_q = q_offset + qi;
6309                                            for ki in (abs_q + 1)..k_s {
6310                                                sc[qi * k_s + ki] = mask_neg;
6311                                            }
6312                                        }
6313                                    }
6314                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
6315                                        for qi in 0..q_s {
6316                                            let abs_q = q_offset + qi;
6317                                            let lo = abs_q.saturating_sub(w);
6318                                            for ki in 0..k_s {
6319                                                if ki < lo || ki > abs_q {
6320                                                    sc[qi * k_s + ki] = mask_neg;
6321                                                }
6322                                            }
6323                                        }
6324                                    }
6325                                    rlx_ir::op::MaskKind::Custom => {
6326                                        for qi in 0..q_s {
6327                                            for ki in 0..k_s {
6328                                                if m_d[bi * k_s + ki] < mask_thr {
6329                                                    sc[qi * k_s + ki] = mask_neg;
6330                                                }
6331                                            }
6332                                        }
6333                                    }
6334                                    rlx_ir::op::MaskKind::Bias => {
6335                                        let per_bh = q_s * k_s;
6336                                        let off = (bi * nh + hi) * per_bh;
6337                                        for i in 0..per_bh {
6338                                            sc[i] += m_d[off + i];
6339                                        }
6340                                    }
6341                                }
6342                                crate::naive::softmax(&mut sc, q_s, k_s);
6343                                oh.fill(0.0);
6344                                for qi in 0..q_s {
6345                                    for ki in 0..k_s {
6346                                        let w = sc[qi * k_s + ki];
6347                                        if w > score_skip {
6348                                            for d in 0..dh {
6349                                                oh[qi * dh + d] += w * vh[ki * dh + d];
6350                                            }
6351                                        }
6352                                    }
6353                                }
6354                                for si in 0..q_s {
6355                                    let off = if bhsd {
6356                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6357                                    } else {
6358                                        bi * q_s * hs + si * hs + hi * dh
6359                                    };
6360                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6361                                }
6362                            }
6363                        }
6364                    })
6365                }
6366
6367                Thunk::FusedSwiGLU {
6368                    src,
6369                    dst,
6370                    n_half,
6371                    total,
6372                    gate_first,
6373                } => {
6374                    let n = n_half as usize;
6375                    let t = total as usize;
6376                    let outer = t / n;
6377                    let in_total = outer * 2 * n;
6378                    Arc::new(move |base: *mut u8| unsafe {
6379                        let inp = sl(src, base, in_total);
6380                        let out = sl_mut(dst, base, t);
6381                        for o in 0..outer {
6382                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6383                            let out_row = &mut out[o * n..(o + 1) * n];
6384                            for i in 0..n {
6385                                let (up, gate) = if gate_first {
6386                                    (in_row[n + i], in_row[i])
6387                                } else {
6388                                    (in_row[i], in_row[n + i])
6389                                };
6390                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6391                            }
6392                        }
6393                    })
6394                }
6395
6396                Thunk::Concat {
6397                    dst,
6398                    outer,
6399                    inner,
6400                    total_axis,
6401                    inputs,
6402                } => {
6403                    let outer = outer as usize;
6404                    let inner = inner as usize;
6405                    let total_axis = total_axis as usize;
6406                    let out_total = outer * total_axis * inner;
6407                    // Pre-compute the destination row offset for each input
6408                    // (cumulative axis offsets times inner).
6409                    let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
6410                    let mut cum: usize = 0;
6411                    for (src_off, in_axis) in &inputs {
6412                        let in_axis = *in_axis as usize;
6413                        layout.push((*src_off, cum * inner, in_axis * inner));
6414                        cum += in_axis;
6415                    }
6416                    Arc::new(move |base: *mut u8| unsafe {
6417                        let out = sl_mut(dst, base, out_total);
6418                        let row_stride = total_axis * inner;
6419                        for (src_off, dst_col_off, copy_per_row) in &layout {
6420                            let in_total = outer * *copy_per_row;
6421                            let inp = sl(*src_off, base, in_total);
6422                            for o in 0..outer {
6423                                let dst_row_start = o * row_stride + *dst_col_off;
6424                                let src_row_start = o * *copy_per_row;
6425                                out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
6426                                    &inp[src_row_start..src_row_start + *copy_per_row],
6427                                );
6428                            }
6429                        }
6430                    })
6431                }
6432
6433                Thunk::CustomOp {
6434                    kernel,
6435                    inputs,
6436                    output,
6437                    attrs,
6438                } => {
6439                    // Capture-by-move: clone the Arc and Vecs once into the
6440                    // closure. Dispatch by output dtype each call (the
6441                    // dtype is fixed at compile time but it's cheaper to
6442                    // branch once per execution than to monomorphize a
6443                    // dozen closure variants).
6444                    let kernel = kernel.clone();
6445                    let attrs = attrs.clone();
6446                    let inputs = inputs.clone();
6447                    let (out_off, out_len, out_shape) = output.clone();
6448                    Arc::new(move |base: *mut u8| unsafe {
6449                        dispatch_custom_op(
6450                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6451                        );
6452                    })
6453                }
6454
6455                Thunk::GaussianSplatRender {
6456                    positions_off,
6457                    positions_len,
6458                    scales_off,
6459                    scales_len,
6460                    rotations_off,
6461                    rotations_len,
6462                    opacities_off,
6463                    opacities_len,
6464                    colors_off,
6465                    colors_len,
6466                    sh_coeffs_off,
6467                    sh_coeffs_len,
6468                    meta_off,
6469                    dst_off,
6470                    dst_len,
6471                    width,
6472                    height,
6473                    tile_size,
6474                    radius_scale,
6475                    alpha_cutoff,
6476                    max_splat_steps,
6477                    transmittance_threshold,
6478                    max_list_entries,
6479                } => Arc::new(move |base: *mut u8| unsafe {
6480                    crate::splat::execute_gaussian_splat_render(
6481                        positions_off,
6482                        positions_len,
6483                        scales_off,
6484                        scales_len,
6485                        rotations_off,
6486                        rotations_len,
6487                        opacities_off,
6488                        opacities_len,
6489                        colors_off,
6490                        colors_len,
6491                        sh_coeffs_off,
6492                        sh_coeffs_len,
6493                        meta_off,
6494                        dst_off,
6495                        dst_len,
6496                        width,
6497                        height,
6498                        tile_size,
6499                        radius_scale,
6500                        alpha_cutoff,
6501                        max_splat_steps,
6502                        transmittance_threshold,
6503                        max_list_entries,
6504                        base,
6505                    );
6506                }),
6507
6508                Thunk::GaussianSplatRenderBackward {
6509                    positions_off,
6510                    positions_len,
6511                    scales_off,
6512                    scales_len,
6513                    rotations_off,
6514                    rotations_len,
6515                    opacities_off,
6516                    opacities_len,
6517                    colors_off,
6518                    colors_len,
6519                    sh_coeffs_off,
6520                    sh_coeffs_len,
6521                    meta_off,
6522                    d_loss_off,
6523                    d_loss_len,
6524                    packed_off,
6525                    packed_len,
6526                    width,
6527                    height,
6528                    tile_size,
6529                    radius_scale,
6530                    alpha_cutoff,
6531                    max_splat_steps,
6532                    transmittance_threshold,
6533                    max_list_entries,
6534                    loss_grad_clip,
6535                    sh_band,
6536                    max_anisotropy,
6537                } => Arc::new(move |base: *mut u8| unsafe {
6538                    crate::splat::execute_gaussian_splat_render_backward(
6539                        positions_off,
6540                        positions_len,
6541                        scales_off,
6542                        scales_len,
6543                        rotations_off,
6544                        rotations_len,
6545                        opacities_off,
6546                        opacities_len,
6547                        colors_off,
6548                        colors_len,
6549                        sh_coeffs_off,
6550                        sh_coeffs_len,
6551                        meta_off,
6552                        d_loss_off,
6553                        d_loss_len,
6554                        packed_off,
6555                        packed_len,
6556                        width,
6557                        height,
6558                        tile_size,
6559                        radius_scale,
6560                        alpha_cutoff,
6561                        max_splat_steps,
6562                        transmittance_threshold,
6563                        max_list_entries,
6564                        loss_grad_clip,
6565                        sh_band,
6566                        max_anisotropy,
6567                        base,
6568                    );
6569                }),
6570
6571                Thunk::GaussianSplatPrepare {
6572                    positions_off,
6573                    positions_len,
6574                    scales_off,
6575                    scales_len,
6576                    rotations_off,
6577                    rotations_len,
6578                    opacities_off,
6579                    opacities_len,
6580                    colors_off,
6581                    colors_len,
6582                    sh_coeffs_off,
6583                    sh_coeffs_len,
6584                    meta_off,
6585                    meta_len,
6586                    prep_off,
6587                    prep_len,
6588                    width,
6589                    height,
6590                    tile_size,
6591                    radius_scale,
6592                    alpha_cutoff,
6593                    max_splat_steps,
6594                    transmittance_threshold,
6595                    max_list_entries,
6596                } => Arc::new(move |base: *mut u8| unsafe {
6597                    crate::splat::execute_gaussian_splat_prepare(
6598                        positions_off,
6599                        positions_len,
6600                        scales_off,
6601                        scales_len,
6602                        rotations_off,
6603                        rotations_len,
6604                        opacities_off,
6605                        opacities_len,
6606                        colors_off,
6607                        colors_len,
6608                        sh_coeffs_off,
6609                        sh_coeffs_len,
6610                        meta_off,
6611                        meta_len,
6612                        prep_off,
6613                        prep_len,
6614                        width,
6615                        height,
6616                        tile_size,
6617                        radius_scale,
6618                        alpha_cutoff,
6619                        max_splat_steps,
6620                        transmittance_threshold,
6621                        max_list_entries,
6622                        base,
6623                    );
6624                }),
6625
6626                Thunk::GaussianSplatRasterize {
6627                    prep_off,
6628                    prep_len,
6629                    meta_off,
6630                    meta_len,
6631                    dst_off,
6632                    dst_len,
6633                    count,
6634                    width,
6635                    height,
6636                    tile_size,
6637                    alpha_cutoff,
6638                    max_splat_steps,
6639                    transmittance_threshold,
6640                    max_list_entries,
6641                } => Arc::new(move |base: *mut u8| unsafe {
6642                    crate::splat::execute_gaussian_splat_rasterize(
6643                        prep_off,
6644                        prep_len,
6645                        meta_off,
6646                        meta_len,
6647                        dst_off,
6648                        dst_len,
6649                        count,
6650                        width,
6651                        height,
6652                        tile_size,
6653                        alpha_cutoff,
6654                        max_splat_steps,
6655                        transmittance_threshold,
6656                        max_list_entries,
6657                        base,
6658                    );
6659                }),
6660
6661                Thunk::Fft1d {
6662                    src,
6663                    dst,
6664                    outer,
6665                    n_complex,
6666                    inverse,
6667                    norm_tag,
6668                    dtype,
6669                } => {
6670                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6671                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6672                            execute_fft1d_f64(
6673                                src,
6674                                dst,
6675                                outer as usize,
6676                                n_complex as usize,
6677                                inverse,
6678                                norm_tag,
6679                                base,
6680                            );
6681                        }),
6682                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6683                            execute_fft1d_f32(
6684                                src,
6685                                dst,
6686                                outer as usize,
6687                                n_complex as usize,
6688                                inverse,
6689                                norm_tag,
6690                                base,
6691                            );
6692                        }),
6693                        rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6694                            execute_fft1d_c64(
6695                                src,
6696                                dst,
6697                                outer as usize,
6698                                n_complex as usize,
6699                                inverse,
6700                                norm_tag,
6701                                base,
6702                            );
6703                        }),
6704                        other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6705                    };
6706                    f
6707                }
6708
6709                Thunk::FftButterflyStage {
6710                    state_src,
6711                    state_dst,
6712                    gate_src,
6713                    rev_src,
6714                    tw_re_src,
6715                    tw_im_src,
6716                    batch,
6717                    n_fft,
6718                    stage,
6719                } => Arc::new(move |base: *mut u8| unsafe {
6720                    execute_fft_butterfly_stage_f32(
6721                        state_src,
6722                        state_dst,
6723                        gate_src,
6724                        rev_src,
6725                        tw_re_src,
6726                        tw_im_src,
6727                        batch as usize,
6728                        n_fft as usize,
6729                        stage as usize,
6730                        base,
6731                    );
6732                }),
6733
6734                Thunk::LogMel {
6735                    spec,
6736                    filters,
6737                    dst,
6738                    outer,
6739                    n_fft,
6740                    n_bins,
6741                    n_mels,
6742                } => Arc::new(move |base: *mut u8| unsafe {
6743                    execute_log_mel_f32(
6744                        spec,
6745                        filters,
6746                        dst,
6747                        outer as usize,
6748                        n_fft as usize,
6749                        n_bins as usize,
6750                        n_mels as usize,
6751                        base,
6752                    );
6753                }),
6754
6755                Thunk::LogMelBackward {
6756                    spec,
6757                    filters,
6758                    dy,
6759                    dst,
6760                    outer,
6761                    n_fft,
6762                    n_bins,
6763                    n_mels,
6764                } => Arc::new(move |base: *mut u8| unsafe {
6765                    execute_log_mel_backward_f32(
6766                        spec,
6767                        filters,
6768                        dy,
6769                        dst,
6770                        outer as usize,
6771                        n_fft as usize,
6772                        n_bins as usize,
6773                        n_mels as usize,
6774                        base,
6775                    );
6776                }),
6777
6778                _ => Arc::new(|_: *mut u8| {}),
6779            }
6780        })
6781        .collect();
6782
6783    // ── Thunk-level attention fusion ──────────────────────
6784    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
6785    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
6786    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6787        .and_then(|v| v.parse().ok())
6788        .unwrap_or(64);
6789    let should_fuse = thunks.iter().any(|t| match t {
6790        Thunk::Attention { batch, seq, .. } => {
6791            (*batch as usize) * (*seq as usize) <= fuse_threshold
6792        }
6793        _ => false,
6794    });
6795
6796    if should_fuse {
6797        // Build non-Nop index for pattern matching across Nop gaps
6798        let active: Vec<usize> = thunks
6799            .iter()
6800            .enumerate()
6801            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6802            .map(|(i, _)| i)
6803            .collect();
6804
6805        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
6806        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
6807
6808        let mut ai = 0;
6809        while ai < active.len() {
6810            // Helper: get active thunk at offset from current
6811            let a = |off: usize| -> Option<(usize, &Thunk)> {
6812                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6813            };
6814
6815            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
6816            let matched = (|| {
6817                let (_i0, t0) = a(0)?;
6818                let (_, t1) = a(1)?;
6819                let (_, t2) = a(2)?;
6820                let (_, t3) = a(3)?;
6821
6822                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
6823                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6824                    Thunk::FusedMmBiasAct {
6825                        a,
6826                        w,
6827                        bias,
6828                        n: _,
6829                        act: None,
6830                        ..
6831                    } => (*a, *w, *bias, true),
6832                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6833                    _ => return None,
6834                };
6835
6836                // a[1..3] must be Narrows
6837                if !matches!(t1, Thunk::Narrow { .. }) {
6838                    return None;
6839                }
6840                if !matches!(t2, Thunk::Narrow { .. }) {
6841                    return None;
6842                }
6843                if !matches!(t3, Thunk::Narrow { .. }) {
6844                    return None;
6845                }
6846
6847                // Look for optional Rope×2 then Attention
6848                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6849                    _,
6850                    Thunk::Rope {
6851                        cos, sin, cos_len, ..
6852                    },
6853                )) = a(4)
6854                {
6855                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6856                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6857                            (true, 6, *cos, *sin, *cos_len)
6858                        } else {
6859                            return None;
6860                        }
6861                    } else {
6862                        return None;
6863                    }
6864                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6865                    (false, 4, 0, 0, 0)
6866                } else {
6867                    return None;
6868                };
6869
6870                let (_attn_real_idx, attn_t) = a(attn_ai)?;
6871                let (batch, seq, heads, head_dim, mask) = match attn_t {
6872                    Thunk::Attention {
6873                        batch,
6874                        seq,
6875                        heads,
6876                        head_dim,
6877                        mask,
6878                        ..
6879                    } => (*batch, *seq, *heads, *head_dim, *mask),
6880                    _ => return None,
6881                };
6882
6883                // Next active must be out projection (FusedMmBiasAct or Sgemm)
6884                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6885                let (out_w, out_b, out_dst) = match out_t {
6886                    Thunk::FusedMmBiasAct {
6887                        w,
6888                        bias,
6889                        c,
6890                        act: None,
6891                        ..
6892                    } => (*w, *bias, *c),
6893                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6894                    _ => return None,
6895                };
6896
6897                let hs = heads * head_dim;
6898                let total_active = attn_ai + 2; // number of active thunks consumed
6899
6900                Some((
6901                    total_active,
6902                    Thunk::FusedAttnBlock {
6903                        hidden,
6904                        qkv_w,
6905                        out_w,
6906                        mask,
6907                        out: out_dst,
6908                        qkv_b: if has_b { qkv_b } else { 0 },
6909                        out_b: if has_b { out_b } else { 0 },
6910                        cos: cos_off,
6911                        sin: sin_off,
6912                        cos_len: cl,
6913                        batch,
6914                        seq,
6915                        hs,
6916                        nh: heads,
6917                        dh: head_dim,
6918                        has_bias: has_b,
6919                        has_rope,
6920                    },
6921                ))
6922            })();
6923
6924            if let Some((count, fused_thunk)) = matched {
6925                // Mark consumed thunks for removal
6926                for off in 0..count {
6927                    if let Some(&idx) = active.get(ai + off) {
6928                        kill[idx] = true;
6929                    }
6930                }
6931                // Insert replacement at position of the QKV thunk
6932                insertions.push((active[ai], fused_thunk));
6933                ai += count;
6934            } else {
6935                ai += 1;
6936            }
6937        }
6938
6939        // Rebuild thunk list: keep non-killed, insert fused at right positions
6940        if !insertions.is_empty() {
6941            let mut new_thunks = Vec::with_capacity(thunks.len());
6942            let mut insert_idx = 0;
6943            for (i, t) in thunks.into_iter().enumerate() {
6944                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6945                    new_thunks.push(insertions[insert_idx].1.clone());
6946                    insert_idx += 1;
6947                }
6948                if !kill[i] {
6949                    new_thunks.push(t);
6950                }
6951            }
6952            if cfg.verbose >= 1 {
6953                eprintln!(
6954                    "[rlx] fused_attention: {} attention blocks fused",
6955                    insertions.len()
6956                );
6957            }
6958            thunks = new_thunks;
6959        }
6960    }
6961
6962    // ── Full layer fusion ──────────────────────────────────
6963    // After attention blocks are fused, scan for full layer patterns:
6964    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
6965    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
6966    if should_fuse {
6967        let active: Vec<usize> = thunks
6968            .iter()
6969            .enumerate()
6970            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6971            .map(|(i, _)| i)
6972            .collect();
6973
6974        let mut kill = vec![false; thunks.len()];
6975        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6976
6977        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6978
6979        let mut ai = 0;
6980        while ai < active.len() {
6981            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
6982            let bert_match = (|| -> Option<usize> {
6983                let fab = a(ai)?;
6984                let rln1 = a(ai + 1)?;
6985                let ffn1 = a(ai + 2)?;
6986                let ffn2 = a(ai + 3)?;
6987                let rln2 = a(ai + 4)?;
6988
6989                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6990                    Thunk::FusedAttnBlock {
6991                        hidden,
6992                        qkv_w,
6993                        qkv_b,
6994                        out_w,
6995                        out_b,
6996                        mask,
6997                        batch,
6998                        seq,
6999                        hs,
7000                        nh,
7001                        dh,
7002                        has_bias: true,
7003                        has_rope: false,
7004                        ..
7005                    } => (
7006                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7007                    ),
7008                    _ => return None,
7009                };
7010                let (ln1_g, ln1_b, eps1) = match rln1 {
7011                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7012                    _ => return None,
7013                };
7014                let (fc1_w, fc1_b, int_dim) = match ffn1 {
7015                    Thunk::FusedMmBiasAct {
7016                        w,
7017                        bias,
7018                        n,
7019                        act: Some(Activation::Gelu),
7020                        ..
7021                    } => (*w, *bias, *n),
7022                    _ => return None,
7023                };
7024                let (fc2_w, fc2_b) = match ffn2 {
7025                    Thunk::FusedMmBiasAct {
7026                        w, bias, act: None, ..
7027                    } => (*w, *bias),
7028                    _ => return None,
7029                };
7030                let (ln2_g, ln2_b, eps2, out) = match rln2 {
7031                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7032                    _ => return None,
7033                };
7034
7035                for off in 0..5 {
7036                    kill[active[ai + off]] = true;
7037                }
7038                insertions.push((
7039                    active[ai],
7040                    Thunk::FusedBertLayer {
7041                        hidden,
7042                        qkv_w,
7043                        qkv_b,
7044                        out_w,
7045                        out_b,
7046                        mask,
7047                        ln1_g,
7048                        ln1_b,
7049                        eps1,
7050                        fc1_w,
7051                        fc1_b,
7052                        fc2_w,
7053                        fc2_b,
7054                        ln2_g,
7055                        ln2_b,
7056                        eps2,
7057                        out,
7058                        batch,
7059                        seq,
7060                        hs,
7061                        nh,
7062                        dh,
7063                        int_dim,
7064                    },
7065                ));
7066                Some(5)
7067            })();
7068            if let Some(n) = bert_match {
7069                ai += n;
7070                continue;
7071            }
7072
7073            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
7074            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
7075            // The body below is kept as reference for when the stride bug is fixed.
7076            #[allow(unreachable_code)]
7077            let nomic_match = (|| -> Option<usize> {
7078                return None; // TODO: fix SwiGLU strided fc2 output mismatch
7079                let fab = a(ai)?;
7080                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7081                    match fab {
7082                        Thunk::FusedAttnBlock {
7083                            hidden,
7084                            qkv_w,
7085                            out_w,
7086                            mask,
7087                            cos,
7088                            sin,
7089                            cos_len,
7090                            batch,
7091                            seq,
7092                            hs,
7093                            nh,
7094                            dh,
7095                            has_bias: false,
7096                            has_rope: true,
7097                            ..
7098                        } => (
7099                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7100                            *hs, *nh, *dh,
7101                        ),
7102                        _ => return None,
7103                    };
7104                // FusedResidualLN for LN1
7105                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7106                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7107                    _ => return None,
7108                };
7109                // Sgemm (fused fc11+fc12)
7110                let fused_fc_w = match a(ai + 2)? {
7111                    Thunk::Sgemm { b: w, .. } => *w,
7112                    _ => return None,
7113                };
7114                // Narrow×2 for split
7115                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7116                    return None;
7117                }
7118                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7119                    return None;
7120                }
7121                // SiLU
7122                if !matches!(
7123                    a(ai + 5)?,
7124                    Thunk::ActivationInPlace {
7125                        act: Activation::Silu,
7126                        ..
7127                    }
7128                ) {
7129                    return None;
7130                }
7131                // BinaryFull(Mul) for gate
7132                if !matches!(
7133                    a(ai + 6)?,
7134                    Thunk::BinaryFull {
7135                        op: BinaryOp::Mul,
7136                        ..
7137                    }
7138                ) {
7139                    return None;
7140                }
7141                // Sgemm (fc2)
7142                let fc2_w = match a(ai + 7)? {
7143                    Thunk::Sgemm { b: w, .. } => *w,
7144                    _ => return None,
7145                };
7146                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
7147                let int_dim = match a(ai + 3)? {
7148                    Thunk::Narrow { inner, .. } => *inner,
7149                    _ => return None,
7150                };
7151                // FusedResidualLN for LN2
7152                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7153                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7154                    _ => return None,
7155                };
7156
7157                for off in 0..9 {
7158                    kill[active[ai + off]] = true;
7159                }
7160                insertions.push((
7161                    active[ai],
7162                    Thunk::FusedNomicLayer {
7163                        hidden,
7164                        qkv_w,
7165                        out_w,
7166                        mask,
7167                        cos,
7168                        sin,
7169                        cos_len,
7170                        ln1_g,
7171                        ln1_b,
7172                        eps1,
7173                        fc11_w: fused_fc_w,
7174                        fc12_w: 0,
7175                        fc2_w,
7176                        ln2_g,
7177                        ln2_b,
7178                        eps2,
7179                        out,
7180                        batch,
7181                        seq,
7182                        hs,
7183                        nh,
7184                        dh,
7185                        int_dim,
7186                    },
7187                ));
7188                Some(9)
7189            })();
7190            if let Some(n) = nomic_match {
7191                ai += n;
7192                continue;
7193            }
7194
7195            ai += 1;
7196        }
7197
7198        if !insertions.is_empty() {
7199            let mut new_thunks = Vec::with_capacity(thunks.len());
7200            let mut ins_idx = 0;
7201            for (i, t) in thunks.into_iter().enumerate() {
7202                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7203                    new_thunks.push(insertions[ins_idx].1.clone());
7204                    ins_idx += 1;
7205                }
7206                if !kill[i] {
7207                    new_thunks.push(t);
7208                }
7209            }
7210            if cfg.verbose >= 1 {
7211                eprintln!(
7212                    "[rlx] fused_layer: {} full transformer layers fused",
7213                    insertions.len()
7214                );
7215            }
7216            thunks = new_thunks;
7217        }
7218    }
7219
7220    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
7221    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
7222    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
7223    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
7224    // preceding Narrow whose dst has no other consumer in this schedule.
7225    // Rewrite Rope to read directly from the parent buffer with the
7226    // parent's row stride; the Narrow becomes a Nop.
7227    //
7228    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
7229    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
7230    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
7231    {
7232        // Collect every byte-offset that's read as a thunk's `src` so
7233        // we know whether a Narrow's dst has consumers other than Rope.
7234        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7235        for t in &thunks {
7236            for off in thunk_read_offsets(t) {
7237                *read_offsets.entry(off).or_insert(0) += 1;
7238            }
7239        }
7240
7241        let mut fused_count = 0usize;
7242        for i in 0..thunks.len().saturating_sub(1) {
7243            // Look for Rope at i+1 reading from Narrow at i (skip Nops
7244            // between them since the planner left them in place).
7245            let narrow = match &thunks[i] {
7246                Thunk::Narrow { .. } => i,
7247                _ => continue,
7248            };
7249            // Find the next non-Nop thunk
7250            let mut j = narrow + 1;
7251            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7252                j += 1;
7253            }
7254            if j >= thunks.len() {
7255                continue;
7256            }
7257            // Must be Rope reading Narrow's dst
7258            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7259                Thunk::Narrow {
7260                    src,
7261                    dst,
7262                    src_stride,
7263                    ..
7264                } => (*src, *dst, *src_stride),
7265                _ => continue,
7266            };
7267            let rope_reads_narrow = matches!(&thunks[j],
7268                Thunk::Rope { src, .. } if *src == n_dst);
7269            if !rope_reads_narrow {
7270                continue;
7271            }
7272            // Conservatively require that the Narrow's dst has exactly
7273            // one reader (the Rope). Anything else and rewriting would
7274            // skip a needed write.
7275            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7276                continue;
7277            }
7278
7279            // Rewire: Rope reads from Narrow's adjusted source with the
7280            // parent buffer's row stride.
7281            if let Thunk::Rope {
7282                src,
7283                src_row_stride,
7284                ..
7285            } = &mut thunks[j]
7286            {
7287                *src = n_src;
7288                *src_row_stride = n_src_stride;
7289            }
7290            thunks[narrow] = Thunk::Nop;
7291            fused_count += 1;
7292        }
7293
7294        if fused_count > 0 && cfg.verbose >= 1 {
7295            eprintln!(
7296                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7297                fused_count
7298            );
7299        }
7300    }
7301
7302    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
7303    // For each Attention thunk in the schedule, look up the producers
7304    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
7305    // one consumer (the Attention), rewire Attention to read directly
7306    // from the parent buffer with the parent's row stride. The three
7307    // Narrows become Nops.
7308    //
7309    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
7310    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
7311    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
7312    // per layer × 6 layers = ~14 MB of write traffic gone.
7313    {
7314        let mut read_counts: HashMap<usize, usize> = HashMap::new();
7315        for t in &thunks {
7316            for off in thunk_read_offsets(t) {
7317                *read_counts.entry(off).or_insert(0) += 1;
7318            }
7319        }
7320        // Build dst→index map for fast producer lookup.
7321        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7322        for (i, t) in thunks.iter().enumerate() {
7323            if let Thunk::Narrow { dst, .. } = t {
7324                dst_to_idx.insert(*dst, i);
7325            }
7326        }
7327
7328        let mut fused_count = 0usize;
7329        for i in 0..thunks.len() {
7330            let (q_off, k_off, v_off) = match &thunks[i] {
7331                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7332                _ => continue,
7333            };
7334            // All three inputs must come from Narrows.
7335            let q_n = match dst_to_idx.get(&q_off).copied() {
7336                Some(x) => x,
7337                None => continue,
7338            };
7339            let k_n = match dst_to_idx.get(&k_off).copied() {
7340                Some(x) => x,
7341                None => continue,
7342            };
7343            let v_n = match dst_to_idx.get(&v_off).copied() {
7344                Some(x) => x,
7345                None => continue,
7346            };
7347            // Each Narrow's dst must have exactly one reader (this Attn).
7348            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7349                continue;
7350            }
7351            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7352                continue;
7353            }
7354            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7355                continue;
7356            }
7357
7358            let (q_src, q_stride) = match &thunks[q_n] {
7359                Thunk::Narrow {
7360                    src, src_stride, ..
7361                } => (*src, *src_stride),
7362                _ => continue,
7363            };
7364            let (k_src, k_stride) = match &thunks[k_n] {
7365                Thunk::Narrow {
7366                    src, src_stride, ..
7367                } => (*src, *src_stride),
7368                _ => continue,
7369            };
7370            let (v_src, v_stride) = match &thunks[v_n] {
7371                Thunk::Narrow {
7372                    src, src_stride, ..
7373                } => (*src, *src_stride),
7374                _ => continue,
7375            };
7376
7377            if let Thunk::Attention {
7378                q,
7379                k,
7380                v,
7381                q_row_stride,
7382                k_row_stride,
7383                v_row_stride,
7384                ..
7385            } = &mut thunks[i]
7386            {
7387                *q = q_src;
7388                *k = k_src;
7389                *v = v_src;
7390                *q_row_stride = q_stride;
7391                *k_row_stride = k_stride;
7392                *v_row_stride = v_stride;
7393            }
7394            thunks[q_n] = Thunk::Nop;
7395            thunks[k_n] = Thunk::Nop;
7396            thunks[v_n] = Thunk::Nop;
7397            fused_count += 1;
7398        }
7399
7400        if fused_count > 0 && cfg.verbose >= 1 {
7401            eprintln!(
7402                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7403                fused_count
7404            );
7405        }
7406    }
7407
7408    ThunkSchedule {
7409        thunks,
7410        moe_resident: None,
7411        moe_resident_layers: None,
7412        moe_topk_capture: None,
7413        mask_threshold: cfg.mask_binary_threshold,
7414        mask_neg_inf: cfg.attn_mask_neg_inf,
7415        score_skip: cfg.score_skip_threshold,
7416        compiled_fns,
7417    }
7418}
7419
7420fn get_len(graph: &Graph, id: NodeId) -> usize {
7421    graph.node(id).shape.num_elements().unwrap_or(0)
7422}
7423
7424/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
7425fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7426    let dims = graph.node(id).shape.dims();
7427    let mut out = Vec::with_capacity(dims.len());
7428    for d in dims {
7429        if let Some(s) = match d {
7430            rlx_ir::Dim::Static(s) => Some(*s),
7431            _ => None,
7432        } {
7433            out.push(s);
7434        } else {
7435            return Vec::new();
7436        }
7437    }
7438    out
7439}
7440
7441/// NumPy-style broadcast strides for one operand into the flat output
7442/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
7443/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
7444/// (after left-padding with size-1 to match ranks), otherwise the
7445/// natural row-major stride into the *input* buffer.
7446///
7447/// Caller iterates output flat index `i` → output coords (row-major)
7448/// → input flat index = dot(coords, strides). The result is correct
7449/// for any broadcast pattern (scalar, last-axis, middle-axis,
7450/// bidirectional).
7451/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
7452/// — i.e. every rhs dim either equals the corresponding output dim
7453/// (counting from the right) or rhs is shorter (left-padded with 1s).
7454/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
7455/// where `c > 1`) are NOT trailing broadcasts and require the
7456/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
7457/// kernel silently miscomputes them.
7458fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7459    if rhs_dims.len() > out_dims.len() {
7460        return false;
7461    }
7462    let off = out_dims.len() - rhs_dims.len();
7463    for i in 0..rhs_dims.len() {
7464        let r = match rhs_dims[i] {
7465            rlx_ir::Dim::Static(n) => n,
7466            _ => return false,
7467        };
7468        let o = match out_dims[off + i] {
7469            rlx_ir::Dim::Static(n) => n,
7470            _ => return false,
7471        };
7472        if r != o {
7473            return false;
7474        }
7475    }
7476    true
7477}
7478
7479fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7480    let r_out = out_dims.len();
7481    let r_in = in_dims.len();
7482    assert!(
7483        r_in <= r_out,
7484        "broadcast: input rank {r_in} > output rank {r_out}"
7485    );
7486    let pad = r_out - r_in;
7487    let mut strides = vec![0u32; r_out];
7488    let mut acc: usize = 1;
7489    for d in (0..r_out).rev() {
7490        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7491        if in_size == 1 {
7492            strides[d] = 0;
7493        } else {
7494            assert_eq!(
7495                in_size, out_dims[d],
7496                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7497                out_dims[d]
7498            );
7499            strides[d] = acc as u32;
7500            acc *= in_size;
7501        }
7502    }
7503    strides
7504}
7505
7506/// Execute a thunk schedule on a raw arena buffer.
7507/// Fastest executor: call pre-compiled closures sequentially.
7508/// Zero match dispatch — each closure is a direct kernel call.
7509pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7510    let base = arena_buf.as_mut_ptr();
7511    for f in &schedule.compiled_fns {
7512        f(base);
7513    }
7514}
7515
7516/// Active-extent execution stub. The runtime calls this when it has an
7517/// active-extent hint set. CPU doesn't implement per-thunk active-extent
7518/// scaling yet — return false so the caller falls back to the full
7519/// `execute_thunks` path.
7520pub fn execute_thunks_active(
7521    schedule: &ThunkSchedule,
7522    _arena_buf: &mut [u8],
7523    _actual: usize,
7524    _upper: usize,
7525) -> bool {
7526    let _ = schedule;
7527    false
7528}
7529
7530/// Match-based executor (fallback, used by tests).
7531struct MoeResidencyGuard;
7532impl Drop for MoeResidencyGuard {
7533    fn drop(&mut self) {
7534        if let Some(stats) = crate::moe_residency::take_stats() {
7535            crate::moe_residency::stash_last_forward_stats(stats);
7536        } else {
7537            crate::moe_residency::clear_mask();
7538        }
7539    }
7540}
7541
7542fn thunk_kind_name(t: &Thunk) -> &'static str {
7543    match t {
7544        Thunk::Nop => "Nop",
7545        Thunk::Gather { .. } => "Gather",
7546        Thunk::GatherAxis { .. } => "GatherAxis",
7547        Thunk::TopK { .. } => "TopK",
7548        Thunk::Copy { .. } => "Copy",
7549        Thunk::CopyF64 { .. } => "CopyF64",
7550        Thunk::CopyI64 { .. } => "CopyI64",
7551        Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
7552        Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
7553        Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
7554        Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
7555        Thunk::Transpose { .. } => "Transpose",
7556        Thunk::TransposeF64 { .. } => "TransposeF64",
7557        Thunk::Where { .. } => "Where",
7558        Thunk::Compare { .. } => "Compare",
7559        Thunk::BinaryFull { .. } => "BinaryFull",
7560        Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
7561        Thunk::Sgemm { .. } => "Sgemm",
7562        Thunk::Dgemm { .. } => "Dgemm",
7563        Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
7564        Thunk::BiasAdd { .. } => "BiasAdd",
7565        Thunk::LayerNorm { .. } => "LayerNorm",
7566        Thunk::Softmax { .. } => "Softmax",
7567        Thunk::Conv2D { .. } => "Conv2D",
7568        Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
7569        Thunk::CustomOp { .. } => "CustomOp",
7570        Thunk::ActivationInPlace { .. } => "ActivationInPlace",
7571        Thunk::Narrow { .. } => "Narrow",
7572        Thunk::Cumsum { .. } => "Cumsum",
7573        Thunk::Reduce { .. } => "Reduce",
7574        Thunk::BatchedSgemm { .. } => "BatchedSgemm",
7575        Thunk::DequantMatMul { .. } => "DequantMatMul",
7576        Thunk::Quantize { .. } => "Quantize",
7577        Thunk::Dequantize { .. } => "Dequantize",
7578        Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
7579        Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
7580        _ => "Other",
7581    }
7582}
7583
7584pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7585    crate::moe_residency::reset_gmm_counters();
7586    if let Some(layers) = schedule.moe_resident_layers.clone() {
7587        crate::moe_residency::set_per_layer_masks(Some(layers));
7588    } else {
7589        crate::moe_residency::set_mask(schedule.moe_resident.clone());
7590    }
7591    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
7592        cap.clear();
7593    }
7594    let _moe_guard = MoeResidencyGuard;
7595    let base = arena_buf.as_mut_ptr();
7596    let mask_thr = schedule.mask_threshold;
7597    let mask_neg = schedule.mask_neg_inf;
7598    let score_thr = schedule.score_skip;
7599    let thunks = &schedule.thunks;
7600    let len = thunks.len();
7601
7602    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
7603    let max_h = thunks
7604        .iter()
7605        .filter_map(|t| match t {
7606            Thunk::FusedResidualLN { h, .. }
7607            | Thunk::FusedResidualRmsNorm { h, .. }
7608            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
7609            _ => None,
7610        })
7611        .max()
7612        .unwrap_or(0);
7613    let zero_bias = vec![0f32; max_h];
7614
7615    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
7616    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
7617    let max_sdpa = thunks
7618        .iter()
7619        .filter_map(|t| match t {
7620            Thunk::Attention {
7621                batch,
7622                seq,
7623                kv_seq,
7624                heads,
7625                head_dim,
7626                ..
7627            } => Some((
7628                *batch as usize,
7629                (*seq as usize).max(*kv_seq as usize),
7630                *heads as usize,
7631                *head_dim as usize,
7632            )),
7633            _ => None,
7634        })
7635        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7636            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7637        });
7638    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7639    let max_units = max_batch * max_heads;
7640    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7641
7642    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
7643    let fl = thunks
7644        .iter()
7645        .filter_map(|t| match t {
7646            Thunk::FusedBertLayer {
7647                batch,
7648                seq,
7649                hs,
7650                int_dim,
7651                ..
7652            } => {
7653                let m = (*batch as usize) * (*seq as usize);
7654                let h = *hs as usize;
7655                let id = *int_dim as usize;
7656                Some((m, h, id, m * (*seq as usize)))
7657            }
7658            Thunk::FusedNomicLayer {
7659                batch,
7660                seq,
7661                hs,
7662                int_dim,
7663                ..
7664            } => {
7665                let m = (*batch as usize) * (*seq as usize);
7666                let h = *hs as usize;
7667                let id = *int_dim as usize;
7668                Some((m, h, id, m * (*seq as usize)))
7669            }
7670            _ => None,
7671        })
7672        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7673            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7674        });
7675    let (fl_m, fl_h, fl_int, fl_ss) = fl;
7676    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7677    let mut fl_attn = vec![0f32; fl_m * fl_h];
7678    let mut fl_res = vec![0f32; fl_m * fl_h];
7679    let mut fl_normed = vec![0f32; fl_m * fl_h];
7680    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
7681    let mut fl_sc = vec![0f32; fl_ss.max(1)];
7682
7683    let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
7684    if trace_thunks {
7685        eprintln!(
7686            "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
7687            max_units * max_seq * max_seq
7688        );
7689    }
7690    for i in 0..len {
7691        let thunk = unsafe { thunks.get_unchecked(i) };
7692        if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
7693            eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
7694        }
7695        let trace_done = trace_thunks && i < 120;
7696        match thunk {
7697            Thunk::Nop => {}
7698
7699            Thunk::GaussianSplatRender {
7700                positions_off,
7701                positions_len,
7702                scales_off,
7703                scales_len,
7704                rotations_off,
7705                rotations_len,
7706                opacities_off,
7707                opacities_len,
7708                colors_off,
7709                colors_len,
7710                sh_coeffs_off,
7711                sh_coeffs_len,
7712                meta_off,
7713                dst_off,
7714                dst_len,
7715                width,
7716                height,
7717                tile_size,
7718                radius_scale,
7719                alpha_cutoff,
7720                max_splat_steps,
7721                transmittance_threshold,
7722                max_list_entries,
7723            } => unsafe {
7724                crate::splat::execute_gaussian_splat_render(
7725                    *positions_off,
7726                    *positions_len,
7727                    *scales_off,
7728                    *scales_len,
7729                    *rotations_off,
7730                    *rotations_len,
7731                    *opacities_off,
7732                    *opacities_len,
7733                    *colors_off,
7734                    *colors_len,
7735                    *sh_coeffs_off,
7736                    *sh_coeffs_len,
7737                    *meta_off,
7738                    *dst_off,
7739                    *dst_len,
7740                    *width,
7741                    *height,
7742                    *tile_size,
7743                    *radius_scale,
7744                    *alpha_cutoff,
7745                    *max_splat_steps,
7746                    *transmittance_threshold,
7747                    *max_list_entries,
7748                    base,
7749                );
7750            },
7751
7752            Thunk::GaussianSplatRenderBackward {
7753                positions_off,
7754                positions_len,
7755                scales_off,
7756                scales_len,
7757                rotations_off,
7758                rotations_len,
7759                opacities_off,
7760                opacities_len,
7761                colors_off,
7762                colors_len,
7763                sh_coeffs_off,
7764                sh_coeffs_len,
7765                meta_off,
7766                d_loss_off,
7767                d_loss_len,
7768                packed_off,
7769                packed_len,
7770                width,
7771                height,
7772                tile_size,
7773                radius_scale,
7774                alpha_cutoff,
7775                max_splat_steps,
7776                transmittance_threshold,
7777                max_list_entries,
7778                loss_grad_clip,
7779                sh_band,
7780                max_anisotropy,
7781            } => unsafe {
7782                crate::splat::execute_gaussian_splat_render_backward(
7783                    *positions_off,
7784                    *positions_len,
7785                    *scales_off,
7786                    *scales_len,
7787                    *rotations_off,
7788                    *rotations_len,
7789                    *opacities_off,
7790                    *opacities_len,
7791                    *colors_off,
7792                    *colors_len,
7793                    *sh_coeffs_off,
7794                    *sh_coeffs_len,
7795                    *meta_off,
7796                    *d_loss_off,
7797                    *d_loss_len,
7798                    *packed_off,
7799                    *packed_len,
7800                    *width,
7801                    *height,
7802                    *tile_size,
7803                    *radius_scale,
7804                    *alpha_cutoff,
7805                    *max_splat_steps,
7806                    *transmittance_threshold,
7807                    *max_list_entries,
7808                    *loss_grad_clip,
7809                    *sh_band,
7810                    *max_anisotropy,
7811                    base,
7812                );
7813            },
7814
7815            Thunk::GaussianSplatPrepare {
7816                positions_off,
7817                positions_len,
7818                scales_off,
7819                scales_len,
7820                rotations_off,
7821                rotations_len,
7822                opacities_off,
7823                opacities_len,
7824                colors_off,
7825                colors_len,
7826                sh_coeffs_off,
7827                sh_coeffs_len,
7828                meta_off,
7829                meta_len,
7830                prep_off,
7831                prep_len,
7832                width,
7833                height,
7834                tile_size,
7835                radius_scale,
7836                alpha_cutoff,
7837                max_splat_steps,
7838                transmittance_threshold,
7839                max_list_entries,
7840            } => unsafe {
7841                crate::splat::execute_gaussian_splat_prepare(
7842                    *positions_off,
7843                    *positions_len,
7844                    *scales_off,
7845                    *scales_len,
7846                    *rotations_off,
7847                    *rotations_len,
7848                    *opacities_off,
7849                    *opacities_len,
7850                    *colors_off,
7851                    *colors_len,
7852                    *sh_coeffs_off,
7853                    *sh_coeffs_len,
7854                    *meta_off,
7855                    *meta_len,
7856                    *prep_off,
7857                    *prep_len,
7858                    *width,
7859                    *height,
7860                    *tile_size,
7861                    *radius_scale,
7862                    *alpha_cutoff,
7863                    *max_splat_steps,
7864                    *transmittance_threshold,
7865                    *max_list_entries,
7866                    base,
7867                );
7868            },
7869
7870            Thunk::GaussianSplatRasterize {
7871                prep_off,
7872                prep_len,
7873                meta_off,
7874                meta_len,
7875                dst_off,
7876                dst_len,
7877                count,
7878                width,
7879                height,
7880                tile_size,
7881                alpha_cutoff,
7882                max_splat_steps,
7883                transmittance_threshold,
7884                max_list_entries,
7885            } => unsafe {
7886                crate::splat::execute_gaussian_splat_rasterize(
7887                    *prep_off,
7888                    *prep_len,
7889                    *meta_off,
7890                    *meta_len,
7891                    *dst_off,
7892                    *dst_len,
7893                    *count,
7894                    *width,
7895                    *height,
7896                    *tile_size,
7897                    *alpha_cutoff,
7898                    *max_splat_steps,
7899                    *transmittance_threshold,
7900                    *max_list_entries,
7901                    base,
7902                );
7903            },
7904
7905            Thunk::Fft1d {
7906                src,
7907                dst,
7908                outer,
7909                n_complex,
7910                inverse,
7911                norm_tag,
7912                dtype,
7913            } => unsafe {
7914                match dtype {
7915                    rlx_ir::DType::F64 => execute_fft1d_f64(
7916                        *src,
7917                        *dst,
7918                        *outer as usize,
7919                        *n_complex as usize,
7920                        *inverse,
7921                        *norm_tag,
7922                        base,
7923                    ),
7924                    rlx_ir::DType::F32 => execute_fft1d_f32(
7925                        *src,
7926                        *dst,
7927                        *outer as usize,
7928                        *n_complex as usize,
7929                        *inverse,
7930                        *norm_tag,
7931                        base,
7932                    ),
7933                    rlx_ir::DType::C64 => execute_fft1d_c64(
7934                        *src,
7935                        *dst,
7936                        *outer as usize,
7937                        *n_complex as usize,
7938                        *inverse,
7939                        *norm_tag,
7940                        base,
7941                    ),
7942                    other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7943                }
7944            },
7945
7946            Thunk::FftButterflyStage {
7947                state_src,
7948                state_dst,
7949                gate_src,
7950                rev_src,
7951                tw_re_src,
7952                tw_im_src,
7953                batch,
7954                n_fft,
7955                stage,
7956            } => unsafe {
7957                execute_fft_butterfly_stage_f32(
7958                    *state_src,
7959                    *state_dst,
7960                    *gate_src,
7961                    *rev_src,
7962                    *tw_re_src,
7963                    *tw_im_src,
7964                    *batch as usize,
7965                    *n_fft as usize,
7966                    *stage as usize,
7967                    base,
7968                );
7969            },
7970
7971            Thunk::LogMel {
7972                spec,
7973                filters,
7974                dst,
7975                outer,
7976                n_fft,
7977                n_bins,
7978                n_mels,
7979            } => unsafe {
7980                execute_log_mel_f32(
7981                    *spec,
7982                    *filters,
7983                    *dst,
7984                    *outer as usize,
7985                    *n_fft as usize,
7986                    *n_bins as usize,
7987                    *n_mels as usize,
7988                    base,
7989                );
7990            },
7991
7992            Thunk::LogMelBackward {
7993                spec,
7994                filters,
7995                dy,
7996                dst,
7997                outer,
7998                n_fft,
7999                n_bins,
8000                n_mels,
8001            } => unsafe {
8002                execute_log_mel_backward_f32(
8003                    *spec,
8004                    *filters,
8005                    *dy,
8006                    *dst,
8007                    *outer as usize,
8008                    *n_fft as usize,
8009                    *n_bins as usize,
8010                    *n_mels as usize,
8011                    base,
8012                );
8013            },
8014
8015            // CustomFn dispatch (interpreted path). Mirrors the
8016            // pre-compiled-closure variant elsewhere in this file.
8017            // Patched by rlx-eda.
8018            Thunk::CustomFn {
8019                body,
8020                body_init,
8021                inputs,
8022                body_output_off,
8023                outer_output_off,
8024                out_bytes,
8025            } => {
8026                let mut body_buf: Vec<u8> = (**body_init).clone();
8027                unsafe {
8028                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8029                        let src = (base as *const u8).add(*outer_in_off);
8030                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
8031                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8032                    }
8033                }
8034                execute_thunks(body, &mut body_buf);
8035                unsafe {
8036                    let src = body_buf.as_ptr().add(*body_output_off);
8037                    let dst = base.add(*outer_output_off);
8038                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8039                }
8040            }
8041
8042            Thunk::Sgemm { a, b, c, m, k, n } => {
8043                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8044                if trace_thunks {
8045                    eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8046                }
8047                let c_len = m.saturating_mul(n);
8048                let a_len = m.saturating_mul(k);
8049                let b_len = k.saturating_mul(n);
8050                let arena_len = arena_buf.len();
8051                let max_a = (arena_len.saturating_sub(*a)) / 4;
8052                let max_b = (arena_len.saturating_sub(*b)) / 4;
8053                let max_c = (arena_len.saturating_sub(*c)) / 4;
8054                let a_len = a_len.min(max_a);
8055                let b_len = b_len.min(max_b);
8056                let c_len = c_len.min(max_c);
8057                unsafe {
8058                    let a_sl = sl(*a, base, a_len);
8059                    let b_sl = sl(*b, base, b_len);
8060                    let c_sl = sl_mut(*c, base, c_len);
8061                    if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8062                        || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8063                    {
8064                        let mut tmp = vec![0.0f32; c_len];
8065                        crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8066                        c_sl.copy_from_slice(&tmp);
8067                    } else {
8068                        crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8069                    }
8070                }
8071            }
8072
8073            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8074                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8075                // LAPACK overwrites both A and B; clone into scratch
8076                // each call. Caller's A and b must be preserved for
8077                // VJP recompute. (Eventually: swap to a factor-once /
8078                // solve-many scheme; that's the symbolic-reuse story
8079                // and lives with the sparse path.)
8080                unsafe {
8081                    let a_src = sl_f64(*a, base, n_ * n_);
8082                    let b_src = sl_f64(*b, base, n_ * nrhs_);
8083                    let mut a_scratch: Vec<f64> = a_src.to_vec();
8084                    let mut x_buf: Vec<f64> = b_src.to_vec();
8085                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8086                    if info != 0 {
8087                        panic!(
8088                            "DenseSolveF64: dgesv reported singular matrix \
8089                                (info={info}, n={n_}, nrhs={nrhs_})"
8090                        );
8091                    }
8092                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8093                    dst.copy_from_slice(&x_buf);
8094                }
8095            }
8096
8097            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8098                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8099                unsafe {
8100                    let a_src = sl(*a, base, n_ * n_);
8101                    let b_src = sl(*b, base, n_ * nrhs_);
8102                    let mut a_scratch: Vec<f32> = a_src.to_vec();
8103                    let mut x_buf: Vec<f32> = b_src.to_vec();
8104                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8105                    if info != 0 {
8106                        panic!(
8107                            "DenseSolveF32: sgesv reported singular matrix \
8108                             (info={info}, n={n_}, nrhs={nrhs_})"
8109                        );
8110                    }
8111                    let dst = sl_mut(*x, base, n_ * nrhs_);
8112                    dst.copy_from_slice(&x_buf);
8113                }
8114            }
8115
8116            Thunk::BatchedDenseSolveF64 {
8117                a,
8118                b,
8119                x,
8120                batch,
8121                n,
8122                nrhs,
8123            } => {
8124                // Per slice: extract A_i and b_i, dgesv, write x_i.
8125                // LAPACK has no batched dgesv on Accelerate, so this
8126                // is a serial loop over the batch axis. cuSOLVER /
8127                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
8128                // the GPU path — we'll wire that in rlx-cuda when
8129                // someone needs Linux+CUDA.
8130                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8131                let a_stride = n_ * n_;
8132                let b_stride = n_ * nrhs_;
8133                unsafe {
8134                    let a_full = sl_f64(*a, base, b_ * a_stride);
8135                    let b_full = sl_f64(*b, base, b_ * b_stride);
8136                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8137                    for bi in 0..b_ {
8138                        let mut a_scratch: Vec<f64> =
8139                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8140                        let mut x_buf: Vec<f64> =
8141                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8142                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8143                        if info != 0 {
8144                            panic!(
8145                                "BatchedDenseSolveF64: slice {bi} \
8146                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
8147                            );
8148                        }
8149                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8150                    }
8151                }
8152            }
8153
8154            Thunk::BatchedDenseSolveF32 {
8155                a,
8156                b,
8157                x,
8158                batch,
8159                n,
8160                nrhs,
8161            } => {
8162                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8163                let a_stride = n_ * n_;
8164                let b_stride = n_ * nrhs_;
8165                unsafe {
8166                    let a_full = sl(*a, base, b_ * a_stride);
8167                    let b_full = sl(*b, base, b_ * b_stride);
8168                    let x_full = sl_mut(*x, base, b_ * b_stride);
8169                    for bi in 0..b_ {
8170                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8171                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8172                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8173                        if info != 0 {
8174                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8175                        }
8176                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8177                    }
8178                }
8179            }
8180
8181            Thunk::BatchedDgemmF64 {
8182                a,
8183                b,
8184                c,
8185                batch,
8186                m,
8187                k,
8188                n,
8189            } => {
8190                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8191                let a_stride = m_ * k_;
8192                let b_stride = k_ * n_;
8193                let c_stride = m_ * n_;
8194                unsafe {
8195                    let a_full = sl_f64(*a, base, b_ * a_stride);
8196                    let b_full = sl_f64(*b, base, b_ * b_stride);
8197                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8198                    for bi in 0..b_ {
8199                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8200                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8201                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8202                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8203                    }
8204                }
8205            }
8206
8207            Thunk::BatchedSgemm {
8208                a,
8209                b,
8210                c,
8211                batch,
8212                m,
8213                k,
8214                n,
8215            } => {
8216                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8217                if trace_thunks {
8218                    eprintln!(
8219                        "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8220                        *a, *b, *c
8221                    );
8222                }
8223                let a_stride = m_.saturating_mul(k_);
8224                let b_stride = k_.saturating_mul(n_);
8225                let c_stride = m_.saturating_mul(n_);
8226                let arena_len = arena_buf.len();
8227                let a_cap = (arena_len.saturating_sub(*a)) / 4;
8228                let b_cap = (arena_len.saturating_sub(*b)) / 4;
8229                let c_cap = (arena_len.saturating_sub(*c)) / 4;
8230                let a_elems = (b_ * a_stride).min(a_cap);
8231                let b_elems = (b_ * b_stride).min(b_cap);
8232                let c_elems = (b_ * c_stride).min(c_cap);
8233                let b_eff = b_
8234                    .min(a_elems.checked_div(a_stride).unwrap_or(0))
8235                    .min(b_elems.checked_div(b_stride).unwrap_or(0))
8236                    .min(c_elems.checked_div(c_stride).unwrap_or(0));
8237                unsafe {
8238                    let a_full = sl(*a, base, a_elems);
8239                    let b_full = sl(*b, base, b_elems);
8240                    let c_full = sl_mut(*c, base, c_elems);
8241                    for bi in 0..b_eff {
8242                        let a0 = bi * a_stride;
8243                        let b0 = bi * b_stride;
8244                        let c0 = bi * c_stride;
8245                        if a0 + a_stride > a_full.len()
8246                            || b0 + b_stride > b_full.len()
8247                            || c0 + c_stride > c_full.len()
8248                        {
8249                            break;
8250                        }
8251                        let a_slice = &a_full[a0..a0 + a_stride];
8252                        let b_slice = &b_full[b0..b0 + b_stride];
8253                        let c_slice = &mut c_full[c0..c0 + c_stride];
8254                        if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8255                            || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8256                        {
8257                            let mut tmp = vec![0.0f32; c_stride];
8258                            crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8259                            c_slice.copy_from_slice(&tmp);
8260                        } else {
8261                            crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8262                        }
8263                    }
8264                }
8265            }
8266
8267            Thunk::Dgemm { a, b, c, m, k, n } => {
8268                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8269                unsafe {
8270                    crate::blas::dgemm(
8271                        sl_f64(*a, base, m * k),
8272                        sl_f64(*b, base, k * n),
8273                        sl_mut_f64(*c, base, m * n),
8274                        m,
8275                        k,
8276                        n,
8277                    );
8278                }
8279            }
8280
8281            Thunk::TransposeF64 {
8282                src,
8283                dst,
8284                in_total,
8285                out_dims,
8286                in_strides,
8287            } => unsafe {
8288                let inp = sl_f64(*src, base, *in_total as usize);
8289                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8290                let out = sl_mut_f64(*dst, base, out_total);
8291                transpose_walk_f64(inp, out, out_dims, in_strides);
8292            },
8293
8294            Thunk::ActivationF64 {
8295                src,
8296                dst,
8297                len,
8298                kind,
8299            } => {
8300                let len = *len as usize;
8301                unsafe {
8302                    let inp = sl_f64(*src, base, len);
8303                    let out = sl_mut_f64(*dst, base, len);
8304                    apply_activation_f64(inp, out, *kind);
8305                }
8306            }
8307
8308            Thunk::ReduceSumF64 {
8309                src,
8310                dst,
8311                outer,
8312                reduced,
8313                inner,
8314            } => {
8315                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8316                unsafe {
8317                    let inp = sl_f64(*src, base, o * r * n);
8318                    let out = sl_mut_f64(*dst, base, o * n);
8319                    reduce_sum_f64(inp, out, o, r, n);
8320                }
8321            }
8322
8323            Thunk::CopyF64 { src, dst, len } => {
8324                let mut len = *len as usize;
8325                if *src == *dst || len == 0 {
8326                    continue;
8327                }
8328                let arena_len = arena_buf.len();
8329                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8330                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8331                len = len.min(max_from_src).min(max_from_dst);
8332                if len == 0 {
8333                    continue;
8334                }
8335                let byte_len = len.saturating_mul(8);
8336                unsafe {
8337                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8338                }
8339            }
8340
8341            Thunk::CopyI64 { src, dst, len } => {
8342                let mut len = *len as usize;
8343                if *src == *dst || len == 0 {
8344                    continue;
8345                }
8346                let arena_len = arena_buf.len();
8347                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8348                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8349                len = len.min(max_from_src).min(max_from_dst);
8350                if len == 0 {
8351                    continue;
8352                }
8353                let byte_len = len.saturating_mul(8);
8354                unsafe {
8355                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8356                }
8357            }
8358
8359            Thunk::CastF32ToI64 { src, dst, len } => {
8360                let len = *len as usize;
8361                if len == 0 {
8362                    continue;
8363                }
8364                unsafe {
8365                    let inp = sl(*src, base, len);
8366                    let out = sl_mut_i64(*dst, base, len);
8367                    for i in 0..len {
8368                        out[i] = inp[i].round() as i64;
8369                    }
8370                }
8371            }
8372
8373            Thunk::CastI64ToF32 { src, dst, len } => {
8374                let len = *len as usize;
8375                if len == 0 {
8376                    continue;
8377                }
8378                unsafe {
8379                    let inp = sl_i64(*src, base, len);
8380                    let out = sl_mut(*dst, base, len);
8381                    for i in 0..len {
8382                        out[i] = inp[i] as f32;
8383                    }
8384                }
8385            }
8386
8387            Thunk::CastBoolToI32 { src, dst, len } => {
8388                let len = *len as usize;
8389                if len == 0 {
8390                    continue;
8391                }
8392                unsafe {
8393                    let inp = &arena_buf[*src..*src + len];
8394                    let out = sl_mut_i32(*dst, base, len);
8395                    for i in 0..len {
8396                        out[i] = i32::from(inp[i] != 0);
8397                    }
8398                }
8399            }
8400
8401            Thunk::CastI32ToF32 { src, dst, len } => {
8402                let len = *len as usize;
8403                if len == 0 {
8404                    continue;
8405                }
8406                unsafe {
8407                    let inp = sl_i32(*src, base, len);
8408                    let out = sl_mut(*dst, base, len);
8409                    for i in 0..len {
8410                        out[i] = inp[i] as f32;
8411                    }
8412                }
8413            }
8414
8415            Thunk::BinaryFullF64 {
8416                lhs,
8417                rhs,
8418                dst,
8419                len,
8420                lhs_len,
8421                rhs_len,
8422                op,
8423                out_dims_bcast,
8424                bcast_lhs_strides,
8425                bcast_rhs_strides,
8426            } => {
8427                let len = *len as usize;
8428                let lhs_len = *lhs_len as usize;
8429                let rhs_len = *rhs_len as usize;
8430                unsafe {
8431                    let l = sl_f64(*lhs, base, lhs_len);
8432                    let r = sl_f64(*rhs, base, rhs_len);
8433                    let d = sl_mut_f64(*dst, base, len);
8434                    if lhs_len == len && rhs_len == len {
8435                        for i in 0..len {
8436                            d[i] = binary_op_f64(*op, l[i], r[i]);
8437                        }
8438                    } else if !out_dims_bcast.is_empty() {
8439                        // Shape-aware broadcast path: correct for
8440                        // arbitrary NumPy-style broadcasts including
8441                        // bidirectional `[N,1] op [1,S]`.
8442                        let rank = out_dims_bcast.len();
8443                        let mut coords = vec![0u32; rank];
8444                        for i in 0..len {
8445                            let mut rem = i;
8446                            for ax in (0..rank).rev() {
8447                                let sz = out_dims_bcast[ax] as usize;
8448                                coords[ax] = (rem % sz) as u32;
8449                                rem /= sz;
8450                            }
8451                            let mut li: usize = 0;
8452                            let mut ri: usize = 0;
8453                            for ax in 0..rank {
8454                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8455                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8456                            }
8457                            d[i] = binary_op_f64(*op, l[li], r[ri]);
8458                        }
8459                    } else {
8460                        // Fallback: legacy modulo path (preserved for
8461                        // dynamic-shape graphs where strides can't be
8462                        // precomputed). Only correct for scalar /
8463                        // last-axis broadcast.
8464                        for i in 0..len {
8465                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
8466                        }
8467                    }
8468                }
8469            }
8470
8471            Thunk::BinaryFullC64 {
8472                lhs,
8473                rhs,
8474                dst,
8475                len,
8476                lhs_len,
8477                rhs_len,
8478                op,
8479                out_dims_bcast,
8480                bcast_lhs_strides,
8481                bcast_rhs_strides,
8482            } => {
8483                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
8484                // Underlying f32 buffer length is 2·N (N = complex
8485                // element count). All offsets are byte offsets; the
8486                // `sl` helper reads as f32 starting at the byte
8487                // offset, so f32-length = 2·complex-len.
8488                let n_out = *len as usize;
8489                let n_l = *lhs_len as usize;
8490                let n_r = *rhs_len as usize;
8491                unsafe {
8492                    let l = sl(*lhs, base, 2 * n_l);
8493                    let r = sl(*rhs, base, 2 * n_r);
8494                    let d = sl_mut(*dst, base, 2 * n_out);
8495                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
8496                        match op {
8497                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
8498                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
8499                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
8500                            BinaryOp::Div => {
8501                                let denom = b_re * b_re + b_im * b_im;
8502                                (
8503                                    (a_re * b_re + a_im * b_im) / denom,
8504                                    (a_im * b_re - a_re * b_im) / denom,
8505                                )
8506                            }
8507                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
8508                                unreachable!("C64 max/min/pow rejected at lowering")
8509                            }
8510                        }
8511                    };
8512                    if n_l == n_out && n_r == n_out {
8513                        for i in 0..n_out {
8514                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
8515                            d[2 * i] = re;
8516                            d[2 * i + 1] = im;
8517                        }
8518                    } else if !out_dims_bcast.is_empty() {
8519                        // Strided complex broadcast: strides are in
8520                        // *complex element* units; multiply by 2 when
8521                        // indexing into the f32 buffer.
8522                        let rank = out_dims_bcast.len();
8523                        let mut coords = vec![0u32; rank];
8524                        for i in 0..n_out {
8525                            let mut rem = i;
8526                            for ax in (0..rank).rev() {
8527                                let sz = out_dims_bcast[ax] as usize;
8528                                coords[ax] = (rem % sz) as u32;
8529                                rem /= sz;
8530                            }
8531                            let mut li: usize = 0;
8532                            let mut ri: usize = 0;
8533                            for ax in 0..rank {
8534                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8535                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8536                            }
8537                            let (re, im) =
8538                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8539                            d[2 * i] = re;
8540                            d[2 * i + 1] = im;
8541                        }
8542                    } else {
8543                        // Modulo fallback (scalar / last-axis broadcast).
8544                        for i in 0..n_out {
8545                            let li = if n_l == 1 { 0 } else { i % n_l };
8546                            let ri = if n_r == 1 { 0 } else { i % n_r };
8547                            let (re, im) =
8548                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8549                            d[2 * i] = re;
8550                            d[2 * i + 1] = im;
8551                        }
8552                    }
8553                }
8554            }
8555
8556            Thunk::ComplexNormSqF32 { src, dst, len } => {
8557                let n = *len as usize;
8558                unsafe {
8559                    let s = sl(*src, base, 2 * n);
8560                    let d = sl_mut(*dst, base, n);
8561                    for i in 0..n {
8562                        let re = s[2 * i];
8563                        let im = s[2 * i + 1];
8564                        d[i] = re * re + im * im;
8565                    }
8566                }
8567            }
8568
8569            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
8570                // Wirtinger: dz = g · z, element-wise complex
8571                // (g is real, z is complex).
8572                let n = *len as usize;
8573                unsafe {
8574                    let zb = sl(*z, base, 2 * n);
8575                    let gb = sl(*g, base, n);
8576                    let db = sl_mut(*dz, base, 2 * n);
8577                    for i in 0..n {
8578                        let re = zb[2 * i];
8579                        let im = zb[2 * i + 1];
8580                        let gv = gb[i];
8581                        db[2 * i] = gv * re;
8582                        db[2 * i + 1] = gv * im;
8583                    }
8584                }
8585            }
8586
8587            Thunk::ConjugateC64 { src, dst, len } => {
8588                let n = *len as usize;
8589                unsafe {
8590                    let s = sl(*src, base, 2 * n);
8591                    let d = sl_mut(*dst, base, 2 * n);
8592                    for i in 0..n {
8593                        d[2 * i] = s[2 * i];
8594                        d[2 * i + 1] = -s[2 * i + 1];
8595                    }
8596                }
8597            }
8598
8599            Thunk::ActivationC64 {
8600                src,
8601                dst,
8602                len,
8603                kind,
8604            } => {
8605                let n = *len as usize;
8606                unsafe {
8607                    let s = sl(*src, base, 2 * n);
8608                    let d = sl_mut(*dst, base, 2 * n);
8609                    for i in 0..n {
8610                        let a = s[2 * i];
8611                        let b = s[2 * i + 1];
8612                        let (re, im) = match kind {
8613                            Activation::Neg => (-a, -b),
8614                            Activation::Exp => {
8615                                // exp(a + bi) = e^a · (cos b + i·sin b)
8616                                let ea = a.exp();
8617                                (ea * b.cos(), ea * b.sin())
8618                            }
8619                            Activation::Log => {
8620                                // log(z) = log|z| + i·arg(z), principal branch
8621                                let r = (a * a + b * b).sqrt();
8622                                (r.ln(), b.atan2(a))
8623                            }
8624                            Activation::Sqrt => {
8625                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
8626                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
8627                                let r = (a * a + b * b).sqrt();
8628                                let re = ((r + a) * 0.5).max(0.0).sqrt();
8629                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
8630                                let im = if b >= 0.0 { im_mag } else { -im_mag };
8631                                (re, im)
8632                            }
8633                            _ => unreachable!("non-C64 activation kind survived lowering"),
8634                        };
8635                        d[2 * i] = re;
8636                        d[2 * i + 1] = im;
8637                    }
8638                }
8639            }
8640
8641            Thunk::Scan {
8642                body,
8643                body_init,
8644                body_input_off,
8645                body_output_off,
8646                outer_init_off,
8647                outer_final_off,
8648                length,
8649                carry_bytes,
8650                save_trajectory,
8651                xs_inputs,
8652                bcast_inputs,
8653                num_checkpoints,
8654            } => {
8655                let cb = *carry_bytes as usize;
8656                let n_steps = *length as usize;
8657                // Checkpoint mode: when 0 < K < length, save trajectory[k]
8658                // only when t == c_k = floor((k+1) * length / K) - 1.
8659                // The last index c_{K-1} = length - 1 always.
8660                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
8661                    n_steps // save every step
8662                } else {
8663                    *num_checkpoints as usize
8664                };
8665                let checkpoint_t_for_k = |k: usize| -> usize {
8666                    if k_total == n_steps {
8667                        k
8668                    } else {
8669                        ((k + 1) * n_steps)
8670                            .div_ceil(k_total)
8671                            .saturating_sub(1)
8672                            .min(n_steps - 1)
8673                    }
8674                };
8675                let mut next_k = 0usize;
8676
8677                let mut body_buf: Vec<u8> = (**body_init).clone();
8678                unsafe {
8679                    std::ptr::copy_nonoverlapping(
8680                        base.add(*outer_init_off),
8681                        body_buf.as_mut_ptr().add(*body_input_off),
8682                        cb,
8683                    );
8684                    // Broadcast inputs: copy each one into the body's
8685                    // input slot ONCE. They aren't touched in the
8686                    // iteration loop below (in contrast to xs).
8687                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
8688                        std::ptr::copy_nonoverlapping(
8689                            base.add(*outer_b_off),
8690                            body_buf.as_mut_ptr().add(*body_b_off),
8691                            *total_bytes as usize,
8692                        );
8693                    }
8694                }
8695                for t in 0..n_steps {
8696                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
8697                        let psb = *per_step_bytes as usize;
8698                        unsafe {
8699                            std::ptr::copy_nonoverlapping(
8700                                base.add(*outer_xs_off + t * psb),
8701                                body_buf.as_mut_ptr().add(*body_x_off),
8702                                psb,
8703                            );
8704                        }
8705                    }
8706
8707                    execute_thunks(body, &mut body_buf);
8708
8709                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
8710                        unsafe {
8711                            std::ptr::copy_nonoverlapping(
8712                                body_buf.as_ptr().add(*body_output_off),
8713                                base.add(*outer_final_off + next_k * cb),
8714                                cb,
8715                            );
8716                        }
8717                        next_k += 1;
8718                    }
8719
8720                    if *body_output_off != *body_input_off {
8721                        body_buf
8722                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
8723                    }
8724                }
8725
8726                if !*save_trajectory {
8727                    // Single final-carry write.
8728                    unsafe {
8729                        std::ptr::copy_nonoverlapping(
8730                            body_buf.as_ptr().add(*body_output_off),
8731                            base.add(*outer_final_off),
8732                            cb,
8733                        );
8734                    }
8735                }
8736            }
8737
8738            Thunk::ScanBackward {
8739                body_vjp,
8740                body_init,
8741                body_carry_in_off,
8742                body_x_offs,
8743                body_d_output_off,
8744                body_dcarry_out_off,
8745                outer_init_off,
8746                outer_traj_off,
8747                outer_upstream_off,
8748                outer_xs_offs,
8749                outer_dinit_off,
8750                length,
8751                carry_bytes,
8752                save_trajectory,
8753                num_checkpoints,
8754                forward_body,
8755                forward_body_init,
8756                forward_body_carry_in_off,
8757                forward_body_output_off,
8758                forward_body_x_offs,
8759                carry_elem_size,
8760            } => {
8761                // Two backward paths share the same per-iteration body
8762                // (body_vjp run + dcarry threading). The "All" path
8763                // reads the carry directly from the saved trajectory
8764                // each step. The "Recursive checkpointing" path stores
8765                // only K saved checkpoints and reconstructs intermediate
8766                // carries via Griewank-style recursive subdivision —
8767                // see [`griewank_process_segment`]. Auxiliary memory
8768                // is `O(log(segment_size) · carry_bytes)` for the
8769                // recursion stack, vs the old segment-cache scheme's
8770                // `O(segment_size · carry_bytes)`. Total recompute work
8771                // grows from `O(length)` to `O(length · log)`, which
8772                // is the canonical Griewank trade.
8773                let cb = *carry_bytes as usize;
8774                let n_steps = *length as usize;
8775                let k_total = *num_checkpoints as usize;
8776                let is_recursive = k_total != 0 && k_total != n_steps;
8777                let checkpoint_t_for_k = |k: usize| -> usize {
8778                    ((k + 1) * n_steps)
8779                        .div_ceil(k_total)
8780                        .saturating_sub(1)
8781                        .min(n_steps - 1)
8782                };
8783
8784                let mut fwd_buf: Vec<u8> = if is_recursive {
8785                    (**forward_body_init.as_ref().unwrap()).clone()
8786                } else {
8787                    Vec::new()
8788                };
8789
8790                let mut dcarry: Vec<u8> = vec![0u8; cb];
8791                if !*save_trajectory {
8792                    unsafe {
8793                        std::ptr::copy_nonoverlapping(
8794                            base.add(*outer_upstream_off),
8795                            dcarry.as_mut_ptr(),
8796                            cb,
8797                        );
8798                    }
8799                }
8800
8801                let mut body_buf: Vec<u8> = (**body_init).clone();
8802
8803                // Per-iteration backward action — shared between the
8804                // direct-trajectory (All) and Griewank (Recursive) paths.
8805                // Both feed the same body_vjp run with carry-at-t,
8806                // x_t_i, and d_output, then thread dcarry backward.
8807                let process_iter =
8808                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
8809                        if *save_trajectory {
8810                            unsafe {
8811                                let up_off = *outer_upstream_off + t * cb;
8812                                match *carry_elem_size {
8813                                    4 => {
8814                                        let up_ptr = base.add(up_off) as *const f32;
8815                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8816                                        let n_elems = cb / 4;
8817                                        for i in 0..n_elems {
8818                                            *dc_ptr.add(i) += *up_ptr.add(i);
8819                                        }
8820                                    }
8821                                    8 => {
8822                                        let up_ptr = base.add(up_off) as *const f64;
8823                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8824                                        let n_elems = cb / 8;
8825                                        for i in 0..n_elems {
8826                                            *dc_ptr.add(i) += *up_ptr.add(i);
8827                                        }
8828                                    }
8829                                    other => panic!(
8830                                        "ScanBackward: unsupported carry elem size {other} \
8831                                     (only f32/f64 carries are supported today)"
8832                                    ),
8833                                }
8834                            }
8835                        }
8836                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8837                            .copy_from_slice(carry_in);
8838                        unsafe {
8839                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
8840                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8841                                let psb = per_step_bytes as usize;
8842                                std::ptr::copy_nonoverlapping(
8843                                    base.add(outer_xs_off + t * psb),
8844                                    body_buf.as_mut_ptr().add(*body_x_off),
8845                                    psb,
8846                                );
8847                            }
8848                            std::ptr::copy_nonoverlapping(
8849                                dcarry.as_ptr(),
8850                                body_buf.as_mut_ptr().add(*body_d_output_off),
8851                                cb,
8852                            );
8853                        }
8854                        execute_thunks(body_vjp, body_buf);
8855                        unsafe {
8856                            std::ptr::copy_nonoverlapping(
8857                                body_buf.as_ptr().add(*body_dcarry_out_off),
8858                                dcarry.as_mut_ptr(),
8859                                cb,
8860                            );
8861                        }
8862                    };
8863
8864                if is_recursive {
8865                    // Griewank treeverse path. Process saved-checkpoint
8866                    // segments from highest-t to lowest-t; within each,
8867                    // recursive binary subdivision via
8868                    // `griewank_process_segment`. Auxiliary memory:
8869                    // O(log(seg_size) · cb) for the recursion stack
8870                    // (vs O(seg_size · cb) for the older segment-cache
8871                    // scheme); recompute work: O(seg_size · log).
8872                    let leaf_threshold = 4usize;
8873                    let fb_sched = forward_body.as_ref().unwrap();
8874                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8875                    let mut segment_end = n_steps - 1;
8876                    for seg_k in (0..k_total).rev() {
8877                        let segment_start = if seg_k == 0 {
8878                            0
8879                        } else {
8880                            checkpoint_t_for_k(seg_k - 1) + 1
8881                        };
8882                        let mut anchor: Vec<u8> = vec![0u8; cb];
8883                        unsafe {
8884                            let src = if seg_k == 0 {
8885                                base.add(*outer_init_off)
8886                            } else {
8887                                base.add(*outer_traj_off + (seg_k - 1) * cb)
8888                            };
8889                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8890                        }
8891                        // Closure adapter for the helper's signature
8892                        // (mutably re-borrows dcarry / body_buf each call).
8893                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
8894                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8895                        };
8896                        unsafe {
8897                            griewank_process_segment(
8898                                segment_start,
8899                                segment_end,
8900                                &anchor,
8901                                cb,
8902                                fb_sched,
8903                                fb_init,
8904                                *forward_body_carry_in_off,
8905                                *forward_body_output_off,
8906                                forward_body_x_offs,
8907                                base,
8908                                outer_xs_offs,
8909                                &mut fwd_buf,
8910                                leaf_threshold,
8911                                &mut leaf_action,
8912                            );
8913                        }
8914                        if seg_k == 0 {
8915                            break;
8916                        }
8917                        segment_end = segment_start - 1;
8918                    }
8919                } else {
8920                    // All-trajectory path: read each carry directly
8921                    // from the saved trajectory buffer.
8922                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
8923                    for t in (0..n_steps).rev() {
8924                        unsafe {
8925                            let src = if t == 0 {
8926                                base.add(*outer_init_off)
8927                            } else {
8928                                base.add(*outer_traj_off + (t - 1) * cb)
8929                            };
8930                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8931                        }
8932                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8933                    }
8934                }
8935
8936                unsafe {
8937                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8938                }
8939            }
8940
8941            Thunk::ScanBackwardXs {
8942                body_vjp,
8943                body_init,
8944                body_carry_in_off,
8945                body_x_offs,
8946                body_d_output_off,
8947                body_dcarry_out_off,
8948                body_dxs_out_off,
8949                outer_init_off,
8950                outer_traj_off,
8951                outer_upstream_off,
8952                outer_xs_offs,
8953                outer_dxs_off,
8954                length,
8955                carry_bytes,
8956                carry_elem_size,
8957                per_step_bytes,
8958                save_trajectory,
8959                num_checkpoints,
8960                forward_body,
8961                forward_body_init,
8962                forward_body_carry_in_off,
8963                forward_body_output_off,
8964                forward_body_x_offs,
8965            } => {
8966                let cb = *carry_bytes as usize;
8967                let psb = *per_step_bytes as usize;
8968                let n_steps = *length as usize;
8969                let k_total = *num_checkpoints as usize;
8970                let is_recursive = k_total != 0 && k_total != n_steps;
8971                let checkpoint_t_for_k = |k: usize| -> usize {
8972                    ((k + 1) * n_steps)
8973                        .div_ceil(k_total)
8974                        .saturating_sub(1)
8975                        .min(n_steps - 1)
8976                };
8977
8978                // Forward-body recompute scratch + segment cache —
8979                // exact mirror of the ScanBackward path. With ≈√length
8980                // checkpoints, total recompute work is O(length).
8981                let mut fwd_buf: Vec<u8> = if is_recursive {
8982                    (**forward_body_init.as_ref().unwrap()).clone()
8983                } else {
8984                    Vec::new()
8985                };
8986                let mut seg_cache: Vec<u8> = Vec::new();
8987                let mut seg_start_t: usize = usize::MAX;
8988                let mut seg_count: usize = 0;
8989                let recompute_carry_t =
8990                    |t: usize,
8991                     dst: &mut [u8],
8992                     fwd_buf: &mut Vec<u8>,
8993                     seg_cache: &mut Vec<u8>,
8994                     seg_start_t: &mut usize,
8995                     seg_count: &mut usize| {
8996                        if !is_recursive {
8997                            unsafe {
8998                                let src = if t == 0 {
8999                                    base.add(*outer_init_off)
9000                                } else {
9001                                    base.add(*outer_traj_off + (t - 1) * cb)
9002                                };
9003                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9004                            }
9005                            return;
9006                        }
9007                        if *seg_start_t != usize::MAX
9008                            && t >= *seg_start_t
9009                            && t < *seg_start_t + *seg_count
9010                        {
9011                            let off = (t - *seg_start_t) * cb;
9012                            dst.copy_from_slice(&seg_cache[off..off + cb]);
9013                            return;
9014                        }
9015                        let seg_k = (0..k_total)
9016                            .find(|&k| t <= checkpoint_t_for_k(k))
9017                            .unwrap_or(k_total - 1);
9018                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9019                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
9020                        } else {
9021                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
9022                            (prev_ck + 1, unsafe {
9023                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9024                            })
9025                        };
9026                        let seg_end_t = checkpoint_t_for_k(seg_k);
9027                        let seg_size = seg_end_t - anchor_t + 1;
9028
9029                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9030                        unsafe {
9031                            std::ptr::copy_nonoverlapping(
9032                                anchor_ptr,
9033                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9034                                cb,
9035                            );
9036                        }
9037                        seg_cache.resize(seg_size * cb, 0u8);
9038                        seg_cache[0..cb].copy_from_slice(
9039                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9040                        );
9041                        let fb_sched = forward_body.as_ref().unwrap();
9042                        for i in 1..seg_size {
9043                            let cur_iter = anchor_t + i - 1;
9044                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9045                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9046                                let xb = x_psb as usize;
9047                                unsafe {
9048                                    std::ptr::copy_nonoverlapping(
9049                                        base.add(outer_xs_off + cur_iter * xb),
9050                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
9051                                        xb,
9052                                    );
9053                                }
9054                            }
9055                            execute_thunks(fb_sched, fwd_buf);
9056                            if *forward_body_output_off != *forward_body_carry_in_off {
9057                                fwd_buf.copy_within(
9058                                    *forward_body_output_off..*forward_body_output_off + cb,
9059                                    *forward_body_carry_in_off,
9060                                );
9061                            }
9062                            let cache_off = i * cb;
9063                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
9064                                &fwd_buf
9065                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9066                            );
9067                        }
9068                        *seg_start_t = anchor_t;
9069                        *seg_count = seg_size;
9070
9071                        let off = (t - anchor_t) * cb;
9072                        dst.copy_from_slice(&seg_cache[off..off + cb]);
9073                    };
9074
9075                let mut dcarry: Vec<u8> = vec![0u8; cb];
9076                if !*save_trajectory {
9077                    unsafe {
9078                        std::ptr::copy_nonoverlapping(
9079                            base.add(*outer_upstream_off),
9080                            dcarry.as_mut_ptr(),
9081                            cb,
9082                        );
9083                    }
9084                }
9085
9086                let mut body_buf: Vec<u8> = (**body_init).clone();
9087
9088                for t in (0..n_steps).rev() {
9089                    if *save_trajectory {
9090                        unsafe {
9091                            let up_off = *outer_upstream_off + t * cb;
9092                            match *carry_elem_size {
9093                                4 => {
9094                                    let up_ptr = base.add(up_off) as *const f32;
9095                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9096                                    let n_elems = cb / 4;
9097                                    for i in 0..n_elems {
9098                                        *dc_ptr.add(i) += *up_ptr.add(i);
9099                                    }
9100                                }
9101                                8 => {
9102                                    let up_ptr = base.add(up_off) as *const f64;
9103                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9104                                    let n_elems = cb / 8;
9105                                    for i in 0..n_elems {
9106                                        *dc_ptr.add(i) += *up_ptr.add(i);
9107                                    }
9108                                }
9109                                other => panic!(
9110                                    "ScanBackwardXs: unsupported carry elem size {other} \
9111                                     (only f32/f64 carries are supported today)"
9112                                ),
9113                            }
9114                        }
9115                    }
9116
9117                    // Seed body_vjp's carry input via the recompute
9118                    // helper (works for both All and Recursive modes),
9119                    // then x_t_i + d_output.
9120                    let carry_dst_start = *body_carry_in_off;
9121                    {
9122                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9123                        recompute_carry_t(
9124                            t,
9125                            carry_slice,
9126                            &mut fwd_buf,
9127                            &mut seg_cache,
9128                            &mut seg_start_t,
9129                            &mut seg_count,
9130                        );
9131                    }
9132                    unsafe {
9133                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
9134                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
9135                            let xb = x_psb as usize;
9136                            std::ptr::copy_nonoverlapping(
9137                                base.add(outer_xs_off + t * xb),
9138                                body_buf.as_mut_ptr().add(*body_x_off),
9139                                xb,
9140                            );
9141                        }
9142                        std::ptr::copy_nonoverlapping(
9143                            dcarry.as_ptr(),
9144                            body_buf.as_mut_ptr().add(*body_d_output_off),
9145                            cb,
9146                        );
9147                    }
9148
9149                    execute_thunks(body_vjp, &mut body_buf);
9150
9151                    // Stash this step's dxs into row `t` of the outer
9152                    // [length, *per_step_xs] output.
9153                    unsafe {
9154                        std::ptr::copy_nonoverlapping(
9155                            body_buf.as_ptr().add(*body_dxs_out_off),
9156                            base.add(*outer_dxs_off + t * psb),
9157                            psb,
9158                        );
9159                    }
9160
9161                    // Update dcarry for next backward iteration.
9162                    unsafe {
9163                        std::ptr::copy_nonoverlapping(
9164                            body_buf.as_ptr().add(*body_dcarry_out_off),
9165                            dcarry.as_mut_ptr(),
9166                            cb,
9167                        );
9168                    }
9169                }
9170            }
9171
9172            Thunk::FusedMmBiasAct {
9173                a,
9174                w,
9175                bias,
9176                c,
9177                m,
9178                k,
9179                n,
9180                act,
9181            } => {
9182                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9183                unsafe {
9184                    let out = sl_mut(*c, base, m * n);
9185                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9186                    match act {
9187                        Some(Activation::Gelu) => {
9188                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9189                        }
9190                        Some(other) => {
9191                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9192                            apply_activation_inplace(out, *other);
9193                        }
9194                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9195                    }
9196                }
9197            }
9198
9199            Thunk::FusedResidualLN {
9200                x,
9201                res,
9202                bias,
9203                g,
9204                b,
9205                out,
9206                rows,
9207                h,
9208                eps,
9209                has_bias,
9210            } => {
9211                let (rows, h) = (*rows as usize, *h as usize);
9212                unsafe {
9213                    let zero = &zero_bias[..h];
9214                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9215                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9216                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9217                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9218                    let bi_ptr = bi.as_ptr() as usize;
9219                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9220                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9221                    let e = *eps;
9222                    crate::pool::par_for(rows, 4, &|off, cnt| {
9223                        let xs =
9224                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9225                        let rs =
9226                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9227                        let os = std::slice::from_raw_parts_mut(
9228                            (o_ptr as *mut f32).add(off * h),
9229                            cnt * h,
9230                        );
9231                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9232                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9233                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9234                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9235                    });
9236                }
9237            }
9238
9239            Thunk::FusedResidualRmsNorm {
9240                x,
9241                res,
9242                bias,
9243                g,
9244                b,
9245                out,
9246                rows,
9247                h,
9248                eps,
9249                has_bias,
9250            } => {
9251                let (rows, h) = (*rows as usize, *h as usize);
9252                unsafe {
9253                    let zero = &zero_bias[..h];
9254                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9255                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9256                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9257                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9258                    let bi_ptr = bi.as_ptr() as usize;
9259                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9260                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9261                    let e = *eps;
9262                    crate::pool::par_for(rows, 4, &|off, cnt| {
9263                        let xs =
9264                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9265                        let rs =
9266                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9267                        let os = std::slice::from_raw_parts_mut(
9268                            (o_ptr as *mut f32).add(off * h),
9269                            cnt * h,
9270                        );
9271                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9272                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9273                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9274                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9275                    });
9276                }
9277            }
9278
9279            Thunk::BiasAdd {
9280                src,
9281                bias,
9282                dst,
9283                m,
9284                n,
9285            } => {
9286                let (m, n) = (*m as usize, *n as usize);
9287                let len = m * n;
9288                unsafe {
9289                    let out = sl_mut(*dst, base, len);
9290                    if *src != *dst {
9291                        let src_ptr = base.add(*src) as *const f32;
9292                        let dst_ptr = base.add(*dst) as *mut f32;
9293                        if src_ptr != dst_ptr {
9294                            std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9295                        }
9296                    }
9297                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9298                }
9299            }
9300
9301            Thunk::BinaryFull {
9302                lhs,
9303                rhs,
9304                dst,
9305                len,
9306                lhs_len,
9307                rhs_len,
9308                op,
9309                out_dims_bcast,
9310                bcast_lhs_strides,
9311                bcast_rhs_strides,
9312                elem_bytes,
9313            } => {
9314                let len = *len as usize;
9315                let ll = (*lhs_len as usize).max(1);
9316                let rl = (*rhs_len as usize).max(1);
9317                let eb = (*elem_bytes).max(1) as usize;
9318                let arena_len = arena_buf.len();
9319                let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9320                let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9321                let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9322                unsafe {
9323                    if eb == 8 {
9324                        let l = sl_i64(*lhs, base, ll);
9325                        let r = sl_i64(*rhs, base, rl);
9326                        let o = sl_mut_i64(*dst, base, len);
9327                        if !out_dims_bcast.is_empty() {
9328                            let rank = out_dims_bcast.len();
9329                            let mut coords = vec![0u32; rank];
9330                            for i in 0..len {
9331                                let mut rem = i;
9332                                for ax in (0..rank).rev() {
9333                                    let sz = out_dims_bcast[ax] as usize;
9334                                    coords[ax] = (rem % sz) as u32;
9335                                    rem /= sz;
9336                                }
9337                                let mut li = 0usize;
9338                                let mut ri = 0usize;
9339                                for ax in 0..rank {
9340                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9341                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9342                                }
9343                                o[i] = match op {
9344                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9345                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9346                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9347                                    BinaryOp::Div => {
9348                                        if r[ri] == 0 {
9349                                            0
9350                                        } else {
9351                                            l[li] / r[ri]
9352                                        }
9353                                    }
9354                                    BinaryOp::Max => l[li].max(r[ri]),
9355                                    BinaryOp::Min => l[li].min(r[ri]),
9356                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9357                                };
9358                            }
9359                        } else {
9360                            for i in 0..len {
9361                                let li = if ll == 1 { 0 } else { i % ll };
9362                                let ri = if rl == 1 { 0 } else { i % rl };
9363                                o[i] = match op {
9364                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9365                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9366                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9367                                    BinaryOp::Div => {
9368                                        if r[ri] == 0 {
9369                                            0
9370                                        } else {
9371                                            l[li] / r[ri]
9372                                        }
9373                                    }
9374                                    BinaryOp::Max => l[li].max(r[ri]),
9375                                    BinaryOp::Min => l[li].min(r[ri]),
9376                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9377                                };
9378                            }
9379                        }
9380                    } else {
9381                        let l = sl(*lhs, base, ll);
9382                        let r = sl(*rhs, base, rl);
9383                        let o = sl_mut(*dst, base, len);
9384                        if ll == len && rl == len {
9385                            #[cfg(target_arch = "aarch64")]
9386                            if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9387                                use std::arch::aarch64::*;
9388                                let chunks = len / 4;
9389                                for c in 0..chunks {
9390                                    let off = c * 4;
9391                                    let vl = vld1q_f32(l.as_ptr().add(off));
9392                                    let vr = vld1q_f32(r.as_ptr().add(off));
9393                                    let res = match op {
9394                                        BinaryOp::Add => vaddq_f32(vl, vr),
9395                                        BinaryOp::Mul => vmulq_f32(vl, vr),
9396                                        _ => unreachable!(),
9397                                    };
9398                                    vst1q_f32(o.as_mut_ptr().add(off), res);
9399                                }
9400                                for i in (chunks * 4)..len {
9401                                    o[i] = match op {
9402                                        BinaryOp::Add => l[i] + r[i],
9403                                        BinaryOp::Mul => l[i] * r[i],
9404                                        _ => unreachable!(),
9405                                    };
9406                                }
9407                                continue;
9408                            }
9409                        }
9410                        if !out_dims_bcast.is_empty() {
9411                            let rank = out_dims_bcast.len();
9412                            let mut coords = vec![0u32; rank];
9413                            for i in 0..len {
9414                                let mut rem = i;
9415                                for ax in (0..rank).rev() {
9416                                    let sz = out_dims_bcast[ax] as usize;
9417                                    coords[ax] = (rem % sz) as u32;
9418                                    rem /= sz;
9419                                }
9420                                let mut li = 0usize;
9421                                let mut ri = 0usize;
9422                                for ax in 0..rank {
9423                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9424                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9425                                }
9426                                o[i] = match op {
9427                                    BinaryOp::Add => l[li] + r[ri],
9428                                    BinaryOp::Sub => l[li] - r[ri],
9429                                    BinaryOp::Mul => l[li] * r[ri],
9430                                    BinaryOp::Div => l[li] / r[ri],
9431                                    BinaryOp::Max => l[li].max(r[ri]),
9432                                    BinaryOp::Min => l[li].min(r[ri]),
9433                                    BinaryOp::Pow => l[li].powf(r[ri]),
9434                                };
9435                            }
9436                        } else {
9437                            for i in 0..len {
9438                                let li = if ll == 1 { 0 } else { i % ll };
9439                                let ri = if rl == 1 { 0 } else { i % rl };
9440                                o[i] = match op {
9441                                    BinaryOp::Add => l[li] + r[ri],
9442                                    BinaryOp::Sub => l[li] - r[ri],
9443                                    BinaryOp::Mul => l[li] * r[ri],
9444                                    BinaryOp::Div => l[li] / r[ri],
9445                                    BinaryOp::Max => l[li].max(r[ri]),
9446                                    BinaryOp::Min => l[li].min(r[ri]),
9447                                    BinaryOp::Pow => l[li].powf(r[ri]),
9448                                };
9449                            }
9450                        }
9451                    }
9452                }
9453            }
9454
9455            Thunk::Gather {
9456                table,
9457                table_len,
9458                idx,
9459                dst,
9460                num_idx,
9461                trailing,
9462                idx_i64,
9463                table_bytes,
9464            } => {
9465                let (ni, tr) = (*num_idx as usize, *trailing as usize);
9466                let rows = *table_len as usize / tr.max(1);
9467                unsafe {
9468                    if *table_bytes == 8 {
9469                        let tab = sl_i64(*table, base, *table_len as usize);
9470                        let out = sl_mut_i64(*dst, base, ni * tr);
9471                        if *idx_i64 != 0 {
9472                            let ids = sl_i64(*idx, base, ni);
9473                            for i in 0..ni {
9474                                let row = ids[i].max(0) as usize;
9475                                if row < rows {
9476                                    out[i * tr..(i + 1) * tr]
9477                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9478                                }
9479                            }
9480                        } else {
9481                            let ids = sl(*idx, base, ni);
9482                            for i in 0..ni {
9483                                let row = ids[i] as usize;
9484                                if row < rows {
9485                                    out[i * tr..(i + 1) * tr]
9486                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9487                                }
9488                            }
9489                        }
9490                    } else {
9491                        let tab = sl(*table, base, *table_len as usize);
9492                        let out = sl_mut(*dst, base, ni * tr);
9493                        if *idx_i64 != 0 {
9494                            let ids = sl_i64(*idx, base, ni);
9495                            for i in 0..ni {
9496                                let row = ids[i].max(0) as usize;
9497                                if row < rows {
9498                                    out[i * tr..(i + 1) * tr]
9499                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9500                                }
9501                            }
9502                        } else {
9503                            let ids = sl(*idx, base, ni);
9504                            for i in 0..ni {
9505                                let row = ids[i] as usize;
9506                                if row < rows {
9507                                    out[i * tr..(i + 1) * tr]
9508                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9509                                }
9510                            }
9511                        }
9512                    }
9513                }
9514            }
9515
9516            Thunk::Narrow {
9517                src,
9518                dst,
9519                outer,
9520                src_stride,
9521                dst_stride,
9522                inner,
9523                elem_bytes,
9524            } => {
9525                let (outer, ss, ds, inner, eb) = (
9526                    *outer as usize,
9527                    *src_stride as usize,
9528                    *dst_stride as usize,
9529                    *inner as usize,
9530                    *elem_bytes as usize,
9531                );
9532                let row_bytes = inner.saturating_mul(eb);
9533                let src_row_stride = ss.saturating_mul(eb);
9534                let dst_row_stride = ds.saturating_mul(eb);
9535                if trace_thunks {
9536                    eprintln!(
9537                        "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
9538                        *src,
9539                        *dst,
9540                        arena_buf.len()
9541                    );
9542                }
9543                if row_bytes > 0 && *src != *dst {
9544                    let arena_len = arena_buf.len();
9545                    for o in 0..outer {
9546                        let s_off = *src + o * src_row_stride;
9547                        let d_off = *dst + o * dst_row_stride;
9548                        if s_off == d_off {
9549                            continue;
9550                        }
9551                        if s_off.saturating_add(row_bytes) > arena_len
9552                            || d_off.saturating_add(row_bytes) > arena_len
9553                        {
9554                            break;
9555                        }
9556                        unsafe {
9557                            std::ptr::copy_nonoverlapping(
9558                                base.add(s_off),
9559                                base.add(d_off),
9560                                row_bytes,
9561                            );
9562                        }
9563                    }
9564                }
9565            }
9566
9567            Thunk::Copy { src, dst, len } => {
9568                let mut len = *len as usize;
9569                if *src == *dst || len == 0 {
9570                    continue;
9571                }
9572                let arena_len = arena_buf.len();
9573                let max_from_src = (arena_len.saturating_sub(*src)) / 4;
9574                let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
9575                len = len.min(max_from_src).min(max_from_dst);
9576                if len == 0 {
9577                    continue;
9578                }
9579                let byte_len = len.saturating_mul(4);
9580                unsafe {
9581                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
9582                }
9583            }
9584
9585            Thunk::LayerNorm {
9586                src,
9587                g,
9588                b,
9589                dst,
9590                rows,
9591                h,
9592                eps,
9593            } => {
9594                let (rows, h) = (*rows as usize, *h as usize);
9595                unsafe {
9596                    let input = sl(*src, base, rows * h);
9597                    let gamma = sl(*g, base, h);
9598                    let beta = sl(*b, base, h);
9599                    let output = sl_mut(*dst, base, rows * h);
9600                    // Parallelize across rows (same pattern as FusedResidualLN)
9601                    if rows >= 4 && rows * h >= 30_000 {
9602                        let i_ptr = input.as_ptr() as usize;
9603                        let o_ptr = output.as_mut_ptr() as usize;
9604                        let g_ptr = gamma.as_ptr() as usize;
9605                        let b_ptr = beta.as_ptr() as usize;
9606                        let e = *eps;
9607                        crate::pool::par_for(rows, 4, &|off, cnt| {
9608                            let inp = std::slice::from_raw_parts(
9609                                (i_ptr as *const f32).add(off * h),
9610                                cnt * h,
9611                            );
9612                            let out = std::slice::from_raw_parts_mut(
9613                                (o_ptr as *mut f32).add(off * h),
9614                                cnt * h,
9615                            );
9616                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9617                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9618                            for row in 0..cnt {
9619                                crate::kernels::layer_norm_row(
9620                                    &inp[row * h..(row + 1) * h],
9621                                    g,
9622                                    b,
9623                                    &mut out[row * h..(row + 1) * h],
9624                                    h,
9625                                    e,
9626                                );
9627                            }
9628                        });
9629                    } else {
9630                        for row in 0..rows {
9631                            crate::kernels::layer_norm_row(
9632                                &input[row * h..(row + 1) * h],
9633                                gamma,
9634                                beta,
9635                                &mut output[row * h..(row + 1) * h],
9636                                h,
9637                                *eps,
9638                            );
9639                        }
9640                    }
9641                }
9642            }
9643
9644            Thunk::GroupNorm {
9645                src,
9646                g,
9647                b,
9648                dst,
9649                n,
9650                c,
9651                h,
9652                w,
9653                num_groups,
9654                eps,
9655            } => {
9656                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9657                let plane = c * h * w;
9658                unsafe {
9659                    for ni in 0..n {
9660                        let input = sl(*src, base.add(ni * plane), plane);
9661                        let gamma = sl(*g, base, c);
9662                        let beta = sl(*b, base, c);
9663                        let output = sl_mut(*dst, base.add(ni * plane), plane);
9664                        crate::kernels::group_norm_nchw(
9665                            input,
9666                            gamma,
9667                            beta,
9668                            output,
9669                            1,
9670                            c,
9671                            h,
9672                            w,
9673                            *num_groups as usize,
9674                            *eps,
9675                        );
9676                    }
9677                }
9678            }
9679
9680            Thunk::BatchNormInference {
9681                src,
9682                g,
9683                b,
9684                mean,
9685                var,
9686                dst,
9687                count,
9688                channels,
9689                eps,
9690            } => {
9691                let count = *count as usize;
9692                let c = *channels as usize;
9693                let n = count * c;
9694                unsafe {
9695                    crate::kernels::batch_norm_inference(
9696                        sl(*src, base, n),
9697                        sl(*g, base, c),
9698                        sl(*b, base, c),
9699                        sl(*mean, base, c),
9700                        sl(*var, base, c),
9701                        sl_mut(*dst, base, n),
9702                        c,
9703                        *eps,
9704                    );
9705                }
9706            }
9707
9708            Thunk::LayerNorm2d {
9709                src,
9710                g,
9711                b,
9712                dst,
9713                n,
9714                c,
9715                h,
9716                w,
9717                eps,
9718            } => {
9719                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9720                let plane = c * h * w;
9721                unsafe {
9722                    let input = sl(*src, base, n * plane);
9723                    let gamma = sl(*g, base, c);
9724                    let beta = sl(*b, base, c);
9725                    let output = sl_mut(*dst, base, n * plane);
9726                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
9727                }
9728            }
9729
9730            Thunk::ConvTranspose2d {
9731                src,
9732                weight,
9733                dst,
9734                n,
9735                c_in,
9736                h,
9737                w_in,
9738                c_out,
9739                h_out,
9740                w_out,
9741                kh,
9742                kw,
9743                sh,
9744                sw,
9745                ph,
9746                pw,
9747                dh,
9748                dw,
9749                groups,
9750            } => {
9751                let n = *n as usize;
9752                let c_in = *c_in as usize;
9753                let h = *h as usize;
9754                let w_in = *w_in as usize;
9755                let c_out = *c_out as usize;
9756                let h_out = *h_out as usize;
9757                let w_out = *w_out as usize;
9758                unsafe {
9759                    let inp = sl(*src, base, n * c_in * h * w_in);
9760                    let wt = sl(
9761                        *weight,
9762                        base,
9763                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
9764                    );
9765                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
9766                    crate::kernels::conv_transpose2d_nchw(
9767                        inp,
9768                        wt,
9769                        out,
9770                        n,
9771                        c_in,
9772                        h,
9773                        w_in,
9774                        c_out,
9775                        h_out,
9776                        w_out,
9777                        *kh as usize,
9778                        *kw as usize,
9779                        *sh as usize,
9780                        *sw as usize,
9781                        *ph as usize,
9782                        *pw as usize,
9783                        *dh as usize,
9784                        *dw as usize,
9785                        *groups as usize,
9786                    );
9787                }
9788            }
9789
9790            Thunk::ResizeNearest2x {
9791                src,
9792                dst,
9793                n,
9794                c,
9795                h,
9796                w,
9797            } => {
9798                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9799                let in_plane = c * h * w;
9800                let out_plane = c * h * 2 * w * 2;
9801                unsafe {
9802                    for ni in 0..n {
9803                        let input = sl(*src, base.add(ni * in_plane), in_plane);
9804                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
9805                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
9806                    }
9807                }
9808            }
9809
9810            Thunk::AxialRope2d {
9811                src,
9812                dst,
9813                batch,
9814                seq,
9815                hidden,
9816                end_x,
9817                end_y,
9818                head_dim,
9819                num_heads,
9820                theta,
9821                repeat_factor,
9822            } => {
9823                let b = *batch as usize;
9824                let s = *seq as usize;
9825                let hdim = *head_dim as usize;
9826                let nh = *num_heads as usize;
9827                let plane = s * (*hidden as usize);
9828                unsafe {
9829                    for bi in 0..b {
9830                        let input = sl(*src, base.add(bi * plane), plane);
9831                        let output = sl_mut(*dst, base.add(bi * plane), plane);
9832                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
9833                            input,
9834                            nh,
9835                            s,
9836                            hdim,
9837                            *end_x as usize,
9838                            *end_y as usize,
9839                            *theta,
9840                            *repeat_factor as usize,
9841                        );
9842                        output.copy_from_slice(&rotated);
9843                    }
9844                }
9845            }
9846
9847            Thunk::RmsNorm {
9848                src,
9849                g,
9850                b,
9851                dst,
9852                rows,
9853                h,
9854                eps,
9855            } => {
9856                let (rows, h) = (*rows as usize, *h as usize);
9857                unsafe {
9858                    let input = sl(*src, base, rows * h);
9859                    let gamma = sl(*g, base, h);
9860                    let beta = sl(*b, base, h);
9861                    let output = sl_mut(*dst, base, rows * h);
9862                    let inv_h = 1.0 / h as f32;
9863                    for row in 0..rows {
9864                        let in_row = &input[row * h..(row + 1) * h];
9865                        let out_row = &mut output[row * h..(row + 1) * h];
9866                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
9867                        let mut sumsq = 0f32;
9868                        for &v in in_row {
9869                            sumsq += v * v;
9870                        }
9871                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
9872                        for i in 0..h {
9873                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
9874                        }
9875                    }
9876                }
9877            }
9878
9879            Thunk::Softmax { data, rows, cols } => {
9880                let (rows, cols) = (*rows as usize, *cols as usize);
9881                unsafe {
9882                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
9883                }
9884            }
9885
9886            Thunk::Cumsum {
9887                src,
9888                dst,
9889                rows,
9890                cols,
9891                exclusive,
9892            } => {
9893                let (rows, cols) = (*rows as usize, *cols as usize);
9894                unsafe {
9895                    let s = sl(*src, base, rows * cols);
9896                    let d = sl_mut(*dst, base, rows * cols);
9897                    if *exclusive {
9898                        for r in 0..rows {
9899                            let mut acc = 0.0f32;
9900                            for c in 0..cols {
9901                                d[r * cols + c] = acc;
9902                                acc += s[r * cols + c];
9903                            }
9904                        }
9905                    } else {
9906                        for r in 0..rows {
9907                            let mut acc = 0.0f32;
9908                            for c in 0..cols {
9909                                acc += s[r * cols + c];
9910                                d[r * cols + c] = acc;
9911                            }
9912                        }
9913                    }
9914                }
9915            }
9916
9917            Thunk::Sample {
9918                logits,
9919                dst,
9920                batch,
9921                vocab,
9922                top_k,
9923                top_p,
9924                temperature,
9925                seed,
9926            } => {
9927                let (b, v) = (*batch as usize, *vocab as usize);
9928                let k = (*top_k as usize).min(v);
9929                unsafe {
9930                    let lg = sl(*logits, base, b * v);
9931                    let out = sl_mut(*dst, base, b);
9932                    let mut rng =
9933                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
9934                    for bi in 0..b {
9935                        let row = &lg[bi * v..(bi + 1) * v];
9936                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
9937                    }
9938                }
9939            }
9940
9941            Thunk::GatedDeltaNet {
9942                q,
9943                k,
9944                v,
9945                g,
9946                beta,
9947                state,
9948                dst,
9949                batch,
9950                seq,
9951                heads,
9952                state_size,
9953            } => unsafe {
9954                execute_gated_delta_net_f32(
9955                    *q,
9956                    *k,
9957                    *v,
9958                    *g,
9959                    *beta,
9960                    *state,
9961                    *dst,
9962                    *batch as usize,
9963                    *seq as usize,
9964                    *heads as usize,
9965                    *state_size as usize,
9966                    base,
9967                );
9968            },
9969
9970            Thunk::SelectiveScan {
9971                x,
9972                delta,
9973                a,
9974                b: bp,
9975                c: cp,
9976                dst,
9977                batch,
9978                seq,
9979                hidden,
9980                state_size,
9981            } => {
9982                let (b, s, h, n) = (
9983                    *batch as usize,
9984                    *seq as usize,
9985                    *hidden as usize,
9986                    *state_size as usize,
9987                );
9988                unsafe {
9989                    let xs = sl(*x, base, b * s * h);
9990                    let dt = sl(*delta, base, b * s * h);
9991                    let am = sl(*a, base, h * n);
9992                    let bm = sl(*bp, base, b * s * n);
9993                    let cm = sl(*cp, base, b * s * n);
9994                    let out = sl_mut(*dst, base, b * s * h);
9995
9996                    // State buffer per-batch: h channels × n state.
9997                    // Sequential along the seq dimension; could
9998                    // parallelize over batch+channel later.
9999                    let mut state = vec![0f32; h * n];
10000                    for bi in 0..b {
10001                        // Reset state at the start of each batch row.
10002                        for v in state.iter_mut() {
10003                            *v = 0.0;
10004                        }
10005                        for si in 0..s {
10006                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10007                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10008                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10009                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10010                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10011
10012                            for ci in 0..h {
10013                                let d = dt_row[ci];
10014                                let xv = x_row[ci];
10015                                let mut acc = 0f32;
10016                                for ni in 0..n {
10017                                    // Discretize: exp(d * a) and d * b.
10018                                    let da = (d * am[ci * n + ni]).exp();
10019                                    state[ci * n + ni] =
10020                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
10021                                    acc += c_row[ni] * state[ci * n + ni];
10022                                }
10023                                out_row[ci] = acc;
10024                            }
10025                        }
10026                    }
10027                }
10028            }
10029
10030            Thunk::DequantMatMul {
10031                x,
10032                w_q,
10033                scale,
10034                zp,
10035                dst,
10036                m,
10037                k,
10038                n,
10039                block_size,
10040                is_asymmetric,
10041            } => {
10042                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10043                let n_blocks = k.div_ceil(bs);
10044                unsafe {
10045                    let xs = sl(*x, base, m * k);
10046                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10047                    let scales = sl(*scale, base, n_blocks * n);
10048                    let zps = if *is_asymmetric {
10049                        sl(*zp, base, n_blocks * n)
10050                    } else {
10051                        &[][..]
10052                    };
10053                    let out = sl_mut(*dst, base, m * n);
10054                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10055                }
10056            }
10057
10058            Thunk::DequantMatMulGguf {
10059                x,
10060                w_q,
10061                dst,
10062                m,
10063                k,
10064                n,
10065                scheme,
10066            } => {
10067                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10068                let block_bytes = scheme.gguf_block_bytes() as usize;
10069                let block_elems = scheme.gguf_block_size() as usize;
10070                debug_assert!(
10071                    block_bytes > 0 && block_elems > 0,
10072                    "non-GGUF scheme in GGUF arm"
10073                );
10074                debug_assert!(
10075                    (k * n).is_multiple_of(block_elems),
10076                    "k*n={} not aligned to GGUF block size {}",
10077                    k * n,
10078                    block_elems
10079                );
10080                let total_bytes = (k * n) / block_elems * block_bytes;
10081                unsafe {
10082                    let xs = sl(*x, base, m * k);
10083                    let w_bytes_ptr = base.add(*w_q) as *const u8;
10084                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10085                    let out = sl_mut(*dst, base, m * n);
10086                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10087                }
10088            }
10089
10090            Thunk::DequantMatMulInt4 {
10091                x,
10092                w_q,
10093                scale,
10094                zp,
10095                dst,
10096                m,
10097                k,
10098                n,
10099                block_size,
10100                is_asymmetric,
10101            } => {
10102                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10103                let n_blocks = k.div_ceil(bs);
10104                unsafe {
10105                    let xs = sl(*x, base, m * k);
10106                    let w_bytes = std::slice::from_raw_parts(
10107                        base.add(*w_q) as *const u8,
10108                        (k * n).div_ceil(2),
10109                    );
10110                    let scales = sl(*scale, base, n_blocks * n);
10111                    let zps = if *is_asymmetric {
10112                        sl(*zp, base, n_blocks * n)
10113                    } else {
10114                        &[][..]
10115                    };
10116                    let out = sl_mut(*dst, base, m * n);
10117                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10118                }
10119            }
10120
10121            Thunk::DequantMatMulFp8 {
10122                x,
10123                w_q,
10124                scale,
10125                dst,
10126                m,
10127                k,
10128                n,
10129                e5m2,
10130            } => {
10131                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10132                unsafe {
10133                    let xs = sl(*x, base, m * k);
10134                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10135                    let scales = sl(*scale, base, n);
10136                    let out = sl_mut(*dst, base, m * n);
10137                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10138                }
10139            }
10140
10141            Thunk::DequantMatMulNvfp4 {
10142                x,
10143                w_q,
10144                scale,
10145                global_scale,
10146                dst,
10147                m,
10148                k,
10149                n,
10150            } => {
10151                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10152                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10153                unsafe {
10154                    let xs = sl(*x, base, m * k);
10155                    let w_bytes = std::slice::from_raw_parts(
10156                        base.add(*w_q) as *const u8,
10157                        (k * n).div_ceil(2),
10158                    );
10159                    let scale_bytes =
10160                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10161                    let gs = sl(*global_scale, base, 1)[0];
10162                    let out = sl_mut(*dst, base, m * n);
10163                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10164                }
10165            }
10166
10167            Thunk::LoraMatMul {
10168                x,
10169                w,
10170                a,
10171                b,
10172                dst,
10173                m,
10174                k,
10175                n,
10176                r,
10177                scale,
10178            } => {
10179                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10180                unsafe {
10181                    let xs = sl(*x, base, m * k);
10182                    let ws = sl(*w, base, k * n);
10183                    let a_s = sl(*a, base, k * r);
10184                    let bs = sl(*b, base, r * n);
10185                    let out = sl_mut(*dst, base, m * n);
10186                    crate::blas::sgemm(xs, ws, out, m, k, n);
10187                    let mut tmp = vec![0f32; m * r];
10188                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10189                    if *scale != 1.0 {
10190                        for v in tmp.iter_mut() {
10191                            *v *= *scale;
10192                        }
10193                    }
10194                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10195                }
10196            }
10197
10198            Thunk::Attention {
10199                q,
10200                k,
10201                v,
10202                mask,
10203                out,
10204                batch,
10205                seq,
10206                kv_seq,
10207                heads,
10208                head_dim,
10209                mask_kind,
10210                q_row_stride,
10211                k_row_stride,
10212                v_row_stride,
10213                bhsd,
10214            } => {
10215                let (b, q_s, k_s, nh, dh) = (
10216                    *batch as usize,
10217                    *seq as usize,
10218                    *kv_seq as usize,
10219                    *heads as usize,
10220                    *head_dim as usize,
10221                );
10222                let hs = nh * dh;
10223                // For [B, H, S, D] layout each (b, h) tile is dense
10224                // contiguous; the qrs/krs/vrs strides are not used.
10225                let (qrs, krs, vrs) = if *bhsd {
10226                    (dh, dh, dh)
10227                } else {
10228                    (
10229                        *q_row_stride as usize,
10230                        *k_row_stride as usize,
10231                        *v_row_stride as usize,
10232                    )
10233                };
10234                let bhsd = *bhsd;
10235                let _ = (q_row_stride, k_row_stride, v_row_stride);
10236                let scale = (dh as f32).powf(-0.5);
10237                let ss = q_s * k_s;
10238                let cfg = crate::config::RuntimeConfig::global();
10239                unsafe {
10240                    // Slice lengths cover the strided span. When Q/K/V
10241                    // alias the parent QKV (post-#46-fusion), the same
10242                    // bytes back all three slices — compiler bounds
10243                    // checks see the right size. For [B, H, S, D] the
10244                    // buffer is densely B*H*S*D elements; the row
10245                    // strides aren't used.
10246                    let q_len = if bhsd {
10247                        b * nh * q_s * dh
10248                    } else {
10249                        b * q_s * qrs
10250                    };
10251                    let k_len = if bhsd {
10252                        b * nh * k_s * dh
10253                    } else {
10254                        b * k_s * krs
10255                    };
10256                    let v_len = if bhsd {
10257                        b * nh * k_s * dh
10258                    } else {
10259                        b * k_s * vrs
10260                    };
10261                    let q_data = sl(*q, base, q_len);
10262                    let k_data = sl(*k, base, k_len);
10263                    let v_data = sl(*v, base, v_len);
10264                    let mask_data: &[f32] = match mask_kind {
10265                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10266                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10267                        _ => &[],
10268                    };
10269                    let out_len = if bhsd {
10270                        b * nh * q_s * dh
10271                    } else {
10272                        b * q_s * hs
10273                    };
10274                    let out_data = sl_mut(*out, base, out_len);
10275
10276                    // ── [B, H, S, D] fallback ──────────────────────
10277                    // The NEON / strided-BLAS specializations below
10278                    // are written for the [B, S, H, D] layout. When
10279                    // the input is head-major ([B, H, S, D] —
10280                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
10281                    // them and run a simple (correct but slower)
10282                    // scalar implementation. Production-CPU inference
10283                    // graphs use [B, S, H, D] so they still hit the
10284                    // hot path; cross-backend parity tests use
10285                    // [B, H, S, D] and land here.
10286                    if bhsd {
10287                        let scores = &mut sdpa_scores[..ss];
10288                        for bi in 0..b {
10289                            for hi in 0..nh {
10290                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10291                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10292                                // Q@K^T
10293                                for qi in 0..q_s {
10294                                    let q_base = q_head_base + qi * dh;
10295                                    for ki in 0..k_s {
10296                                        let k_base = k_head_base + ki * dh;
10297                                        let mut dot = 0f32;
10298                                        for d in 0..dh {
10299                                            dot += q_data[q_base + d] * k_data[k_base + d];
10300                                        }
10301                                        scores[qi * k_s + ki] = dot * scale;
10302                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10303                                            && !mask_data.is_empty()
10304                                            && mask_data[bi * k_s + ki] < mask_thr
10305                                        {
10306                                            scores[qi * k_s + ki] = mask_neg;
10307                                        }
10308                                    }
10309                                }
10310                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10311                                    let off = (bi * nh + hi) * q_s * k_s;
10312                                    for i in 0..q_s * k_s {
10313                                        scores[i] += mask_data[off + i];
10314                                    }
10315                                }
10316                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10317                                crate::kernels::neon_softmax(scores, q_s, k_s);
10318                                // score @ V
10319                                for qi in 0..q_s {
10320                                    let o_base = q_head_base + qi * dh;
10321                                    for d in 0..dh {
10322                                        out_data[o_base + d] = 0.0;
10323                                    }
10324                                    for ki in 0..k_s {
10325                                        let sc = scores[qi * k_s + ki];
10326                                        if sc > score_thr {
10327                                            let v_base = k_head_base + ki * dh;
10328                                            for d in 0..dh {
10329                                                out_data[o_base + d] += sc * v_data[v_base + d];
10330                                            }
10331                                        }
10332                                    }
10333                                }
10334                            }
10335                        }
10336                        continue;
10337                    }
10338
10339                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
10340                    // For tiny inputs (batch=1, short seq), per-head BLAS call
10341                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
10342                    // exceeds the NEON compute cost. Use direct strided NEON
10343                    // with zero dispatch overhead.
10344                    // For batch≥2: always BLAS + par_for (parallelism wins).
10345                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10346                        // ── Sequential NEON path (zero overhead) ──
10347                        let scores = &mut sdpa_scores[..ss];
10348                        #[cfg(target_arch = "aarch64")]
10349                        let neon_chunks = dh / 4;
10350
10351                        for bi in 0..b {
10352                            for hi in 0..nh {
10353                                // Q@K^T via strided NEON dot products
10354                                for qi in 0..q_s {
10355                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10356                                    for ki in 0..k_s {
10357                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
10358                                        #[cfg(target_arch = "aarch64")]
10359                                        let mut dot;
10360                                        #[cfg(not(target_arch = "aarch64"))]
10361                                        let mut dot = 0f32;
10362                                        #[cfg(target_arch = "aarch64")]
10363                                        {
10364                                            use std::arch::aarch64::*;
10365                                            let mut acc = vdupq_n_f32(0.0);
10366                                            for c in 0..neon_chunks {
10367                                                let vq =
10368                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10369                                                let vk =
10370                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10371                                                acc = vfmaq_f32(acc, vq, vk);
10372                                            }
10373                                            dot = vaddvq_f32(acc);
10374                                            for d in (neon_chunks * 4)..dh {
10375                                                dot += q_data[q_off + d] * k_data[k_off + d];
10376                                            }
10377                                        }
10378                                        #[cfg(not(target_arch = "aarch64"))]
10379                                        for d in 0..dh {
10380                                            dot += q_data[q_off + d] * k_data[k_off + d];
10381                                        }
10382                                        scores[qi * k_s + ki] = dot * scale;
10383                                        // Inner-loop Custom mask check —
10384                                        // Causal / SlidingWindow / None
10385                                        // apply outside the loop below.
10386                                        // Skip for Bias — that mask is a
10387                                        // per-head additive tensor, not a
10388                                        // 0/1 key-padding mask.
10389                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10390                                            && !mask_data.is_empty()
10391                                            && mask_data[bi * k_s + ki] < mask_thr
10392                                        {
10393                                            scores[qi * k_s + ki] = mask_neg;
10394                                        }
10395                                    }
10396                                }
10397
10398                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10399                                    let off = (bi * nh + hi) * q_s * k_s;
10400                                    for i in 0..q_s * k_s {
10401                                        scores[i] += mask_data[off + i];
10402                                    }
10403                                }
10404                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10405                                crate::kernels::neon_softmax(scores, q_s, k_s);
10406
10407                                // Score@V via strided NEON accumulation (zero-copy)
10408                                for qi in 0..q_s {
10409                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
10410                                    // Zero output for this head position
10411                                    for d in 0..dh {
10412                                        out_data[o_off + d] = 0.0;
10413                                    }
10414                                    for ki in 0..k_s {
10415                                        let sc = scores[qi * k_s + ki];
10416                                        if sc > score_thr {
10417                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
10418                                            #[cfg(target_arch = "aarch64")]
10419                                            {
10420                                                use std::arch::aarch64::*;
10421                                                let vsc = vdupq_n_f32(sc);
10422                                                for c in 0..neon_chunks {
10423                                                    let off = c * 4;
10424                                                    let vo = vld1q_f32(
10425                                                        out_data.as_ptr().add(o_off + off),
10426                                                    );
10427                                                    let vv =
10428                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
10429                                                    vst1q_f32(
10430                                                        out_data.as_mut_ptr().add(o_off + off),
10431                                                        vfmaq_f32(vo, vsc, vv),
10432                                                    );
10433                                                }
10434                                            }
10435                                            #[cfg(not(target_arch = "aarch64"))]
10436                                            for d in 0..dh {
10437                                                out_data[o_off + d] += sc * v_data[v_off + d];
10438                                            }
10439                                        }
10440                                    }
10441                                }
10442                            }
10443                        }
10444                    } else {
10445                        // ── Parallel strided BLAS path (high throughput) ──
10446                        let total_work = b * nh;
10447                        let q_addr = q_data.as_ptr() as usize;
10448                        let k_addr = k_data.as_ptr() as usize;
10449                        let v_addr = v_data.as_ptr() as usize;
10450                        let m_addr = mask_data.as_ptr() as usize;
10451                        let o_addr = out_data.as_mut_ptr() as usize;
10452                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
10453
10454                        crate::pool::par_for(total_work, 1, &|off, cnt| {
10455                            for idx in off..off + cnt {
10456                                let bi = idx / nh;
10457                                let hi = idx % nh;
10458
10459                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
10460                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
10461                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
10462                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
10463                                let sc = std::slice::from_raw_parts_mut(
10464                                    (sc_addr as *mut f32).add(idx * ss),
10465                                    ss,
10466                                );
10467
10468                                // LDA = qrs, LDB = krs (parent row strides
10469                                // when fused; hs otherwise).
10470                                crate::blas::sgemm_general(
10471                                    q_start,
10472                                    k_start,
10473                                    sc.as_mut_ptr(),
10474                                    q_s,
10475                                    k_s,
10476                                    dh,
10477                                    scale,
10478                                    0.0,
10479                                    qrs,
10480                                    krs,
10481                                    k_s,
10482                                    false,
10483                                    true,
10484                                );
10485
10486                                match mask_kind {
10487                                    rlx_ir::op::MaskKind::Custom => {
10488                                        let mask_bi = std::slice::from_raw_parts(
10489                                            (m_addr as *const f32).add(bi * k_s),
10490                                            k_s,
10491                                        );
10492                                        for ki in 0..k_s {
10493                                            if mask_bi[ki] < mask_thr {
10494                                                for qi in 0..q_s {
10495                                                    sc[qi * k_s + ki] = mask_neg;
10496                                                }
10497                                            }
10498                                        }
10499                                    }
10500                                    rlx_ir::op::MaskKind::Bias => {
10501                                        // Per-head additive bias slice.
10502                                        let bias = std::slice::from_raw_parts(
10503                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
10504                                            q_s * k_s,
10505                                        );
10506                                        for i in 0..q_s * k_s {
10507                                            sc[i] += bias[i];
10508                                        }
10509                                    }
10510                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
10511                                }
10512
10513                                crate::kernels::neon_softmax(sc, q_s, k_s);
10514
10515                                // LDB = vrs (parent row stride when
10516                                // fused; hs otherwise). LDC stays hs —
10517                                // output is its own contiguous buffer.
10518                                crate::blas::sgemm_general(
10519                                    sc.as_ptr(),
10520                                    v_start,
10521                                    o_start,
10522                                    q_s,
10523                                    dh,
10524                                    k_s,
10525                                    1.0,
10526                                    0.0,
10527                                    k_s,
10528                                    vrs,
10529                                    hs,
10530                                    false,
10531                                    false,
10532                                );
10533                            }
10534                        });
10535                    }
10536                }
10537            }
10538
10539            Thunk::AttentionBackward {
10540                q,
10541                k,
10542                v,
10543                dy,
10544                mask,
10545                out,
10546                batch,
10547                seq,
10548                kv_seq,
10549                heads,
10550                head_dim,
10551                mask_kind,
10552                wrt,
10553                bhsd,
10554            } => {
10555                let (b, q_s, k_s, nh, dh) = (
10556                    *batch as usize,
10557                    *seq as usize,
10558                    *kv_seq as usize,
10559                    *heads as usize,
10560                    *head_dim as usize,
10561                );
10562                unsafe {
10563                    let q_len = if *bhsd {
10564                        b * nh * q_s * dh
10565                    } else {
10566                        b * q_s * nh * dh
10567                    };
10568                    let k_len = if *bhsd {
10569                        b * nh * k_s * dh
10570                    } else {
10571                        b * k_s * nh * dh
10572                    };
10573                    let out_len = match wrt {
10574                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
10575                            k_len
10576                        }
10577                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
10578                    };
10579                    let q_data = sl(*q, base, q_len);
10580                    let k_data = sl(*k, base, k_len);
10581                    let v_data = sl(*v, base, k_len);
10582                    let dy_data = sl(*dy, base, q_len);
10583                    let out_data = sl_mut(*out, base, out_len);
10584                    let mask_data: &[f32] = if *mask != 0 {
10585                        let ml = match mask_kind {
10586                            rlx_ir::op::MaskKind::Custom => b * k_s,
10587                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
10588                            _ => 0,
10589                        };
10590                        sl(*mask, base, ml)
10591                    } else {
10592                        &[]
10593                    };
10594                    crate::attention_bwd::attention_backward(
10595                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
10596                        *mask_kind, mask_data, *bhsd,
10597                    );
10598                }
10599            }
10600
10601            Thunk::ActivationInPlace { data, len, act } => {
10602                let len = *len as usize;
10603                unsafe {
10604                    let d = sl_mut(*data, base, len);
10605                    match act {
10606                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
10607                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
10608                        Activation::Silu => crate::kernels::par_silu_inplace(d),
10609                        Activation::Relu => {
10610                            for v in d.iter_mut() {
10611                                *v = v.max(0.0);
10612                            }
10613                        }
10614                        Activation::Sigmoid => {
10615                            for v in d.iter_mut() {
10616                                *v = 1.0 / (1.0 + (-*v).exp());
10617                            }
10618                        }
10619                        Activation::Tanh => {
10620                            for v in d.iter_mut() {
10621                                *v = v.tanh();
10622                            }
10623                        }
10624                        Activation::Exp => {
10625                            for v in d.iter_mut() {
10626                                *v = v.exp();
10627                            }
10628                        }
10629                        Activation::Log => {
10630                            for v in d.iter_mut() {
10631                                *v = v.ln();
10632                            }
10633                        }
10634                        Activation::Sqrt => {
10635                            for v in d.iter_mut() {
10636                                *v = v.sqrt();
10637                            }
10638                        }
10639                        Activation::Rsqrt => {
10640                            for v in d.iter_mut() {
10641                                *v = 1.0 / v.sqrt();
10642                            }
10643                        }
10644                        Activation::Neg => {
10645                            for v in d.iter_mut() {
10646                                *v = -*v;
10647                            }
10648                        }
10649                        Activation::Abs => {
10650                            for v in d.iter_mut() {
10651                                *v = v.abs();
10652                            }
10653                        }
10654                        Activation::Round => {
10655                            for v in d.iter_mut() {
10656                                *v = v.round();
10657                            }
10658                        }
10659                        Activation::Sin => {
10660                            for v in d.iter_mut() {
10661                                *v = v.sin();
10662                            }
10663                        }
10664                        Activation::Cos => {
10665                            for v in d.iter_mut() {
10666                                *v = v.cos();
10667                            }
10668                        }
10669                        Activation::Tan => {
10670                            for v in d.iter_mut() {
10671                                *v = v.tan();
10672                            }
10673                        }
10674                        Activation::Atan => {
10675                            for v in d.iter_mut() {
10676                                *v = v.atan();
10677                            }
10678                        }
10679                    }
10680                }
10681            }
10682
10683            Thunk::FusedAttnBlock {
10684                hidden,
10685                qkv_w,
10686                out_w,
10687                mask,
10688                out,
10689                qkv_b,
10690                out_b,
10691                cos,
10692                sin,
10693                cos_len,
10694                batch,
10695                seq,
10696                hs,
10697                nh,
10698                dh,
10699                has_bias,
10700                has_rope,
10701            } => {
10702                let (b, s) = (*batch as usize, *seq as usize);
10703                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
10704                let m = b * s;
10705                let scale = (d_h as f32).powf(-0.5);
10706                let half = d_h / 2;
10707                unsafe {
10708                    let inp = sl(*hidden, base, m * h);
10709                    let wq = sl(*qkv_w, base, h * 3 * h);
10710                    let wo = sl(*out_w, base, h * h);
10711                    let mk = sl(*mask, base, b * s);
10712                    let dst = sl_mut(*out, base, m * h);
10713
10714                    // Stack-allocated intermediates — all fit in L1 cache for small batch
10715                    let mut qkv = vec![0f32; m * 3 * h];
10716                    let mut attn_out = vec![0f32; m * h];
10717                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
10718
10719                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
10720                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
10721                    if *has_bias {
10722                        let bias = sl(*qkv_b, base, 3 * h);
10723                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
10724                    }
10725
10726                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
10727                    //    Process heads sequentially with inline RoPE — zero copy.
10728                    #[cfg(target_arch = "aarch64")]
10729                    let neon_chunks = d_h / 4;
10730                    #[cfg(target_arch = "aarch64")]
10731                    let _rope_chunks = half / 4;
10732
10733                    for bi in 0..b {
10734                        for hi in 0..n_h {
10735                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
10736                            for qi in 0..s {
10737                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10738                                for ki in 0..s {
10739                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10740                                    let mut dot = 0f32;
10741
10742                                    if *has_rope {
10743                                        // Apply RoPE inline during dot product
10744                                        let q_cos = qi * half;
10745                                        let k_cos = ki * half;
10746                                        let cos_tab = sl(*cos, base, *cos_len as usize);
10747                                        let sin_tab = sl(*sin, base, *cos_len as usize);
10748                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
10749                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
10750                                        for i in 0..half {
10751                                            let q1 = qkv[q_base + i];
10752                                            let q2 = qkv[q_base + half + i];
10753                                            let k1 = qkv[k_base + i];
10754                                            let k2 = qkv[k_base + half + i];
10755                                            let c_q = cos_tab[q_cos + i];
10756                                            let s_q = sin_tab[q_cos + i];
10757                                            let c_k = cos_tab[k_cos + i];
10758                                            let s_k = sin_tab[k_cos + i];
10759                                            let qr1 = q1 * c_q - q2 * s_q;
10760                                            let kr1 = k1 * c_k - k2 * s_k;
10761                                            let qr2 = q2 * c_q + q1 * s_q;
10762                                            let kr2 = k2 * c_k + k1 * s_k;
10763                                            dot += qr1 * kr1 + qr2 * kr2;
10764                                        }
10765                                    } else {
10766                                        // Standard dot product
10767                                        #[cfg(target_arch = "aarch64")]
10768                                        {
10769                                            use std::arch::aarch64::*;
10770                                            let mut acc = vdupq_n_f32(0.0);
10771                                            for c in 0..neon_chunks {
10772                                                let vq =
10773                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
10774                                                let vk =
10775                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
10776                                                acc = vfmaq_f32(acc, vq, vk);
10777                                            }
10778                                            dot = vaddvq_f32(acc);
10779                                            for d in (neon_chunks * 4)..d_h {
10780                                                dot += qkv[q_base + d] * qkv[k_base + d];
10781                                            }
10782                                        }
10783                                        #[cfg(not(target_arch = "aarch64"))]
10784                                        for d in 0..d_h {
10785                                            dot += qkv[q_base + d] * qkv[k_base + d];
10786                                        }
10787                                    }
10788
10789                                    scores_buf[qi * s + ki] = dot * scale;
10790                                    if mk[bi * s + ki] < mask_thr {
10791                                        scores_buf[qi * s + ki] = mask_neg;
10792                                    }
10793                                }
10794                            }
10795
10796                            // Softmax
10797                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
10798
10799                            // Score @ V accumulation (V at offset 2h in QKV)
10800                            for qi in 0..s {
10801                                let o_base = bi * s * h + qi * h + hi * d_h;
10802                                for d in 0..d_h {
10803                                    attn_out[o_base + d] = 0.0;
10804                                }
10805                                for ki in 0..s {
10806                                    let sc = scores_buf[qi * s + ki];
10807                                    if sc > score_thr {
10808                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10809                                        #[cfg(target_arch = "aarch64")]
10810                                        {
10811                                            use std::arch::aarch64::*;
10812                                            let vsc = vdupq_n_f32(sc);
10813                                            for c in 0..neon_chunks {
10814                                                let off = c * 4;
10815                                                let vo =
10816                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
10817                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
10818                                                vst1q_f32(
10819                                                    attn_out.as_mut_ptr().add(o_base + off),
10820                                                    vfmaq_f32(vo, vsc, vv),
10821                                                );
10822                                            }
10823                                        }
10824                                        #[cfg(not(target_arch = "aarch64"))]
10825                                        for d in 0..d_h {
10826                                            attn_out[o_base + d] += sc * qkv[v_base + d];
10827                                        }
10828                                    }
10829                                }
10830                            }
10831                        }
10832                    }
10833
10834                    // 3. Output projection: [m, h] @ [h, h] → dst
10835                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
10836                    if *has_bias {
10837                        let bias = sl(*out_b, base, h);
10838                        crate::blas::bias_add(dst, bias, m, h);
10839                    }
10840                }
10841            }
10842
10843            Thunk::Rope {
10844                src,
10845                cos,
10846                sin,
10847                dst,
10848                batch,
10849                seq,
10850                hidden,
10851                head_dim,
10852                n_rot,
10853                cos_len,
10854                src_row_stride,
10855            } => {
10856                let (b, s, hs, dh, nr) = (
10857                    *batch as usize,
10858                    *seq as usize,
10859                    *hidden as usize,
10860                    *head_dim as usize,
10861                    *n_rot as usize,
10862                );
10863                let tab_half = dh / 2;
10864                let rot_half = nr / 2;
10865                let nh = hs / dh;
10866                let cl = *cos_len as usize;
10867                let src_rs = *src_row_stride as usize;
10868                unsafe {
10869                    let x = sl(*src, base, b * s * src_rs);
10870                    let cos_tab = sl(*cos, base, cl);
10871                    let sin_tab = sl(*sin, base, cl);
10872                    let out = sl_mut(*dst, base, b * s * hs);
10873
10874                    let total = b * s;
10875                    let x_ptr = x.as_ptr() as usize;
10876                    let o_ptr = out.as_mut_ptr() as usize;
10877                    let c_ptr = cos_tab.as_ptr() as usize;
10878                    let s_ptr = sin_tab.as_ptr() as usize;
10879
10880                    crate::pool::par_for(total, 4, &|off, cnt| {
10881                        for idx in off..off + cnt {
10882                            let bi = idx / s;
10883                            let si = idx % s;
10884                            let tab_off = si * tab_half;
10885
10886                            for hi in 0..nh {
10887                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
10888                                let dst_base = bi * s * hs + si * hs + hi * dh;
10889                                let xp = (x_ptr as *const f32).add(src_base);
10890                                let op = (o_ptr as *mut f32).add(dst_base);
10891                                let cp = (c_ptr as *const f32).add(tab_off);
10892                                let sp = (s_ptr as *const f32).add(tab_off);
10893
10894                                for i in 0..rot_half {
10895                                    let x1 = *xp.add(i);
10896                                    let x2 = *xp.add(rot_half + i);
10897                                    let cv = *cp.add(i);
10898                                    let sv = *sp.add(i);
10899                                    *op.add(i) = x1 * cv - x2 * sv;
10900                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
10901                                }
10902                                for j in nr..dh {
10903                                    *op.add(j) = *xp.add(j);
10904                                }
10905                            }
10906                        }
10907                    });
10908                }
10909            }
10910            Thunk::FusedBertLayer {
10911                hidden,
10912                qkv_w,
10913                qkv_b,
10914                out_w,
10915                out_b,
10916                mask,
10917                ln1_g,
10918                ln1_b,
10919                eps1,
10920                fc1_w,
10921                fc1_b,
10922                fc2_w,
10923                fc2_b,
10924                ln2_g,
10925                ln2_b,
10926                eps2,
10927                out,
10928                batch,
10929                seq,
10930                hs,
10931                nh,
10932                dh,
10933                int_dim,
10934            } => {
10935                let (b, s, h, n_h, d_h) = (
10936                    *batch as usize,
10937                    *seq as usize,
10938                    *hs as usize,
10939                    *nh as usize,
10940                    *dh as usize,
10941                );
10942                let m = b * s;
10943                let id = *int_dim as usize;
10944                let scale = (d_h as f32).powf(-0.5);
10945                let _half = d_h / 2;
10946                #[cfg(target_arch = "aarch64")]
10947                let neon_chunks = d_h / 4;
10948                unsafe {
10949                    let inp = sl(*hidden, base, m * h);
10950                    let dst = sl_mut(*out, base, m * h);
10951                    let mk = sl(*mask, base, b * s);
10952
10953                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
10954                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
10955                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
10956                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
10957                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
10958                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
10959                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
10960
10961                    // QKV (parallelized across cores — multiple AMX coprocessors)
10962                    crate::blas::par_sgemm_bias(
10963                        inp,
10964                        sl(*qkv_w, base, h * 3 * h),
10965                        sl(*qkv_b, base, 3 * h),
10966                        qkv,
10967                        m,
10968                        h,
10969                        3 * h,
10970                    );
10971
10972                    // SDPA per head (sequential NEON, inline — zero overhead)
10973                    for bi in 0..b {
10974                        for hi in 0..n_h {
10975                            for qi in 0..s {
10976                                for ki in 0..s {
10977                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10978                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10979                                    #[cfg(target_arch = "aarch64")]
10980                                    let dot;
10981                                    #[cfg(not(target_arch = "aarch64"))]
10982                                    let mut dot = 0f32;
10983                                    #[cfg(target_arch = "aarch64")]
10984                                    {
10985                                        use std::arch::aarch64::*;
10986                                        let mut acc = vdupq_n_f32(0.0);
10987                                        for c in 0..neon_chunks {
10988                                            acc = vfmaq_f32(
10989                                                acc,
10990                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
10991                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
10992                                            );
10993                                        }
10994                                        dot = vaddvq_f32(acc);
10995                                    }
10996                                    #[cfg(not(target_arch = "aarch64"))]
10997                                    for d in 0..d_h {
10998                                        dot += qkv[q_base + d] * qkv[k_base + d];
10999                                    }
11000                                    sc[qi * s + ki] = dot * scale;
11001                                    if mk[bi * s + ki] < mask_thr {
11002                                        sc[qi * s + ki] = mask_neg;
11003                                    }
11004                                }
11005                            }
11006                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11007                            for qi in 0..s {
11008                                let o = bi * s * h + qi * h + hi * d_h;
11009                                for d in 0..d_h {
11010                                    attn[o + d] = 0.0;
11011                                }
11012                                for ki in 0..s {
11013                                    let w = sc[qi * s + ki];
11014                                    if w > score_thr {
11015                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11016                                        #[cfg(target_arch = "aarch64")]
11017                                        {
11018                                            use std::arch::aarch64::*;
11019                                            let vw = vdupq_n_f32(w);
11020                                            for c in 0..neon_chunks {
11021                                                let off = c * 4;
11022                                                vst1q_f32(
11023                                                    attn.as_mut_ptr().add(o + off),
11024                                                    vfmaq_f32(
11025                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11026                                                        vw,
11027                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11028                                                    ),
11029                                                );
11030                                            }
11031                                        }
11032                                        #[cfg(not(target_arch = "aarch64"))]
11033                                        for d in 0..d_h {
11034                                            attn[o + d] += w * qkv[v + d];
11035                                        }
11036                                    }
11037                                }
11038                            }
11039                        }
11040                    }
11041
11042                    // Out proj (sgemm + bias fused) + residual add with NEON
11043                    crate::blas::sgemm_bias(
11044                        attn,
11045                        sl(*out_w, base, h * h),
11046                        sl(*out_b, base, h),
11047                        res,
11048                        m,
11049                        h,
11050                        h,
11051                    );
11052                    #[cfg(target_arch = "aarch64")]
11053                    {
11054                        use std::arch::aarch64::*;
11055                        let chunks_h = (m * h) / 4;
11056                        for c in 0..chunks_h {
11057                            let off = c * 4;
11058                            vst1q_f32(
11059                                res.as_mut_ptr().add(off),
11060                                vaddq_f32(
11061                                    vld1q_f32(res.as_ptr().add(off)),
11062                                    vld1q_f32(inp.as_ptr().add(off)),
11063                                ),
11064                            );
11065                        }
11066                        for i in (chunks_h * 4)..(m * h) {
11067                            res[i] += inp[i];
11068                        }
11069                    }
11070                    #[cfg(not(target_arch = "aarch64"))]
11071                    for i in 0..m * h {
11072                        res[i] += inp[i];
11073                    }
11074
11075                    // LN1 (fused residual already done above — just normalize)
11076                    let g1 = sl(*ln1_g, base, h);
11077                    let b1 = sl(*ln1_b, base, h);
11078                    for r in 0..m {
11079                        crate::kernels::layer_norm_row(
11080                            &res[r * h..(r + 1) * h],
11081                            g1,
11082                            b1,
11083                            &mut normed[r * h..(r + 1) * h],
11084                            h,
11085                            *eps1,
11086                        );
11087                    }
11088
11089                    // FFN: fc1 (parallel across cores) + GELU
11090                    crate::blas::par_sgemm_bias(
11091                        normed,
11092                        sl(*fc1_w, base, h * id),
11093                        sl(*fc1_b, base, id),
11094                        ffn,
11095                        m,
11096                        h,
11097                        id,
11098                    );
11099                    crate::kernels::par_gelu_inplace(ffn);
11100
11101                    // fc2 + bias (parallel across cores) + residual with NEON
11102                    crate::blas::par_sgemm_bias(
11103                        ffn,
11104                        sl(*fc2_w, base, id * h),
11105                        sl(*fc2_b, base, h),
11106                        res,
11107                        m,
11108                        id,
11109                        h,
11110                    );
11111                    #[cfg(target_arch = "aarch64")]
11112                    {
11113                        use std::arch::aarch64::*;
11114                        let chunks_h = (m * h) / 4;
11115                        for c in 0..chunks_h {
11116                            let off = c * 4;
11117                            vst1q_f32(
11118                                res.as_mut_ptr().add(off),
11119                                vaddq_f32(
11120                                    vld1q_f32(res.as_ptr().add(off)),
11121                                    vld1q_f32(normed.as_ptr().add(off)),
11122                                ),
11123                            );
11124                        }
11125                        for i in (chunks_h * 4)..(m * h) {
11126                            res[i] += normed[i];
11127                        }
11128                    }
11129                    #[cfg(not(target_arch = "aarch64"))]
11130                    for i in 0..m * h {
11131                        res[i] += normed[i];
11132                    }
11133
11134                    // LN2 → output
11135                    let g2 = sl(*ln2_g, base, h);
11136                    let b2 = sl(*ln2_b, base, h);
11137                    for r in 0..m {
11138                        crate::kernels::layer_norm_row(
11139                            &res[r * h..(r + 1) * h],
11140                            g2,
11141                            b2,
11142                            &mut dst[r * h..(r + 1) * h],
11143                            h,
11144                            *eps2,
11145                        );
11146                    }
11147                }
11148            }
11149
11150            Thunk::FusedNomicLayer {
11151                hidden,
11152                qkv_w,
11153                out_w,
11154                mask,
11155                cos,
11156                sin,
11157                cos_len,
11158                ln1_g,
11159                ln1_b,
11160                eps1,
11161                fc11_w,
11162                fc12_w: _,
11163                fc2_w,
11164                ln2_g,
11165                ln2_b,
11166                eps2,
11167                out,
11168                batch,
11169                seq,
11170                hs,
11171                nh,
11172                dh,
11173                int_dim,
11174            } => {
11175                let (b, s, h, n_h, d_h) = (
11176                    *batch as usize,
11177                    *seq as usize,
11178                    *hs as usize,
11179                    *nh as usize,
11180                    *dh as usize,
11181                );
11182                let m = b * s;
11183                let id = *int_dim as usize;
11184                let scale = (d_h as f32).powf(-0.5);
11185                let half_dh = d_h / 2;
11186                #[cfg(target_arch = "aarch64")]
11187                let neon_chunks = d_h / 4;
11188                unsafe {
11189                    let inp = sl(*hidden, base, m * h);
11190                    let dst = sl_mut(*out, base, m * h);
11191                    let mk = sl(*mask, base, b * s);
11192                    let cos_tab = sl(*cos, base, *cos_len as usize);
11193                    let sin_tab = sl(*sin, base, *cos_len as usize);
11194                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
11195                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11196
11197                    let mut qkv = vec![0f32; m * 3 * h];
11198                    let mut attn = vec![0f32; m * h];
11199                    let mut res = vec![0f32; m * h];
11200                    let mut normed = vec![0f32; m * h];
11201                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
11202                    let mut sc = vec![0f32; s * s];
11203
11204                    // QKV (no bias)
11205                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11206
11207                    // SDPA with inline RoPE
11208                    for bi in 0..b {
11209                        for hi in 0..n_h {
11210                            for qi in 0..s {
11211                                for ki in 0..s {
11212                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11213                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11214                                    let mut dot = 0f32;
11215                                    for i in 0..half_dh {
11216                                        let q1 = qkv[q_base + i];
11217                                        let q2 = qkv[q_base + half_dh + i];
11218                                        let k1 = qkv[k_base + i];
11219                                        let k2 = qkv[k_base + half_dh + i];
11220                                        let cq = cos_tab[qi * half_dh + i];
11221                                        let sq = sin_tab[qi * half_dh + i];
11222                                        let ck = cos_tab[ki * half_dh + i];
11223                                        let sk = sin_tab[ki * half_dh + i];
11224                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11225                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11226                                    }
11227                                    sc[qi * s + ki] = dot * scale;
11228                                    if mk[bi * s + ki] < mask_thr {
11229                                        sc[qi * s + ki] = mask_neg;
11230                                    }
11231                                }
11232                            }
11233                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11234                            for qi in 0..s {
11235                                let o = bi * s * h + qi * h + hi * d_h;
11236                                for d in 0..d_h {
11237                                    attn[o + d] = 0.0;
11238                                }
11239                                for ki in 0..s {
11240                                    let w = sc[qi * s + ki];
11241                                    if w > score_thr {
11242                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11243                                        #[cfg(target_arch = "aarch64")]
11244                                        {
11245                                            use std::arch::aarch64::*;
11246                                            let vw = vdupq_n_f32(w);
11247                                            for c in 0..neon_chunks {
11248                                                let off = c * 4;
11249                                                vst1q_f32(
11250                                                    attn.as_mut_ptr().add(o + off),
11251                                                    vfmaq_f32(
11252                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11253                                                        vw,
11254                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11255                                                    ),
11256                                                );
11257                                            }
11258                                        }
11259                                        #[cfg(not(target_arch = "aarch64"))]
11260                                        for d in 0..d_h {
11261                                            attn[o + d] += w * qkv[v + d];
11262                                        }
11263                                    }
11264                                }
11265                            }
11266                        }
11267                    }
11268
11269                    // Out proj (no bias) + residual
11270                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11271                    for i in 0..m * h {
11272                        res[i] += inp[i];
11273                    }
11274
11275                    // LN1
11276                    let g1 = sl(*ln1_g, base, h);
11277                    let b1 = sl(*ln1_b, base, h);
11278                    for r in 0..m {
11279                        crate::kernels::layer_norm_row(
11280                            &res[r * h..(r + 1) * h],
11281                            g1,
11282                            b1,
11283                            &mut normed[r * h..(r + 1) * h],
11284                            h,
11285                            *eps1,
11286                        );
11287                    }
11288
11289                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
11290                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11291                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
11292                    // SiLU on gate, then multiply up * gate → store in up region
11293                    for row in 0..m {
11294                        let bo = row * 2 * id;
11295                        // SiLU in-place on gate portion
11296                        for j in 0..id {
11297                            let x = ffn_concat[bo + id + j];
11298                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11299                        }
11300                        // Multiply: up[j] *= gate[j]
11301                        for j in 0..id {
11302                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11303                        }
11304                    }
11305
11306                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
11307                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
11308                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
11309                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
11310                    crate::blas::sgemm_general(
11311                        ffn_concat.as_ptr(),
11312                        sl(*fc2_w, base, id * h).as_ptr(),
11313                        res.as_mut_ptr(),
11314                        m,
11315                        h,
11316                        id,
11317                        1.0,
11318                        0.0,
11319                        2 * id,
11320                        h,
11321                        h,
11322                        false,
11323                        false,
11324                    );
11325                    for i in 0..m * h {
11326                        res[i] += normed[i];
11327                    }
11328
11329                    // LN2 → output
11330                    let g2 = sl(*ln2_g, base, h);
11331                    let b2 = sl(*ln2_b, base, h);
11332                    for r in 0..m {
11333                        crate::kernels::layer_norm_row(
11334                            &res[r * h..(r + 1) * h],
11335                            g2,
11336                            b2,
11337                            &mut dst[r * h..(r + 1) * h],
11338                            h,
11339                            *eps2,
11340                        );
11341                    }
11342                }
11343            }
11344
11345            Thunk::FusedSwiGLU {
11346                src,
11347                dst,
11348                n_half,
11349                total,
11350                gate_first,
11351            } => {
11352                let n = *n_half as usize;
11353                let t = *total as usize;
11354                let outer = t / n;
11355                let in_total = outer * 2 * n;
11356                let gate_first = *gate_first;
11357                unsafe {
11358                    let inp = sl(*src, base, in_total);
11359                    let out = sl_mut(*dst, base, t);
11360                    for o in 0..outer {
11361                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11362                        let out_row = &mut out[o * n..(o + 1) * n];
11363                        for i in 0..n {
11364                            let (up, gate) = if gate_first {
11365                                (in_row[n + i], in_row[i])
11366                            } else {
11367                                (in_row[i], in_row[n + i])
11368                            };
11369                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11370                        }
11371                    }
11372                }
11373            }
11374
11375            Thunk::Concat {
11376                dst,
11377                outer,
11378                inner,
11379                total_axis,
11380                inputs,
11381            } => {
11382                let outer = *outer as usize;
11383                let inner = *inner as usize;
11384                let total_axis = *total_axis as usize;
11385                let row_stride = total_axis * inner;
11386                let out_total = outer * row_stride;
11387                unsafe {
11388                    let out = sl_mut(*dst, base, out_total);
11389                    let mut cum: usize = 0;
11390                    for (src_off, in_axis) in inputs {
11391                        let in_axis = *in_axis as usize;
11392                        let copy_per_row = in_axis * inner;
11393                        let dst_col_off = cum * inner;
11394                        let in_total = outer * copy_per_row;
11395                        let inp = sl(*src_off, base, in_total);
11396                        for o in 0..outer {
11397                            let dst_row_start = o * row_stride + dst_col_off;
11398                            let src_row_start = o * copy_per_row;
11399                            out[dst_row_start..dst_row_start + copy_per_row]
11400                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11401                        }
11402                        cum += in_axis;
11403                    }
11404                }
11405            }
11406
11407            Thunk::ConcatF64 {
11408                dst,
11409                outer,
11410                inner,
11411                total_axis,
11412                inputs,
11413            } => {
11414                let outer = *outer as usize;
11415                let inner = *inner as usize;
11416                let total_axis = *total_axis as usize;
11417                let row_stride = total_axis * inner;
11418                let out_total = outer * row_stride;
11419                unsafe {
11420                    let out = sl_mut_f64(*dst, base, out_total);
11421                    let mut cum: usize = 0;
11422                    for (src_off, in_axis) in inputs {
11423                        let in_axis = *in_axis as usize;
11424                        let copy_per_row = in_axis * inner;
11425                        let dst_col_off = cum * inner;
11426                        let in_total = outer * copy_per_row;
11427                        let inp = sl_f64(*src_off, base, in_total);
11428                        for o in 0..outer {
11429                            let dst_row_start = o * row_stride + dst_col_off;
11430                            let src_row_start = o * copy_per_row;
11431                            out[dst_row_start..dst_row_start + copy_per_row]
11432                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11433                        }
11434                        cum += in_axis;
11435                    }
11436                }
11437            }
11438
11439            Thunk::Compare {
11440                lhs,
11441                rhs,
11442                dst,
11443                len,
11444                op,
11445                inputs_i64,
11446                inputs_elem_bytes,
11447                dst_elem_bytes,
11448            } => {
11449                let len = *len as usize;
11450                let arena_len = arena_buf.len();
11451                let elem = (*inputs_elem_bytes).max(1) as usize;
11452                let dst_eb = (*dst_elem_bytes).max(1) as usize;
11453                let max_l = (arena_len.saturating_sub(*lhs)) / elem;
11454                let max_r = (arena_len.saturating_sub(*rhs)) / elem;
11455                let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
11456                let len = len.min(max_l).min(max_r).min(max_d);
11457                if trace_thunks && len > 0 {
11458                    eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
11459                }
11460                if elem == 1 {
11461                    let l = arena_buf[*lhs..*lhs + len].to_vec();
11462                    let r = arena_buf[*rhs..*rhs + len].to_vec();
11463                    for i in 0..len {
11464                        let v = match op {
11465                            CmpOp::Eq => l[i] == r[i],
11466                            CmpOp::Ne => l[i] != r[i],
11467                            CmpOp::Lt => l[i] < r[i],
11468                            CmpOp::Le => l[i] <= r[i],
11469                            CmpOp::Gt => l[i] > r[i],
11470                            CmpOp::Ge => l[i] >= r[i],
11471                        };
11472                        if *dst_elem_bytes == 1 {
11473                            arena_buf[*dst + i] = u8::from(v);
11474                        } else {
11475                            unsafe {
11476                                let o = sl_mut(*dst, base, len);
11477                                o[i] = if v { 1.0 } else { 0.0 };
11478                            }
11479                        }
11480                    }
11481                } else if *inputs_i64 != 0 {
11482                    unsafe {
11483                        let l = sl_i64(*lhs, base, len);
11484                        let r = sl_i64(*rhs, base, len);
11485                        for i in 0..len {
11486                            let v = match op {
11487                                CmpOp::Eq => l[i] == r[i],
11488                                CmpOp::Ne => l[i] != r[i],
11489                                CmpOp::Lt => l[i] < r[i],
11490                                CmpOp::Le => l[i] <= r[i],
11491                                CmpOp::Gt => l[i] > r[i],
11492                                CmpOp::Ge => l[i] >= r[i],
11493                            };
11494                            if *dst_elem_bytes == 1 {
11495                                arena_buf[*dst + i] = u8::from(v);
11496                            } else {
11497                                let o = sl_mut(*dst, base, len);
11498                                o[i] = if v { 1.0 } else { 0.0 };
11499                            }
11500                        }
11501                    }
11502                } else {
11503                    unsafe {
11504                        let l = sl(*lhs, base, len);
11505                        let r = sl(*rhs, base, len);
11506                        for i in 0..len {
11507                            let v = match op {
11508                                CmpOp::Eq => l[i] == r[i],
11509                                CmpOp::Ne => l[i] != r[i],
11510                                CmpOp::Lt => l[i] < r[i],
11511                                CmpOp::Le => l[i] <= r[i],
11512                                CmpOp::Gt => l[i] > r[i],
11513                                CmpOp::Ge => l[i] >= r[i],
11514                            };
11515                            if *dst_elem_bytes == 1 {
11516                                arena_buf[*dst + i] = u8::from(v);
11517                            } else {
11518                                let o = sl_mut(*dst, base, len);
11519                                o[i] = if v { 1.0 } else { 0.0 };
11520                            }
11521                        }
11522                    }
11523                }
11524            }
11525
11526            Thunk::Where {
11527                cond,
11528                on_true,
11529                on_false,
11530                dst,
11531                len,
11532                elem_bytes,
11533                cond_elem_bytes,
11534            } => {
11535                let len = *len as usize;
11536                let eb = *elem_bytes as usize;
11537                let cond_eb = (*cond_elem_bytes).max(1) as usize;
11538                let arena_len = arena_buf.len();
11539                let len = len
11540                    .min((arena_len.saturating_sub(*cond)) / cond_eb)
11541                    .min((arena_len.saturating_sub(*on_true)) / eb)
11542                    .min((arena_len.saturating_sub(*on_false)) / eb)
11543                    .min((arena_len.saturating_sub(*dst)) / eb);
11544                unsafe {
11545                    if *elem_bytes == 8 {
11546                        let t = sl_i64(*on_true, base, len);
11547                        let e = sl_i64(*on_false, base, len);
11548                        let o = sl_mut_i64(*dst, base, len);
11549                        if *cond_elem_bytes == 1 {
11550                            let c = &arena_buf[*cond..*cond + len];
11551                            for i in 0..len {
11552                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
11553                            }
11554                        } else {
11555                            let c = sl_i64(*cond, base, len);
11556                            for i in 0..len {
11557                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
11558                            }
11559                        }
11560                    } else if *cond_elem_bytes == 1 {
11561                        let c = &arena_buf[*cond..*cond + len];
11562                        let t = sl(*on_true, base, len);
11563                        let e = sl(*on_false, base, len);
11564                        let o = sl_mut(*dst, base, len);
11565                        for i in 0..len {
11566                            o[i] = if c[i] != 0 { t[i] } else { e[i] };
11567                        }
11568                    } else {
11569                        let c = sl(*cond, base, len);
11570                        let t = sl(*on_true, base, len);
11571                        let e = sl(*on_false, base, len);
11572                        let o = sl_mut(*dst, base, len);
11573                        for i in 0..len {
11574                            o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
11575                        }
11576                    }
11577                }
11578            }
11579
11580            Thunk::ScatterAdd {
11581                updates,
11582                indices,
11583                dst,
11584                num_updates,
11585                out_dim,
11586                trailing,
11587            } => {
11588                let num_updates = *num_updates as usize;
11589                let out_dim = *out_dim as usize;
11590                let trailing = *trailing as usize;
11591                unsafe {
11592                    let upd = sl(*updates, base, num_updates * trailing);
11593                    let ids = sl(*indices, base, num_updates);
11594                    let out = sl_mut(*dst, base, out_dim * trailing);
11595                    // Zero the output first — semantics are accumulate-into-zeros.
11596                    for v in out.iter_mut() {
11597                        *v = 0.0;
11598                    }
11599                    for i in 0..num_updates {
11600                        let row = ids[i] as usize;
11601                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
11602                        let src_off = i * trailing;
11603                        let dst_off = row * trailing;
11604                        for j in 0..trailing {
11605                            out[dst_off + j] += upd[src_off + j];
11606                        }
11607                    }
11608                }
11609            }
11610
11611            Thunk::GroupedMatMul {
11612                input,
11613                weight,
11614                expert_idx,
11615                dst,
11616                m,
11617                k_dim,
11618                n,
11619                num_experts,
11620            } => {
11621                let m = *m as usize;
11622                let k_dim = *k_dim as usize;
11623                let n = *n as usize;
11624                let num_experts = *num_experts as usize;
11625                unsafe {
11626                    let inp = sl(*input, base, m * k_dim);
11627                    let wt = sl(*weight, base, num_experts * k_dim * n);
11628                    let ids = sl(*expert_idx, base, m);
11629                    let out = sl_mut(*dst, base, m * n);
11630
11631                    // Counting-sort tokens by their assigned expert.
11632                    // counts[e] = how many tokens routed to expert e.
11633                    let mut counts = vec![0usize; num_experts];
11634                    for i in 0..m {
11635                        let e = ids[i] as usize;
11636                        debug_assert!(
11637                            e < num_experts,
11638                            "expert_idx out of range: {e} >= {num_experts}"
11639                        );
11640                        counts[e] += 1;
11641                    }
11642                    // Cumulative offsets into the packed buffer.
11643                    let mut offsets = vec![0usize; num_experts + 1];
11644                    for e in 0..num_experts {
11645                        offsets[e + 1] = offsets[e] + counts[e];
11646                    }
11647                    // Pack: each expert's rows land contiguously in `packed_in`.
11648                    // `original_pos[packed_idx] = original_token_idx` for the
11649                    // unpermute step at the end.
11650                    let mut packed_in = vec![0f32; m * k_dim];
11651                    let mut original_pos = vec![0usize; m];
11652                    let mut write_idx = vec![0usize; num_experts];
11653                    for i in 0..m {
11654                        let e = ids[i] as usize;
11655                        let dst_row = offsets[e] + write_idx[e];
11656                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
11657                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
11658                        original_pos[dst_row] = i;
11659                        write_idx[e] += 1;
11660                    }
11661
11662                    // One BLAS sgemm per expert. Skip experts with no
11663                    // tokens — common at the tail when M is much smaller
11664                    // than num_experts × k.
11665                    let mut packed_out = vec![0f32; m * n];
11666                    let expert_stride = k_dim * n;
11667                    let gmm_ord = crate::moe_residency::next_gmm_ord();
11668                    let moe_layer = gmm_ord / 3;
11669                    for e in 0..num_experts {
11670                        let count = counts[e];
11671                        if count == 0 {
11672                            continue;
11673                        }
11674                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
11675                        let in_start = offsets[e];
11676                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
11677                        let w_slab: &[f32] =
11678                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
11679                                if let Some(ptr) =
11680                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
11681                                {
11682                                    std::slice::from_raw_parts(ptr, expert_stride)
11683                                } else {
11684                                    &wt[e * expert_stride..(e + 1) * expert_stride]
11685                                }
11686                            } else {
11687                                &wt[e * expert_stride..(e + 1) * expert_stride]
11688                            };
11689                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
11690                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
11691                    }
11692
11693                    // Unpermute back to original token order.
11694                    for packed_idx in 0..m {
11695                        let i = original_pos[packed_idx];
11696                        out[i * n..(i + 1) * n]
11697                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
11698                    }
11699                }
11700            }
11701
11702            Thunk::DequantGroupedMatMulGguf {
11703                input,
11704                w_q,
11705                expert_idx,
11706                dst,
11707                m,
11708                k_dim,
11709                n,
11710                num_experts,
11711                scheme,
11712            } => {
11713                let m = *m as usize;
11714                let k_dim = *k_dim as usize;
11715                let n = *n as usize;
11716                let num_experts = *num_experts as usize;
11717                let block_elems = scheme.gguf_block_size() as usize;
11718                let block_bytes = scheme.gguf_block_bytes() as usize;
11719                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11720                unsafe {
11721                    let inp = sl(*input, base, m * k_dim);
11722                    let wt = std::slice::from_raw_parts(
11723                        base.add(*w_q) as *const u8,
11724                        num_experts * slab_bytes,
11725                    );
11726                    let ids = sl(*expert_idx, base, m);
11727                    let out = sl_mut(*dst, base, m * n);
11728                    crate::gguf_matmul::gguf_grouped_matmul_bt(
11729                        inp,
11730                        wt,
11731                        ids,
11732                        out,
11733                        m,
11734                        k_dim,
11735                        n,
11736                        num_experts,
11737                        *scheme,
11738                    );
11739                }
11740            }
11741
11742            Thunk::DequantMoEWeightsGguf {
11743                w_q,
11744                dst,
11745                k_dim,
11746                n,
11747                num_experts,
11748                scheme,
11749            } => {
11750                let k_dim = *k_dim as usize;
11751                let n = *n as usize;
11752                let num_experts = *num_experts as usize;
11753                let block_elems = scheme.gguf_block_size() as usize;
11754                let block_bytes = scheme.gguf_block_bytes() as usize;
11755                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11756                unsafe {
11757                    let wt = std::slice::from_raw_parts(
11758                        base.add(*w_q) as *const u8,
11759                        num_experts * slab_bytes,
11760                    );
11761                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
11762                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
11763                        wt,
11764                        out,
11765                        num_experts,
11766                        k_dim,
11767                        n,
11768                        *scheme,
11769                    );
11770                }
11771            }
11772
11773            Thunk::TopK {
11774                src,
11775                dst,
11776                outer,
11777                axis_dim,
11778                k,
11779                indices_i64,
11780            } => {
11781                let outer = *outer as usize;
11782                let axis_dim = *axis_dim as usize;
11783                let k = *k as usize;
11784                unsafe {
11785                    let inp = sl(*src, base, outer * axis_dim);
11786                    // Repeated argmax with masking. O(k * axis_dim) per row;
11787                    // good enough for small k (MoE typical k=2–8). For larger
11788                    // k a partial heap would win.
11789                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
11790                    if *indices_i64 != 0 {
11791                        let out = sl_mut_i64(*dst, base, outer * k);
11792                        for o in 0..outer {
11793                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11794                            for ki in 0..k {
11795                                let mut best_i = 0usize;
11796                                let mut best_v = row_buf[0];
11797                                for i in 1..axis_dim {
11798                                    let v = row_buf[i];
11799                                    if v > best_v {
11800                                        best_v = v;
11801                                        best_i = i;
11802                                    }
11803                                }
11804                                out[o * k + ki] = best_i as i64;
11805                                row_buf[best_i] = f32::NEG_INFINITY;
11806                            }
11807                        }
11808                    } else {
11809                        let out = sl_mut(*dst, base, outer * k);
11810                        for o in 0..outer {
11811                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11812                            for ki in 0..k {
11813                                let mut best_i = 0usize;
11814                                let mut best_v = row_buf[0];
11815                                for i in 1..axis_dim {
11816                                    let v = row_buf[i];
11817                                    if v > best_v {
11818                                        best_v = v;
11819                                        best_i = i;
11820                                    }
11821                                }
11822                                out[o * k + ki] = best_i as f32;
11823                                row_buf[best_i] = f32::NEG_INFINITY;
11824                            }
11825                        }
11826                        if let Some(cap) = schedule.moe_topk_capture.as_ref() {
11827                            cap.push_topk_f32(&out[..outer * k], axis_dim);
11828                        }
11829                    }
11830                }
11831            }
11832
11833            Thunk::Reduce {
11834                src,
11835                dst,
11836                outer,
11837                reduced,
11838                inner,
11839                op,
11840            } => {
11841                let outer = *outer as usize;
11842                let reduced = *reduced as usize;
11843                let inner = *inner as usize;
11844                let in_total = outer * reduced * inner;
11845                let out_total = outer * inner;
11846                unsafe {
11847                    let inp = sl(*src, base, in_total);
11848                    let out = sl_mut(*dst, base, out_total);
11849                    for o in 0..outer {
11850                        for i in 0..inner {
11851                            let mut acc = match op {
11852                                ReduceOp::Max => f32::NEG_INFINITY,
11853                                ReduceOp::Min => f32::INFINITY,
11854                                ReduceOp::Prod => 1.0f32,
11855                                _ => 0.0f32, // Sum / Mean
11856                            };
11857                            // Walk the reduced axis with stride `inner`.
11858                            for r in 0..reduced {
11859                                let v = inp[o * reduced * inner + r * inner + i];
11860                                acc = match op {
11861                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
11862                                    ReduceOp::Max => acc.max(v),
11863                                    ReduceOp::Min => acc.min(v),
11864                                    ReduceOp::Prod => acc * v,
11865                                };
11866                            }
11867                            if matches!(op, ReduceOp::Mean) {
11868                                acc /= reduced as f32;
11869                            }
11870                            out[o * inner + i] = acc;
11871                        }
11872                    }
11873                }
11874            }
11875
11876            Thunk::Conv2D1x1 {
11877                src,
11878                weight,
11879                dst,
11880                n,
11881                c_in,
11882                c_out,
11883                hw,
11884            } => {
11885                let n = *n as usize;
11886                let c_in = *c_in as usize;
11887                let c_out = *c_out as usize;
11888                let hw = *hw as usize;
11889                unsafe {
11890                    let inp = sl(*src, base, n * c_in * hw);
11891                    let wt = sl(*weight, base, c_out * c_in);
11892                    let out = sl_mut(*dst, base, n * c_out * hw);
11893                    // Per-batch sgemm: weight [c_out, c_in] @ input
11894                    // [c_in, hw] = output [c_out, hw]. The weight is
11895                    // shared across batches, so we get to dispatch
11896                    // BLAS once per N (typically 1).
11897                    for ni in 0..n {
11898                        let in_off = ni * c_in * hw;
11899                        let out_off = ni * c_out * hw;
11900                        crate::blas::sgemm(
11901                            wt,
11902                            &inp[in_off..in_off + c_in * hw],
11903                            &mut out[out_off..out_off + c_out * hw],
11904                            c_out,
11905                            c_in,
11906                            hw,
11907                        );
11908                    }
11909                }
11910            }
11911
11912            Thunk::Conv2D {
11913                src,
11914                weight,
11915                dst,
11916                n,
11917                c_in,
11918                h,
11919                w,
11920                c_out,
11921                h_out,
11922                w_out,
11923                kh,
11924                kw,
11925                sh,
11926                sw,
11927                ph,
11928                pw,
11929                dh,
11930                dw,
11931                groups,
11932            } => {
11933                let n = *n as usize;
11934                let c_in = *c_in as usize;
11935                let h = *h as usize;
11936                let w = *w as usize;
11937                let c_out = *c_out as usize;
11938                let h_out = *h_out as usize;
11939                let w_out = *w_out as usize;
11940                let kh = *kh as usize;
11941                let kw = *kw as usize;
11942                let sh = *sh as usize;
11943                let sw = *sw as usize;
11944                let ph = *ph as usize;
11945                let pw = *pw as usize;
11946                let dh = *dh as usize;
11947                let dw = *dw as usize;
11948                let groups = *groups as usize;
11949                let c_in_per_g = c_in / groups;
11950                let c_out_per_g = c_out / groups;
11951                unsafe {
11952                    let inp = sl(*src, base, n * c_in * h * w);
11953                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
11954                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
11955                    for ni in 0..n {
11956                        for co in 0..c_out {
11957                            let g = co / c_out_per_g;
11958                            let ci_start = g * c_in_per_g;
11959                            for ho in 0..h_out {
11960                                for wo in 0..w_out {
11961                                    let mut acc = 0f32;
11962                                    for ci_off in 0..c_in_per_g {
11963                                        let ci = ci_start + ci_off;
11964                                        let in_chan = ((ni * c_in) + ci) * h * w;
11965                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11966                                        for ki in 0..kh {
11967                                            for kj in 0..kw {
11968                                                let hi = ho * sh + ki * dh;
11969                                                let wi = wo * sw + kj * dw;
11970                                                if hi < ph || wi < pw {
11971                                                    continue;
11972                                                }
11973                                                let hi = hi - ph;
11974                                                let wi = wi - pw;
11975                                                if hi >= h || wi >= w {
11976                                                    continue;
11977                                                }
11978                                                acc += inp[in_chan + hi * w + wi]
11979                                                    * wt[wt_chan + ki * kw + kj];
11980                                            }
11981                                        }
11982                                    }
11983                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
11984                                        acc;
11985                                }
11986                            }
11987                        }
11988                    }
11989                }
11990            }
11991
11992            Thunk::Pool2D {
11993                src,
11994                dst,
11995                n,
11996                c,
11997                h,
11998                w,
11999                h_out,
12000                w_out,
12001                kh,
12002                kw,
12003                sh,
12004                sw,
12005                ph,
12006                pw,
12007                kind,
12008            } => {
12009                let n = *n as usize;
12010                let c = *c as usize;
12011                let h = *h as usize;
12012                let w = *w as usize;
12013                let h_out = *h_out as usize;
12014                let w_out = *w_out as usize;
12015                let kh = *kh as usize;
12016                let kw = *kw as usize;
12017                let sh = *sh as usize;
12018                let sw = *sw as usize;
12019                let ph = *ph as usize;
12020                let pw = *pw as usize;
12021                let kernel_area = (kh * kw) as f32;
12022                unsafe {
12023                    let inp = sl(*src, base, n * c * h * w);
12024                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
12025                    for ni in 0..n {
12026                        for ci in 0..c {
12027                            let in_chan = ni * c * h * w + ci * h * w;
12028                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12029                            for ho in 0..h_out {
12030                                for wo in 0..w_out {
12031                                    let mut acc = match kind {
12032                                        ReduceOp::Max => f32::NEG_INFINITY,
12033                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
12034                                    };
12035                                    for ki in 0..kh {
12036                                        for kj in 0..kw {
12037                                            let hi = ho * sh + ki;
12038                                            let wi = wo * sw + kj;
12039                                            // Padded-zero region.
12040                                            if hi < ph || wi < pw {
12041                                                continue;
12042                                            }
12043                                            let hi = hi - ph;
12044                                            let wi = wi - pw;
12045                                            if hi >= h || wi >= w {
12046                                                continue;
12047                                            }
12048                                            let v = inp[in_chan + hi * w + wi];
12049                                            match kind {
12050                                                ReduceOp::Max => acc = acc.max(v),
12051                                                _ => acc += v,
12052                                            }
12053                                        }
12054                                    }
12055                                    if matches!(kind, ReduceOp::Mean) {
12056                                        acc /= kernel_area;
12057                                    }
12058                                    out[out_chan + ho * w_out + wo] = acc;
12059                                }
12060                            }
12061                        }
12062                    }
12063                }
12064            }
12065
12066            Thunk::ReluBackward { x, dy, dx, len } => {
12067                let len = *len as usize;
12068                unsafe {
12069                    let xs = sl(*x, base, len);
12070                    let dys = sl(*dy, base, len);
12071                    let out = sl_mut(*dx, base, len);
12072                    for i in 0..len {
12073                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12074                    }
12075                }
12076            }
12077
12078            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12079                let len = *len as usize;
12080                unsafe {
12081                    let xs = sl_f64(*x, base, len);
12082                    let dys = sl_f64(*dy, base, len);
12083                    let out = sl_mut_f64(*dx, base, len);
12084                    for i in 0..len {
12085                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12086                    }
12087                }
12088            }
12089
12090            Thunk::QMatMul {
12091                x,
12092                w,
12093                bias,
12094                out,
12095                m,
12096                k,
12097                n,
12098                x_zp,
12099                w_zp,
12100                out_zp,
12101                mult,
12102            } => {
12103                let m = *m as usize;
12104                let k = *k as usize;
12105                let n = *n as usize;
12106                unsafe {
12107                    let x_ptr = base.add(*x) as *const i8;
12108                    let w_ptr = base.add(*w) as *const i8;
12109                    let bias_ptr = base.add(*bias) as *const i32;
12110                    let out_ptr = base.add(*out) as *mut i8;
12111                    for mi in 0..m {
12112                        for ni in 0..n {
12113                            let mut acc: i32 = *bias_ptr.add(ni);
12114                            for ki in 0..k {
12115                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12116                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12117                                acc += xv * wv;
12118                            }
12119                            // Requantize: round(acc · mult) + out_zp,
12120                            // clamped to i8.
12121                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12122                            let r = r.clamp(-128, 127) as i8;
12123                            *out_ptr.add(mi * n + ni) = r;
12124                        }
12125                    }
12126                }
12127            }
12128
12129            Thunk::QConv2d {
12130                x,
12131                w,
12132                bias,
12133                out,
12134                n,
12135                c_in,
12136                h,
12137                w_in,
12138                c_out,
12139                h_out,
12140                w_out,
12141                kh,
12142                kw,
12143                sh,
12144                sw,
12145                ph,
12146                pw,
12147                dh,
12148                dw,
12149                groups,
12150                x_zp,
12151                w_zp,
12152                out_zp,
12153                mult,
12154            } => {
12155                let n = *n as usize;
12156                let c_in = *c_in as usize;
12157                let h = *h as usize;
12158                let w_in = *w_in as usize;
12159                let c_out = *c_out as usize;
12160                let h_out = *h_out as usize;
12161                let w_out = *w_out as usize;
12162                let kh = *kh as usize;
12163                let kw = *kw as usize;
12164                let sh = *sh as usize;
12165                let sw = *sw as usize;
12166                let ph = *ph as usize;
12167                let pw = *pw as usize;
12168                let dh = *dh as usize;
12169                let dw = *dw as usize;
12170                let groups = *groups as usize;
12171                let c_in_per_g = c_in / groups;
12172                let c_out_per_g = c_out / groups;
12173                unsafe {
12174                    let x_ptr = base.add(*x) as *const i8;
12175                    let w_ptr = base.add(*w) as *const i8;
12176                    let bias_ptr = base.add(*bias) as *const i32;
12177                    let out_ptr = base.add(*out) as *mut i8;
12178                    for ni in 0..n {
12179                        for co in 0..c_out {
12180                            let g = co / c_out_per_g;
12181                            let ci_start = g * c_in_per_g;
12182                            for ho in 0..h_out {
12183                                for wo in 0..w_out {
12184                                    let mut acc: i32 = *bias_ptr.add(co);
12185                                    for ci_off in 0..c_in_per_g {
12186                                        let ci = ci_start + ci_off;
12187                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
12188                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12189                                        for ki in 0..kh {
12190                                            for kj in 0..kw {
12191                                                let hi = ho * sh + ki * dh;
12192                                                let wi = wo * sw + kj * dw;
12193                                                if hi < ph || wi < pw {
12194                                                    continue;
12195                                                }
12196                                                let hi = hi - ph;
12197                                                let wi = wi - pw;
12198                                                if hi >= h || wi >= w_in {
12199                                                    continue;
12200                                                }
12201                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12202                                                    as i32
12203                                                    - *x_zp;
12204                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12205                                                    - *w_zp;
12206                                                acc += xv * wv;
12207                                            }
12208                                        }
12209                                    }
12210                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12211                                    let r = r.clamp(-128, 127) as i8;
12212                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12213                                    *out_ptr.add(dst) = r;
12214                                }
12215                            }
12216                        }
12217                    }
12218                }
12219            }
12220
12221            Thunk::Quantize {
12222                x,
12223                q,
12224                len,
12225                chan_axis: _,
12226                chan_dim,
12227                inner,
12228                scales,
12229                zero_points,
12230            } => {
12231                let len = *len as usize;
12232                let chan_dim = *chan_dim as usize;
12233                let inner = *inner as usize;
12234                unsafe {
12235                    let xs = sl(*x, base, len);
12236                    let q_ptr = base.add(*q) as *mut i8;
12237                    for i in 0..len {
12238                        let c = if chan_dim == 1 {
12239                            0
12240                        } else {
12241                            (i / inner) % chan_dim
12242                        };
12243                        let inv_scale = 1.0 / scales[c];
12244                        let zp = zero_points[c];
12245                        let v = (xs[i] * inv_scale).round() as i32 + zp;
12246                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12247                    }
12248                }
12249            }
12250
12251            Thunk::Dequantize {
12252                q,
12253                x,
12254                len,
12255                chan_axis: _,
12256                chan_dim,
12257                inner,
12258                scales,
12259                zero_points,
12260            } => {
12261                let len = *len as usize;
12262                let chan_dim = *chan_dim as usize;
12263                let inner = *inner as usize;
12264                unsafe {
12265                    let q_ptr = base.add(*q) as *const i8;
12266                    let out = sl_mut(*x, base, len);
12267                    for i in 0..len {
12268                        let c = if chan_dim == 1 {
12269                            0
12270                        } else {
12271                            (i / inner) % chan_dim
12272                        };
12273                        let scale = scales[c];
12274                        let zp = zero_points[c];
12275                        let qv = *q_ptr.add(i) as i32;
12276                        out[i] = (qv - zp) as f32 * scale;
12277                    }
12278                }
12279            }
12280
12281            Thunk::FakeQuantize {
12282                x,
12283                out,
12284                len,
12285                chan_axis: _,
12286                chan_dim,
12287                inner,
12288                bits,
12289                ste: _,
12290                scale_mode,
12291                state_off,
12292            } => {
12293                use rlx_ir::op::ScaleMode;
12294                let len = *len as usize;
12295                let chan_dim = *chan_dim as usize;
12296                let inner = *inner as usize;
12297                let q_max: f32 = match *bits {
12298                    8 => 127.0,
12299                    4 => 7.0,
12300                    2 => 1.0,
12301                    n => panic!("FakeQuantize: unsupported bits {n}"),
12302                };
12303                unsafe {
12304                    let xs = sl(*x, base, len);
12305                    let outs = sl_mut(*out, base, len);
12306
12307                    let mut scale = vec![0f32; chan_dim];
12308                    match scale_mode {
12309                        ScaleMode::PerBatch => {
12310                            let mut max_abs = vec![0f32; chan_dim];
12311                            for i in 0..len {
12312                                let c = if chan_dim == 1 {
12313                                    0
12314                                } else {
12315                                    (i / inner) % chan_dim
12316                                };
12317                                let a = xs[i].abs();
12318                                if a > max_abs[c] {
12319                                    max_abs[c] = a;
12320                                }
12321                            }
12322                            for c in 0..chan_dim {
12323                                scale[c] = (max_abs[c] / q_max).max(1e-12);
12324                            }
12325                        }
12326                        ScaleMode::EMA { decay } => {
12327                            // Per-channel current max-abs, then blend
12328                            // into the running state in place.
12329                            let mut max_abs = vec![0f32; chan_dim];
12330                            for i in 0..len {
12331                                let c = if chan_dim == 1 {
12332                                    0
12333                                } else {
12334                                    (i / inner) % chan_dim
12335                                };
12336                                let a = xs[i].abs();
12337                                if a > max_abs[c] {
12338                                    max_abs[c] = a;
12339                                }
12340                            }
12341                            let state =
12342                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12343                            for c in 0..chan_dim {
12344                                let cur = (max_abs[c] / q_max).max(1e-12);
12345                                // Cold-start: state==0 → seed directly.
12346                                let blended = if state[c] <= 0.0 {
12347                                    cur
12348                                } else {
12349                                    *decay * state[c] + (1.0 - *decay) * cur
12350                                };
12351                                state[c] = blended;
12352                                scale[c] = blended;
12353                            }
12354                        }
12355                        ScaleMode::Fixed => {
12356                            let state =
12357                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
12358                            for c in 0..chan_dim {
12359                                scale[c] = state[c].max(1e-12);
12360                            }
12361                        }
12362                    }
12363
12364                    for i in 0..len {
12365                        let c = if chan_dim == 1 {
12366                            0
12367                        } else {
12368                            (i / inner) % chan_dim
12369                        };
12370                        let s = scale[c];
12371                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12372                        outs[i] = qv * s;
12373                    }
12374                }
12375            }
12376
12377            Thunk::ActivationBackward {
12378                x,
12379                dy,
12380                dx,
12381                len,
12382                kind,
12383            } => {
12384                let len = *len as usize;
12385                unsafe {
12386                    let xs = sl(*x, base, len);
12387                    let dys = sl(*dy, base, len);
12388                    let out = sl_mut(*dx, base, len);
12389                    activation_backward_kernel(*kind, xs, dys, out);
12390                }
12391            }
12392
12393            Thunk::ActivationBackwardF64 {
12394                x,
12395                dy,
12396                dx,
12397                len,
12398                kind,
12399            } => {
12400                let len = *len as usize;
12401                unsafe {
12402                    let xs = sl_f64(*x, base, len);
12403                    let dys = sl_f64(*dy, base, len);
12404                    let out = sl_mut_f64(*dx, base, len);
12405                    activation_backward_kernel_f64(*kind, xs, dys, out);
12406                }
12407            }
12408
12409            Thunk::FakeQuantizeLSQ {
12410                x,
12411                scale_off,
12412                out,
12413                len,
12414                chan_axis: _,
12415                chan_dim,
12416                inner,
12417                bits,
12418            } => {
12419                let len = *len as usize;
12420                let chan_dim = *chan_dim as usize;
12421                let inner = *inner as usize;
12422                let q_max: f32 = match *bits {
12423                    8 => 127.0,
12424                    4 => 7.0,
12425                    2 => 1.0,
12426                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
12427                };
12428                unsafe {
12429                    let xs = sl(*x, base, len);
12430                    let scale = sl(*scale_off, base, chan_dim);
12431                    let outs = sl_mut(*out, base, len);
12432                    for i in 0..len {
12433                        let c = if chan_dim == 1 {
12434                            0
12435                        } else {
12436                            (i / inner) % chan_dim
12437                        };
12438                        let s = scale[c].max(1e-12);
12439                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12440                        outs[i] = qv * s;
12441                    }
12442                }
12443            }
12444
12445            Thunk::FakeQuantizeLSQBackwardX {
12446                x,
12447                scale_off,
12448                dy,
12449                dx,
12450                len,
12451                chan_axis: _,
12452                chan_dim,
12453                inner,
12454                bits,
12455            } => {
12456                let len = *len as usize;
12457                let chan_dim = *chan_dim as usize;
12458                let inner = *inner as usize;
12459                let q_max: f32 = match *bits {
12460                    8 => 127.0,
12461                    4 => 7.0,
12462                    2 => 1.0,
12463                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
12464                };
12465                unsafe {
12466                    let xs = sl(*x, base, len);
12467                    let scale = sl(*scale_off, base, chan_dim);
12468                    let dys = sl(*dy, base, len);
12469                    let outs = sl_mut(*dx, base, len);
12470                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
12471                    for i in 0..len {
12472                        let c = if chan_dim == 1 {
12473                            0
12474                        } else {
12475                            (i / inner) % chan_dim
12476                        };
12477                        let z = xs[i] / scale[c].max(1e-12);
12478                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
12479                    }
12480                }
12481            }
12482
12483            Thunk::FakeQuantizeLSQBackwardScale {
12484                x,
12485                scale_off,
12486                dy,
12487                dscale,
12488                len,
12489                chan_axis: _,
12490                chan_dim,
12491                inner,
12492                bits,
12493            } => {
12494                let len = *len as usize;
12495                let chan_dim = *chan_dim as usize;
12496                let inner = *inner as usize;
12497                let q_max: f32 = match *bits {
12498                    8 => 127.0,
12499                    4 => 7.0,
12500                    2 => 1.0,
12501                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
12502                };
12503                unsafe {
12504                    let xs = sl(*x, base, len);
12505                    let scale = sl(*scale_off, base, chan_dim);
12506                    let dys = sl(*dy, base, len);
12507                    let outs = sl_mut(*dscale, base, chan_dim);
12508                    for v in outs.iter_mut() {
12509                        *v = 0.0;
12510                    }
12511                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
12512                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
12513                    for i in 0..len {
12514                        let c = if chan_dim == 1 {
12515                            0
12516                        } else {
12517                            (i / inner) % chan_dim
12518                        };
12519                        let s = scale[c].max(1e-12);
12520                        let z = xs[i] / s;
12521                        let psi = if z.abs() <= q_max {
12522                            -z + z.round()
12523                        } else if z > 0.0 {
12524                            q_max
12525                        } else {
12526                            -q_max
12527                        };
12528                        outs[c] += psi * dys[i];
12529                    }
12530                }
12531            }
12532
12533            Thunk::FakeQuantizeBackward {
12534                x,
12535                dy,
12536                dx,
12537                len,
12538                chan_axis: _,
12539                chan_dim,
12540                inner,
12541                bits,
12542                ste,
12543            } => {
12544                use rlx_ir::op::SteKind;
12545                let len = *len as usize;
12546                let chan_dim = *chan_dim as usize;
12547                let inner = *inner as usize;
12548                let q_max: f32 = match *bits {
12549                    8 => 127.0,
12550                    4 => 7.0,
12551                    2 => 1.0,
12552                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
12553                };
12554                unsafe {
12555                    let xs = sl(*x, base, len);
12556                    let dys = sl(*dy, base, len);
12557                    let outs = sl_mut(*dx, base, len);
12558
12559                    // Per-channel max-abs → scale, same as forward.
12560                    let mut max_abs = vec![0f32; chan_dim];
12561                    for i in 0..len {
12562                        let c = if chan_dim == 1 {
12563                            0
12564                        } else {
12565                            (i / inner) % chan_dim
12566                        };
12567                        let a = xs[i].abs();
12568                        if a > max_abs[c] {
12569                            max_abs[c] = a;
12570                        }
12571                    }
12572                    let mut scale = vec![0f32; chan_dim];
12573                    for c in 0..chan_dim {
12574                        scale[c] = (max_abs[c] / q_max).max(1e-12);
12575                    }
12576
12577                    match *ste {
12578                        SteKind::Identity => {
12579                            // dx = dy unchanged.
12580                            outs.copy_from_slice(dys);
12581                        }
12582                        SteKind::ClippedIdentity => {
12583                            // dx = dy * (|x| <= q_max·s); zero if the
12584                            // forward saturated.
12585                            for i in 0..len {
12586                                let c = if chan_dim == 1 {
12587                                    0
12588                                } else {
12589                                    (i / inner) % chan_dim
12590                                };
12591                                let bound = q_max * scale[c];
12592                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
12593                            }
12594                        }
12595                        SteKind::Tanh => {
12596                            // dx = dy * (1 - tanh²(x/s)).
12597                            for i in 0..len {
12598                                let c = if chan_dim == 1 {
12599                                    0
12600                                } else {
12601                                    (i / inner) % chan_dim
12602                                };
12603                                let t = (xs[i] / scale[c]).tanh();
12604                                outs[i] = dys[i] * (1.0 - t * t);
12605                            }
12606                        }
12607                        SteKind::HardTanh => {
12608                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
12609                            for i in 0..len {
12610                                let c = if chan_dim == 1 {
12611                                    0
12612                                } else {
12613                                    (i / inner) % chan_dim
12614                                };
12615                                let bound = q_max * scale[c];
12616                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
12617                                outs[i] = dys[i] * attenuation;
12618                            }
12619                        }
12620                    }
12621                }
12622            }
12623
12624            Thunk::LayerNormBackwardInput {
12625                x,
12626                gamma,
12627                dy,
12628                dx,
12629                rows,
12630                h,
12631                eps,
12632            } => {
12633                let rows = *rows as usize;
12634                let h = *h as usize;
12635                let eps = *eps;
12636                unsafe {
12637                    let xs = sl(*x, base, rows * h);
12638                    let g = sl(*gamma, base, h);
12639                    let dys = sl(*dy, base, rows * h);
12640                    let out = sl_mut(*dx, base, rows * h);
12641                    let n_inv = 1.0 / h as f32;
12642                    for r in 0..rows {
12643                        let xr = &xs[r * h..(r + 1) * h];
12644                        let dyr = &dys[r * h..(r + 1) * h];
12645                        // Per-row mean and inv_std (recompute — no saved
12646                        // tensor from the forward pass).
12647                        let mut sum = 0f32;
12648                        for &v in xr {
12649                            sum += v;
12650                        }
12651                        let mean = sum * n_inv;
12652                        let mut var = 0f32;
12653                        for &v in xr {
12654                            let d = v - mean;
12655                            var += d * d;
12656                        }
12657                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12658
12659                        // sums needed for the closed-form:
12660                        //   mean(dy·γ) and mean(dy·γ·x̂)
12661                        let mut s_sy = 0f32;
12662                        let mut s_sxh = 0f32;
12663                        for d in 0..h {
12664                            let xh = (xr[d] - mean) * inv_std;
12665                            let sy = dyr[d] * g[d];
12666                            s_sy += sy;
12667                            s_sxh += sy * xh;
12668                        }
12669                        let m_sy = s_sy * n_inv;
12670                        let m_sxh = s_sxh * n_inv;
12671
12672                        for d in 0..h {
12673                            let xh = (xr[d] - mean) * inv_std;
12674                            let sy = dyr[d] * g[d];
12675                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
12676                        }
12677                    }
12678                }
12679            }
12680
12681            Thunk::BatchNormInferenceBackwardInput {
12682                x,
12683                gamma,
12684                mean,
12685                var,
12686                dy,
12687                dx,
12688                count,
12689                channels,
12690                eps,
12691            } => {
12692                let count = *count as usize;
12693                let c = *channels as usize;
12694                let n = count * c;
12695                let eps = *eps;
12696                unsafe {
12697                    crate::kernels::batch_norm_inference_backward_input(
12698                        sl(*x, base, n),
12699                        sl(*gamma, base, c),
12700                        sl(*mean, base, c),
12701                        sl(*var, base, c),
12702                        sl(*dy, base, n),
12703                        sl_mut(*dx, base, n),
12704                        c,
12705                        eps,
12706                    );
12707                }
12708            }
12709
12710            Thunk::BatchNormInferenceBackwardGamma {
12711                x,
12712                mean,
12713                var,
12714                dy,
12715                dgamma,
12716                count,
12717                channels,
12718                eps,
12719            } => {
12720                let count = *count as usize;
12721                let c = *channels as usize;
12722                let n = count * c;
12723                let eps = *eps;
12724                unsafe {
12725                    crate::kernels::batch_norm_inference_backward_gamma(
12726                        sl(*x, base, n),
12727                        sl(*mean, base, c),
12728                        sl(*var, base, c),
12729                        sl(*dy, base, n),
12730                        sl_mut(*dgamma, base, c),
12731                        c,
12732                        eps,
12733                    );
12734                }
12735            }
12736
12737            Thunk::BatchNormInferenceBackwardBeta {
12738                dy,
12739                dbeta,
12740                count,
12741                channels,
12742            } => {
12743                let count = *count as usize;
12744                let c = *channels as usize;
12745                let n = count * c;
12746                unsafe {
12747                    crate::kernels::batch_norm_inference_backward_beta(
12748                        sl(*dy, base, n),
12749                        sl_mut(*dbeta, base, c),
12750                        c,
12751                    );
12752                }
12753            }
12754
12755            Thunk::LayerNormBackwardGamma {
12756                x,
12757                dy,
12758                dgamma,
12759                rows,
12760                h,
12761                eps,
12762            } => {
12763                let rows = *rows as usize;
12764                let h = *h as usize;
12765                let eps = *eps;
12766                unsafe {
12767                    let xs = sl(*x, base, rows * h);
12768                    let dys = sl(*dy, base, rows * h);
12769                    let out = sl_mut(*dgamma, base, h);
12770                    for v in out.iter_mut() {
12771                        *v = 0.0;
12772                    }
12773                    let n_inv = 1.0 / h as f32;
12774                    for r in 0..rows {
12775                        let xr = &xs[r * h..(r + 1) * h];
12776                        let dyr = &dys[r * h..(r + 1) * h];
12777                        let mut sum = 0f32;
12778                        for &v in xr {
12779                            sum += v;
12780                        }
12781                        let mean = sum * n_inv;
12782                        let mut var = 0f32;
12783                        for &v in xr {
12784                            let d = v - mean;
12785                            var += d * d;
12786                        }
12787                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12788                        for d in 0..h {
12789                            let xh = (xr[d] - mean) * inv_std;
12790                            out[d] += dyr[d] * xh;
12791                        }
12792                    }
12793                }
12794            }
12795
12796            Thunk::RmsNormBackwardInput {
12797                x,
12798                gamma,
12799                beta,
12800                dy,
12801                dx,
12802                rows,
12803                h,
12804                eps,
12805            } => {
12806                let (rows, h) = (*rows as usize, *h as usize);
12807                unsafe {
12808                    let xs = sl(*x, base, rows * h);
12809                    let g = sl(*gamma, base, h);
12810                    let b = sl(*beta, base, h);
12811                    let dys = sl(*dy, base, rows * h);
12812                    let out = sl_mut(*dx, base, rows * h);
12813                    let mut dg = vec![0f32; h];
12814                    let mut db = vec![0f32; h];
12815                    for r in 0..rows {
12816                        crate::training_bwd::rms_norm_backward_row(
12817                            &xs[r * h..(r + 1) * h],
12818                            g,
12819                            b,
12820                            &dys[r * h..(r + 1) * h],
12821                            &mut out[r * h..(r + 1) * h],
12822                            &mut dg,
12823                            &mut db,
12824                            *eps,
12825                        );
12826                    }
12827                }
12828            }
12829
12830            Thunk::RmsNormBackwardGamma {
12831                x,
12832                gamma,
12833                beta,
12834                dy,
12835                dgamma,
12836                rows,
12837                h,
12838                eps,
12839            } => {
12840                let (rows, h) = (*rows as usize, *h as usize);
12841                unsafe {
12842                    let xs = sl(*x, base, rows * h);
12843                    let g = sl(*gamma, base, h);
12844                    let b = sl(*beta, base, h);
12845                    let dys = sl(*dy, base, rows * h);
12846                    let out = sl_mut(*dgamma, base, h);
12847                    for v in out.iter_mut() {
12848                        *v = 0.0;
12849                    }
12850                    let mut dx = vec![0f32; h];
12851                    let mut db = vec![0f32; h];
12852                    for r in 0..rows {
12853                        crate::training_bwd::rms_norm_backward_row(
12854                            &xs[r * h..(r + 1) * h],
12855                            g,
12856                            b,
12857                            &dys[r * h..(r + 1) * h],
12858                            &mut dx,
12859                            &mut *out,
12860                            &mut db,
12861                            *eps,
12862                        );
12863                    }
12864                }
12865            }
12866
12867            Thunk::RmsNormBackwardBeta {
12868                x,
12869                gamma,
12870                beta,
12871                dy,
12872                dbeta,
12873                rows,
12874                h,
12875                eps,
12876            } => {
12877                let (rows, h) = (*rows as usize, *h as usize);
12878                unsafe {
12879                    let xs = sl(*x, base, rows * h);
12880                    let g = sl(*gamma, base, h);
12881                    let b = sl(*beta, base, h);
12882                    let dys = sl(*dy, base, rows * h);
12883                    let out = sl_mut(*dbeta, base, h);
12884                    for v in out.iter_mut() {
12885                        *v = 0.0;
12886                    }
12887                    let mut dx = vec![0f32; h];
12888                    let mut dg = vec![0f32; h];
12889                    for r in 0..rows {
12890                        crate::training_bwd::rms_norm_backward_row(
12891                            &xs[r * h..(r + 1) * h],
12892                            g,
12893                            b,
12894                            &dys[r * h..(r + 1) * h],
12895                            &mut dx,
12896                            &mut dg,
12897                            &mut *out,
12898                            *eps,
12899                        );
12900                    }
12901                }
12902            }
12903
12904            Thunk::RopeBackward {
12905                dy,
12906                cos,
12907                sin,
12908                dx,
12909                batch,
12910                seq,
12911                hidden,
12912                head_dim,
12913                n_rot,
12914                cos_len,
12915            } => {
12916                let (b, s, hs, dh, nr, cl) = (
12917                    *batch as usize,
12918                    *seq as usize,
12919                    *hidden as usize,
12920                    *head_dim as usize,
12921                    *n_rot as usize,
12922                    *cos_len as usize,
12923                );
12924                let nh = hs / dh;
12925                let tab_half = dh / 2;
12926                unsafe {
12927                    let dys = sl(*dy, base, b * s * hs);
12928                    let cos_tab = sl(*cos, base, cl);
12929                    let sin_tab = sl(*sin, base, cl);
12930                    let out = sl_mut(*dx, base, b * s * hs);
12931                    for bi in 0..b {
12932                        for si in 0..s {
12933                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12934                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12935                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12936                            for hi in 0..nh {
12937                                let base_idx = bi * s * hs + si * hs + hi * dh;
12938                                crate::training_bwd::rope_backward_row(
12939                                    &dys[base_idx..base_idx + dh],
12940                                    cp,
12941                                    sp,
12942                                    &mut out[base_idx..base_idx + dh],
12943                                    dh,
12944                                    nr,
12945                                );
12946                            }
12947                        }
12948                    }
12949                }
12950            }
12951
12952            Thunk::CumsumBackward {
12953                dy,
12954                dx,
12955                rows,
12956                cols,
12957                exclusive,
12958            } => {
12959                let (rows, cols) = (*rows as usize, *cols as usize);
12960                unsafe {
12961                    let dys = sl(*dy, base, rows * cols);
12962                    let out = sl_mut(*dx, base, rows * cols);
12963                    for r in 0..rows {
12964                        crate::training_bwd::cumsum_backward_row(
12965                            &dys[r * cols..(r + 1) * cols],
12966                            &mut out[r * cols..(r + 1) * cols],
12967                            *exclusive,
12968                        );
12969                    }
12970                }
12971            }
12972
12973            Thunk::GroupNormBackwardInput {
12974                x,
12975                gamma,
12976                beta: _beta,
12977                dy,
12978                dx,
12979                n,
12980                c,
12981                h,
12982                w,
12983                num_groups,
12984                eps,
12985            } => {
12986                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
12987                let plane = c * h * w;
12988                unsafe {
12989                    let xs = sl(*x, base, n * plane);
12990                    let g = sl(*gamma, base, c);
12991                    let dys = sl(*dy, base, n * plane);
12992                    let out = sl_mut(*dx, base, n * plane);
12993                    crate::training_bwd::group_norm_backward_input_nchw(
12994                        xs,
12995                        g,
12996                        dys,
12997                        out,
12998                        n,
12999                        c,
13000                        h,
13001                        w,
13002                        *num_groups as usize,
13003                        *eps,
13004                    );
13005                }
13006            }
13007
13008            Thunk::GroupNormBackwardGamma {
13009                x,
13010                dy,
13011                dgamma,
13012                n,
13013                c,
13014                h,
13015                w,
13016                num_groups,
13017                eps,
13018            } => {
13019                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13020                let plane = c * h * w;
13021                unsafe {
13022                    let xs = sl(*x, base, n * plane);
13023                    let dys = sl(*dy, base, n * plane);
13024                    let out = sl_mut(*dgamma, base, c);
13025                    crate::training_bwd::group_norm_backward_gamma_nchw(
13026                        xs,
13027                        dys,
13028                        out,
13029                        n,
13030                        c,
13031                        h,
13032                        w,
13033                        *num_groups as usize,
13034                        *eps,
13035                    );
13036                }
13037            }
13038
13039            Thunk::GroupNormBackwardBeta {
13040                dy,
13041                dbeta,
13042                n,
13043                c,
13044                h,
13045                w,
13046            } => {
13047                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13048                let plane = c * h * w;
13049                unsafe {
13050                    let dys = sl(*dy, base, n * plane);
13051                    let out = sl_mut(*dbeta, base, c);
13052                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13053                }
13054            }
13055
13056            Thunk::GatherBackward {
13057                dy,
13058                indices,
13059                dst,
13060                outer,
13061                axis_dim,
13062                num_idx,
13063                trailing,
13064            } => {
13065                let (outer, axis_dim, num_idx, trailing) = (
13066                    *outer as usize,
13067                    *axis_dim as usize,
13068                    *num_idx as usize,
13069                    *trailing as usize,
13070                );
13071                unsafe {
13072                    let dys = sl(*dy, base, outer * num_idx * trailing);
13073                    let ids = sl(*indices, base, num_idx);
13074                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13075                    for v in out.iter_mut() {
13076                        *v = 0.0;
13077                    }
13078                    crate::training_bwd::gather_axis_backward(
13079                        dys, ids, out, outer, axis_dim, num_idx, trailing,
13080                    );
13081                }
13082            }
13083
13084            Thunk::MaxPool2dBackward {
13085                x,
13086                dy,
13087                dx,
13088                n,
13089                c,
13090                h,
13091                w,
13092                h_out,
13093                w_out,
13094                kh,
13095                kw,
13096                sh,
13097                sw,
13098                ph,
13099                pw,
13100            } => unsafe {
13101                execute_maxpool2d_backward_f32(
13102                    *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13103                    base,
13104                );
13105            },
13106
13107            Thunk::Conv2dBackwardInput {
13108                dy,
13109                w,
13110                dx,
13111                n,
13112                c_in,
13113                h,
13114                w_in,
13115                c_out,
13116                h_out,
13117                w_out,
13118                kh,
13119                kw,
13120                sh,
13121                sw,
13122                ph,
13123                pw,
13124                dh,
13125                dw,
13126                groups,
13127            } => {
13128                // Per-group GEMM + col2im. Two orders of magnitude faster
13129                // than the naive 6-deep nested loop on training shapes.
13130                //
13131                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
13132                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
13133                //
13134                // Layouts (all row-major):
13135                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
13136                //   dy_n_g    [c_out_per_g, h_out · w_out]
13137                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
13138                //   dx_n_g    [c_in_per_g, h · w_in]
13139                let n = *n as usize;
13140                let c_in = *c_in as usize;
13141                let h = *h as usize;
13142                let w_in = *w_in as usize;
13143                let c_out = *c_out as usize;
13144                let h_out = *h_out as usize;
13145                let w_out = *w_out as usize;
13146                let kh = *kh as usize;
13147                let kw = *kw as usize;
13148                let sh = *sh as usize;
13149                let sw = *sw as usize;
13150                let ph = *ph as usize;
13151                let pw = *pw as usize;
13152                let dh = *dh as usize;
13153                let dw = *dw as usize;
13154                let groups = *groups as usize;
13155                let c_in_per_g = c_in / groups;
13156                let c_out_per_g = c_out / groups;
13157
13158                let m_dim = c_in_per_g * kh * kw;
13159                let n_dim = h_out * w_out;
13160                let k_dim = c_out_per_g;
13161
13162                let dy_stride_n = c_out * h_out * w_out;
13163                let dy_stride_g = c_out_per_g * h_out * w_out;
13164                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13165                let dx_stride_n = c_in * h * w_in;
13166                let dx_stride_g = c_in_per_g * h * w_in;
13167
13168                unsafe {
13169                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13170                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13171                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13172                    for v in dxs.iter_mut() {
13173                        *v = 0.0;
13174                    }
13175
13176                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
13177                    let mut dcol = vec![0f32; m_dim * n_dim];
13178
13179                    for ni in 0..n {
13180                        for g in 0..groups {
13181                            let w_g_off = g * w_stride_g;
13182                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13183                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13184
13185                            // dcol = w_g^T @ dy_n_g
13186                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
13187                            // (i.e. K×M storage with lda = M = m_dim — exactly what
13188                            // sgemm_general wants for trans_a=true).
13189                            crate::blas::sgemm_general(
13190                                ws.as_ptr().add(w_g_off),
13191                                dys.as_ptr().add(dy_n_g_off),
13192                                dcol.as_mut_ptr(),
13193                                m_dim,
13194                                n_dim,
13195                                k_dim,
13196                                1.0,
13197                                0.0,
13198                                /*lda=*/ m_dim,
13199                                /*ldb=*/ n_dim,
13200                                /*ldc=*/ n_dim,
13201                                /*trans_a=*/ true,
13202                                /*trans_b=*/ false,
13203                            );
13204
13205                            // dx_n_g += col2im(dcol)
13206                            col2im(
13207                                &dcol,
13208                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13209                                c_in_per_g,
13210                                h,
13211                                w_in,
13212                                h_out,
13213                                w_out,
13214                                kh,
13215                                kw,
13216                                sh,
13217                                sw,
13218                                ph,
13219                                pw,
13220                                dh,
13221                                dw,
13222                            );
13223                        }
13224                    }
13225                }
13226            }
13227
13228            Thunk::Conv2dBackwardWeight {
13229                x,
13230                dy,
13231                dw,
13232                n,
13233                c_in,
13234                h,
13235                w,
13236                c_out,
13237                h_out,
13238                w_out,
13239                kh,
13240                kw,
13241                sh,
13242                sw,
13243                ph,
13244                pw,
13245                dh,
13246                dw_dil,
13247                groups,
13248            } => {
13249                let n = *n as usize;
13250                let c_in = *c_in as usize;
13251                let h = *h as usize;
13252                let w = *w as usize;
13253                // Per-group im2col + GEMM, summed across batch.
13254                //
13255                //   col_n_g  = im2col(x_n_g)               (gather)
13256                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
13257                //
13258                // Layouts:
13259                //   x_n_g     [c_in_per_g, h · w]
13260                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
13261                //   dy_n_g    [c_out_per_g, h_out · w_out]
13262                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
13263                let c_out = *c_out as usize;
13264                let h_out = *h_out as usize;
13265                let w_out = *w_out as usize;
13266                let kh = *kh as usize;
13267                let kw = *kw as usize;
13268                let sh = *sh as usize;
13269                let sw = *sw as usize;
13270                let ph = *ph as usize;
13271                let pw = *pw as usize;
13272                let dh = *dh as usize;
13273                let dw_dil = *dw_dil as usize;
13274                let groups = *groups as usize;
13275                let c_in_per_g = c_in / groups;
13276                let c_out_per_g = c_out / groups;
13277
13278                let m_dim = c_out_per_g;
13279                let n_dim = c_in_per_g * kh * kw;
13280                let k_dim = h_out * w_out;
13281
13282                let x_stride_n = c_in * h * w;
13283                let x_stride_g = c_in_per_g * h * w;
13284                let dy_stride_n = c_out * h_out * w_out;
13285                let dy_stride_g = c_out_per_g * h_out * w_out;
13286                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13287
13288                unsafe {
13289                    let xs = sl(*x, base, n * c_in * h * w);
13290                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13291                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13292                    for v in dws.iter_mut() {
13293                        *v = 0.0;
13294                    }
13295
13296                    let mut col = vec![0f32; n_dim * k_dim];
13297
13298                    for ni in 0..n {
13299                        for g in 0..groups {
13300                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13301                            im2col(
13302                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
13303                                &mut col,
13304                                c_in_per_g,
13305                                h,
13306                                w,
13307                                h_out,
13308                                w_out,
13309                                kh,
13310                                kw,
13311                                sh,
13312                                sw,
13313                                ph,
13314                                pw,
13315                                dh,
13316                                dw_dil,
13317                            );
13318
13319                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13320                            let dw_g_off = g * dw_stride_g;
13321
13322                            // dw_g += dy_n_g @ col^T
13323                            //
13324                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
13325                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
13326                            // col is stored as N×K row-major; with trans_b=true,
13327                            // sgemm_general uses ldb = K = k_dim and treats it as
13328                            // transposed. β=1 accumulates across the batch loop.
13329                            crate::blas::sgemm_general(
13330                                dys.as_ptr().add(dy_n_g_off),
13331                                col.as_ptr(),
13332                                dws.as_mut_ptr().add(dw_g_off),
13333                                m_dim,
13334                                n_dim,
13335                                k_dim,
13336                                1.0,
13337                                1.0,
13338                                /*lda=*/ k_dim,
13339                                /*ldb=*/ k_dim,
13340                                /*ldc=*/ n_dim,
13341                                /*trans_a=*/ false,
13342                                /*trans_b=*/ true,
13343                            );
13344                        }
13345                    }
13346                }
13347            }
13348
13349            Thunk::Im2Col {
13350                x,
13351                col,
13352                n,
13353                c_in,
13354                h,
13355                w,
13356                h_out,
13357                w_out,
13358                kh,
13359                kw,
13360                sh,
13361                sw,
13362                ph,
13363                pw,
13364                dh,
13365                dw_dil,
13366            } => {
13367                let c_in = *c_in as usize;
13368                let h = *h as usize;
13369                let w = *w as usize;
13370                let h_out = *h_out as usize;
13371                let w_out = *w_out as usize;
13372                let kh = *kh as usize;
13373                let kw = *kw as usize;
13374                let sh = *sh as usize;
13375                let sw = *sw as usize;
13376                let ph = *ph as usize;
13377                let pw = *pw as usize;
13378                let dh = *dh as usize;
13379                let dw_dil = *dw_dil as usize;
13380                let per_batch = c_in * h * w;
13381                unsafe {
13382                    let n_eff = if *n == 0 { 0usize } else { *n as usize };
13383                    let x_floats = if n_eff == 0 {
13384                        per_batch.max(1)
13385                    } else {
13386                        n_eff * per_batch
13387                    };
13388                    let xs = sl(*x, base, x_floats);
13389                    let n = if *n == 0 {
13390                        xs.len() / per_batch.max(1)
13391                    } else {
13392                        n_eff
13393                    };
13394                    let m = n * h_out * w_out;
13395                    let k = c_in * kh * kw;
13396                    let cols = sl_mut(*col, base, m * k);
13397                    crate::im2col::im2col_rows_layout(
13398                        xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
13399                    );
13400                }
13401            }
13402
13403            Thunk::SoftmaxCrossEntropy {
13404                logits,
13405                labels,
13406                dst,
13407                n,
13408                c,
13409            } => {
13410                let n = *n as usize;
13411                let c = *c as usize;
13412                unsafe {
13413                    let lg = sl(*logits, base, n * c);
13414                    let lb = sl(*labels, base, n);
13415                    let out = sl_mut(*dst, base, n);
13416                    for ni in 0..n {
13417                        let row = &lg[ni * c..(ni + 1) * c];
13418                        // log-sum-exp: max-subtract for stability.
13419                        let mut m = f32::NEG_INFINITY;
13420                        for &v in row {
13421                            if v > m {
13422                                m = v;
13423                            }
13424                        }
13425                        let mut sum = 0f32;
13426                        for &v in row {
13427                            sum += (v - m).exp();
13428                        }
13429                        let lse = m + sum.ln();
13430                        let label_idx = lb[ni] as usize;
13431                        // loss = -(logits[label] - lse) = lse - logits[label].
13432                        out[ni] = lse - row[label_idx];
13433                    }
13434                }
13435            }
13436
13437            Thunk::SoftmaxCrossEntropyBackward {
13438                logits,
13439                labels,
13440                d_loss,
13441                dlogits,
13442                n,
13443                c,
13444            } => {
13445                let n = *n as usize;
13446                let c = *c as usize;
13447                unsafe {
13448                    let lg = sl(*logits, base, n * c);
13449                    let lb = sl(*labels, base, n);
13450                    let dl = sl(*d_loss, base, n);
13451                    let out = sl_mut(*dlogits, base, n * c);
13452                    for ni in 0..n {
13453                        let row = &lg[ni * c..(ni + 1) * c];
13454                        let label_idx = lb[ni] as usize;
13455                        let scale = dl[ni];
13456                        let mut m = f32::NEG_INFINITY;
13457                        for &v in row {
13458                            if v > m {
13459                                m = v;
13460                            }
13461                        }
13462                        let mut sum = 0f32;
13463                        for &v in row {
13464                            sum += (v - m).exp();
13465                        }
13466                        let inv_sum = 1.0 / sum;
13467                        let dst_row = &mut out[ni * c..(ni + 1) * c];
13468                        for k in 0..c {
13469                            let p = (row[k] - m).exp() * inv_sum;
13470                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
13471                            dst_row[k] = (p - one_hot) * scale;
13472                        }
13473                    }
13474                }
13475            }
13476
13477            Thunk::GatherAxis {
13478                table,
13479                idx,
13480                dst,
13481                outer,
13482                axis_dim,
13483                num_idx,
13484                trailing,
13485                idx_i64,
13486                table_bytes,
13487            } => {
13488                let outer = *outer as usize;
13489                let axis_dim = *axis_dim as usize;
13490                let num_idx = *num_idx as usize;
13491                let trailing = *trailing as usize;
13492                unsafe {
13493                    if *table_bytes == 8 {
13494                        let tab = sl_i64(*table, base, outer * axis_dim * trailing);
13495                        let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
13496                        for o in 0..outer {
13497                            let tab_outer = o * axis_dim * trailing;
13498                            let out_outer = o * num_idx * trailing;
13499                            if *idx_i64 != 0 {
13500                                let ids = sl_i64(*idx, base, num_idx);
13501                                for k in 0..num_idx {
13502                                    let row = ids[k].max(0) as usize;
13503                                    if row < axis_dim {
13504                                        let tab_row = tab_outer + row * trailing;
13505                                        let out_row = out_outer + k * trailing;
13506                                        out[out_row..out_row + trailing]
13507                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13508                                    }
13509                                }
13510                            } else {
13511                                let ids = sl(*idx, base, num_idx);
13512                                for k in 0..num_idx {
13513                                    let row = ids[k] as usize;
13514                                    if row < axis_dim {
13515                                        let tab_row = tab_outer + row * trailing;
13516                                        let out_row = out_outer + k * trailing;
13517                                        out[out_row..out_row + trailing]
13518                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13519                                    }
13520                                }
13521                            }
13522                        }
13523                    } else {
13524                        let tab = sl(*table, base, outer * axis_dim * trailing);
13525                        let out = sl_mut(*dst, base, outer * num_idx * trailing);
13526                        for o in 0..outer {
13527                            let tab_outer = o * axis_dim * trailing;
13528                            let out_outer = o * num_idx * trailing;
13529                            if *idx_i64 != 0 {
13530                                let ids = sl_i64(*idx, base, num_idx);
13531                                for k in 0..num_idx {
13532                                    let row = ids[k].max(0) as usize;
13533                                    if row < axis_dim {
13534                                        let tab_row = tab_outer + row * trailing;
13535                                        let out_row = out_outer + k * trailing;
13536                                        out[out_row..out_row + trailing]
13537                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13538                                    }
13539                                }
13540                            } else {
13541                                let ids = sl(*idx, base, num_idx);
13542                                for k in 0..num_idx {
13543                                    let row = ids[k] as usize;
13544                                    if row < axis_dim {
13545                                        let tab_row = tab_outer + row * trailing;
13546                                        let out_row = out_outer + k * trailing;
13547                                        out[out_row..out_row + trailing]
13548                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13549                                    }
13550                                }
13551                            }
13552                        }
13553                    }
13554                }
13555            }
13556
13557            Thunk::Transpose {
13558                src,
13559                dst,
13560                in_total,
13561                out_dims,
13562                in_strides,
13563                elem_bytes,
13564            } => {
13565                // N-D index walk: for each output flat index, decompose into
13566                // multi-dim coords using out_dims, then dot with in_strides
13567                // to find the source flat index. Stride 0 = broadcast (read
13568                // the same input element repeatedly along that dim).
13569                let rank = out_dims.len();
13570                let total: usize = out_dims.iter().map(|&d| d as usize).product();
13571                let in_total = *in_total as usize;
13572                unsafe {
13573                    if *elem_bytes == 8 {
13574                        let inp = sl_i64(*src, base, in_total);
13575                        let out = sl_mut_i64(*dst, base, total);
13576                        let mut idx = vec![0usize; rank];
13577                        for o in 0..total {
13578                            let mut src_idx = 0usize;
13579                            for d in 0..rank {
13580                                src_idx += idx[d] * in_strides[d] as usize;
13581                            }
13582                            out[o] = inp[src_idx];
13583                            for d in (0..rank).rev() {
13584                                idx[d] += 1;
13585                                if idx[d] < out_dims[d] as usize {
13586                                    break;
13587                                }
13588                                idx[d] = 0;
13589                            }
13590                        }
13591                    } else {
13592                        let inp = sl(*src, base, in_total);
13593                        let out = sl_mut(*dst, base, total);
13594                        let mut idx = vec![0usize; rank];
13595                        for o in 0..total {
13596                            let mut src_idx = 0usize;
13597                            for d in 0..rank {
13598                                src_idx += idx[d] * in_strides[d] as usize;
13599                            }
13600                            out[o] = inp[src_idx];
13601                            for d in (0..rank).rev() {
13602                                idx[d] += 1;
13603                                if idx[d] < out_dims[d] as usize {
13604                                    break;
13605                                }
13606                                idx[d] = 0;
13607                            }
13608                        }
13609                    }
13610                }
13611            }
13612
13613            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
13614            // stubs here as placeholders during the wire-up; both
13615            // are now reached by the real implementations earlier in
13616            // this same match — the stubs were dead code shadowed by
13617            // the specific-pattern arms above. Removed.)
13618            Thunk::CustomOp {
13619                kernel,
13620                inputs,
13621                output,
13622                attrs,
13623            } => {
13624                let (out_off, out_len, out_shape) = output;
13625                unsafe {
13626                    dispatch_custom_op(
13627                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
13628                    );
13629                }
13630            }
13631        }
13632        if trace_done {
13633            eprintln!("[thunk {i} done]");
13634        }
13635    }
13636}
13637
13638/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
13639/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
13640/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
13641/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
13642///
13643/// Compared to the iterative segment-cached scheme, this trades extra
13644/// recompute for less working memory — each level of recursion holds
13645/// one `cb`-sized intermediate carry on the stack but never the whole
13646/// segment at once. With K saved outer checkpoints, the outer driver
13647/// invokes this helper once per segment.
13648///
13649/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
13650/// runs `body_vjp` at iteration `t` with the supplied carry, threads
13651/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
13652#[allow(clippy::too_many_arguments)]
13653unsafe fn griewank_process_segment(
13654    t_lo: usize,
13655    t_hi: usize,
13656    anchor_carry: &[u8],
13657    cb: usize,
13658    fwd_sched: &ThunkSchedule,
13659    fwd_init: &[u8],
13660    fwd_carry_in_off: usize,
13661    fwd_output_off: usize,
13662    fwd_x_offs: &[usize],
13663    base: *mut u8,
13664    outer_xs_offs: &[(usize, u32)],
13665    fwd_buf: &mut Vec<u8>,
13666    leaf_threshold: usize,
13667    process_iter: &mut dyn FnMut(usize, &[u8]),
13668) {
13669    unsafe {
13670        let size = t_hi - t_lo + 1;
13671        if size == 1 {
13672            process_iter(t_lo, anchor_carry);
13673            return;
13674        }
13675        if size <= leaf_threshold {
13676            // Walk forward, cache each carry, run backward in reverse.
13677            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
13678            cache.extend_from_slice(anchor_carry);
13679            fwd_buf.copy_from_slice(fwd_init);
13680            std::ptr::copy_nonoverlapping(
13681                anchor_carry.as_ptr(),
13682                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13683                cb,
13684            );
13685            for i in 1..size {
13686                let cur_iter = t_lo + i - 1;
13687                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13688                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13689                    let xb = x_psb as usize;
13690                    std::ptr::copy_nonoverlapping(
13691                        base.add(outer_xs_off + cur_iter * xb),
13692                        fwd_buf.as_mut_ptr().add(*fb_x_off),
13693                        xb,
13694                    );
13695                }
13696                execute_thunks(fwd_sched, fwd_buf);
13697                if fwd_output_off != fwd_carry_in_off {
13698                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13699                }
13700                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
13701            }
13702            // Process backward.
13703            for t in (t_lo..=t_hi).rev() {
13704                let idx = t - t_lo;
13705                let carry = &cache[idx * cb..(idx + 1) * cb];
13706                process_iter(t, carry);
13707            }
13708            return;
13709        }
13710
13711        // Split: walk forward from anchor to compute carry entering `mid`.
13712        // (We need `mid - t_lo` body executions: one per iteration in
13713        // [t_lo, mid).)
13714        let mid = t_lo + size / 2;
13715        fwd_buf.copy_from_slice(fwd_init);
13716        std::ptr::copy_nonoverlapping(
13717            anchor_carry.as_ptr(),
13718            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13719            cb,
13720        );
13721        for cur_iter in t_lo..mid {
13722            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13723                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13724                let xb = x_psb as usize;
13725                std::ptr::copy_nonoverlapping(
13726                    base.add(outer_xs_off + cur_iter * xb),
13727                    fwd_buf.as_mut_ptr().add(*fb_x_off),
13728                    xb,
13729                );
13730            }
13731            execute_thunks(fwd_sched, fwd_buf);
13732            if fwd_output_off != fwd_carry_in_off {
13733                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13734            }
13735        }
13736        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
13737
13738        // Right half first (higher t values processed first to match the
13739        // canonical reverse-mode iteration order: dcarry threads from
13740        // t=length-1 down to t=0).
13741        griewank_process_segment(
13742            mid,
13743            t_hi,
13744            &mid_carry,
13745            cb,
13746            fwd_sched,
13747            fwd_init,
13748            fwd_carry_in_off,
13749            fwd_output_off,
13750            fwd_x_offs,
13751            base,
13752            outer_xs_offs,
13753            fwd_buf,
13754            leaf_threshold,
13755            process_iter,
13756        );
13757        // Then left half with original anchor.
13758        griewank_process_segment(
13759            t_lo,
13760            mid - 1,
13761            anchor_carry,
13762            cb,
13763            fwd_sched,
13764            fwd_init,
13765            fwd_carry_in_off,
13766            fwd_output_off,
13767            fwd_x_offs,
13768            base,
13769            outer_xs_offs,
13770            fwd_buf,
13771            leaf_threshold,
13772            process_iter,
13773        );
13774    }
13775}
13776
13777/// Execute a batched 1D FFT in the f64 2N-real-block layout.
13778/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
13779/// The `outer` rows are independent and processed sequentially.
13780///
13781/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
13782/// kernel — only the twiddle-factor sign differs. Power-of-2 only
13783/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
13784/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
13785/// backend crates can invoke this as a host fallback against a
13786/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
13787/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
13788/// command buffer). Self-contained — no rlx-cpu state required.
13789///
13790/// Safety: `base + src` and `base + dst` must be valid for the
13791/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
13792/// alive for the duration of the call.
13793pub unsafe fn execute_fft1d_f64(
13794    src: usize,
13795    dst: usize,
13796    outer: usize,
13797    n_complex: usize,
13798    inverse: bool,
13799    norm_tag: u32,
13800    base: *mut u8,
13801) {
13802    let row_elems = 2 * n_complex;
13803    let mut re = vec![0f64; n_complex];
13804    let mut im = vec![0f64; n_complex];
13805    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13806    let scale = norm.output_scale(n_complex, inverse);
13807    // Scratch reused across rows for the Bluestein path. Empty when
13808    // we're on the radix-2 fast path.
13809    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13810        BluesteinScratchF64::empty()
13811    } else {
13812        BluesteinScratchF64::build(n_complex, inverse)
13813    };
13814    for o in 0..outer {
13815        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
13816        let s = unsafe { sl_f64(row_offset, base, row_elems) };
13817        re.copy_from_slice(&s[..n_complex]);
13818        im.copy_from_slice(&s[n_complex..]);
13819        if n_complex.is_power_of_two() {
13820            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
13821        } else if n_complex <= 16 {
13822            fft_naive_inplace_f64(&mut re, &mut im, inverse);
13823        } else {
13824            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
13825        }
13826        if scale != 1.0 {
13827            re.iter_mut().for_each(|v| *v *= scale);
13828            im.iter_mut().for_each(|v| *v *= scale);
13829        }
13830        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
13831        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
13832        d[..n_complex].copy_from_slice(&re);
13833        d[n_complex..].copy_from_slice(&im);
13834    }
13835}
13836
13837/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
13838/// (first N real, second N imag per row), same unnormalized
13839/// convention; only the element width differs. Twiddle factors are
13840/// computed in f64 and cast to f32 to keep large-N error closer to
13841/// the f64 path (the savings from f32 are in memory bandwidth, not in
13842/// twiddle precision).
13843/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
13844/// When `state == 0`, uses a zero-initialized scratch state per batch item.
13845pub unsafe fn execute_gated_delta_net_f32(
13846    q: usize,
13847    k: usize,
13848    v: usize,
13849    g: usize,
13850    beta: usize,
13851    state: usize,
13852    dst: usize,
13853    batch: usize,
13854    seq: usize,
13855    heads: usize,
13856    state_size: usize,
13857    base: *mut u8,
13858) {
13859    use rayon::prelude::*;
13860
13861    #[derive(Copy, Clone)]
13862    struct ArenaPtr(usize);
13863    unsafe impl Send for ArenaPtr {}
13864    unsafe impl Sync for ArenaPtr {}
13865    impl ArenaPtr {
13866        #[inline]
13867        fn get(self) -> *mut u8 {
13868            self.0 as *mut u8
13869        }
13870    }
13871
13872    unsafe {
13873        let arena = ArenaPtr(base as usize);
13874        let (b, s, h, n) = (batch, seq, heads, state_size);
13875        let scale = 1.0f32 / (n as f32).sqrt();
13876        let use_external = state != 0;
13877        let mut owned_state = vec![0f32; h * n * n];
13878
13879        crate::pool::num_threads();
13880
13881        assert!(
13882            n <= crate::gdn::GDN_MAX_STATE,
13883            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
13884            crate::gdn::GDN_MAX_STATE
13885        );
13886
13887        let qs = sl(q, arena.get(), b * s * h * n);
13888        let ks = sl(k, arena.get(), b * s * h * n);
13889        let vs = sl(v, arena.get(), b * s * h * n);
13890        let gs = sl(g, arena.get(), b * s * h);
13891        let betas = sl(beta, arena.get(), b * s * h);
13892        let _out = sl_mut(dst, arena.get(), b * s * h * n);
13893        let hs_n = h * n;
13894
13895        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
13896            for ti in 0..s {
13897                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
13898                let gb_step = bi * s * h + ti * h + hi;
13899                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
13900                crate::gdn::gdn_step_blas(
13901                    s_mat,
13902                    &qs[qkv_step..qkv_step + n],
13903                    &ks[qkv_step..qkv_step + n],
13904                    &vs[qkv_step..qkv_step + n],
13905                    gs[gb_step],
13906                    betas[gb_step],
13907                    out_row,
13908                    sk,
13909                    n,
13910                    scale,
13911                );
13912            }
13913        };
13914
13915        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
13916        // better occupancy than head-outer when prompt length dominates.
13917        if !use_external && s > 1 {
13918            for bi in 0..b {
13919                (0..h).into_par_iter().for_each(|hi| {
13920                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13921                    let sk = &mut sk_buf[..n];
13922                    let mut local_state =
13923                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
13924                    let s_mat = &mut local_state[..n * n];
13925                    s_mat.fill(0.0);
13926                    run_head(bi, hi, s_mat, sk);
13927                });
13928            }
13929            return;
13930        }
13931
13932        if use_external {
13933            let state_bytes = state;
13934            (0..b * h).into_par_iter().for_each(|bhi| {
13935                let bi = bhi / h;
13936                let hi = bhi % h;
13937                let elem_off = bi * h * n * n + hi * n * n;
13938                let s_mat = sl_mut(
13939                    state_bytes + elem_off * std::mem::size_of::<f32>(),
13940                    arena.get(),
13941                    n * n,
13942                );
13943                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13944                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
13945            });
13946        } else {
13947            for bi in 0..b {
13948                owned_state.fill(0.0);
13949                owned_state
13950                    .par_chunks_mut(n * n)
13951                    .enumerate()
13952                    .for_each(|(hi, s_mat)| {
13953                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13954                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
13955                    });
13956            }
13957        }
13958    }
13959}
13960
13961/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
13962pub unsafe fn execute_rms_norm_backward_input_f32(
13963    x: usize,
13964    gamma: usize,
13965    beta: usize,
13966    dy: usize,
13967    dx: usize,
13968    rows: u32,
13969    h: u32,
13970    eps: f32,
13971    base: *mut u8,
13972) {
13973    let (rows, h) = (rows as usize, h as usize);
13974    let mut dg = vec![0f32; h];
13975    let mut db = vec![0f32; h];
13976    let xs = sl(x, base, rows * h);
13977    let dys = sl(dy, base, rows * h);
13978    let g = sl(gamma, base, h);
13979    let b = sl(beta, base, h);
13980    let out = sl_mut(dx, base, rows * h);
13981    for r in 0..rows {
13982        crate::training_bwd::rms_norm_backward_row(
13983            &xs[r * h..(r + 1) * h],
13984            g,
13985            b,
13986            &dys[r * h..(r + 1) * h],
13987            &mut out[r * h..(r + 1) * h],
13988            &mut dg,
13989            &mut db,
13990            eps,
13991        );
13992    }
13993}
13994
13995pub unsafe fn execute_rms_norm_backward_gamma_f32(
13996    x: usize,
13997    gamma: usize,
13998    beta: usize,
13999    dy: usize,
14000    dgamma: usize,
14001    rows: u32,
14002    h: u32,
14003    eps: f32,
14004    base: *mut u8,
14005) {
14006    let (rows, h) = (rows as usize, h as usize);
14007    let out = sl_mut(dgamma, base, h);
14008    out.fill(0.0);
14009    let mut dx = vec![0f32; h];
14010    let mut db = vec![0f32; h];
14011    let xs = sl(x, base, rows * h);
14012    let dys = sl(dy, base, rows * h);
14013    let g = sl(gamma, base, h);
14014    let b = sl(beta, base, h);
14015    for r in 0..rows {
14016        crate::training_bwd::rms_norm_backward_row(
14017            &xs[r * h..(r + 1) * h],
14018            g,
14019            b,
14020            &dys[r * h..(r + 1) * h],
14021            &mut dx,
14022            out,
14023            &mut db,
14024            eps,
14025        );
14026    }
14027}
14028
14029pub unsafe fn execute_rms_norm_backward_beta_f32(
14030    x: usize,
14031    gamma: usize,
14032    beta: usize,
14033    dy: usize,
14034    dbeta: usize,
14035    rows: u32,
14036    h: u32,
14037    eps: f32,
14038    base: *mut u8,
14039) {
14040    let (rows, h) = (rows as usize, h as usize);
14041    let out = sl_mut(dbeta, base, h);
14042    out.fill(0.0);
14043    let mut dx = vec![0f32; h];
14044    let mut dg = vec![0f32; h];
14045    let xs = sl(x, base, rows * h);
14046    let dys = sl(dy, base, rows * h);
14047    let g = sl(gamma, base, h);
14048    let b = sl(beta, base, h);
14049    for r in 0..rows {
14050        crate::training_bwd::rms_norm_backward_row(
14051            &xs[r * h..(r + 1) * h],
14052            g,
14053            b,
14054            &dys[r * h..(r + 1) * h],
14055            &mut dx,
14056            &mut dg,
14057            out,
14058            eps,
14059        );
14060    }
14061}
14062
14063#[allow(clippy::too_many_arguments)]
14064pub unsafe fn execute_conv2d_forward_f32(
14065    src: usize,
14066    weight: usize,
14067    dst: usize,
14068    n: u32,
14069    c_in: u32,
14070    h: u32,
14071    w: u32,
14072    c_out: u32,
14073    h_out: u32,
14074    w_out: u32,
14075    kh: u32,
14076    kw: u32,
14077    sh: u32,
14078    sw: u32,
14079    ph: u32,
14080    pw: u32,
14081    dh: u32,
14082    dw: u32,
14083    groups: u32,
14084    base: *mut u8,
14085) {
14086    let n = n as usize;
14087    let c_in = c_in as usize;
14088    let h = h as usize;
14089    let w = w as usize;
14090    let c_out = c_out as usize;
14091    let h_out = h_out as usize;
14092    let w_out = w_out as usize;
14093    let kh = kh as usize;
14094    let kw = kw as usize;
14095    let sh = sh as usize;
14096    let sw = sw as usize;
14097    let ph = ph as usize;
14098    let pw = pw as usize;
14099    let dh = dh as usize;
14100    let dw = dw as usize;
14101    let groups = groups as usize;
14102    let c_in_per_g = c_in / groups;
14103    let inp = sl(src, base, n * c_in * h * w);
14104    let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14105    let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14106    crate::conv_fwd::conv2d_forward_nchw_f32(
14107        inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14108    );
14109}
14110
14111pub unsafe fn execute_maxpool2d_backward_f32(
14112    x: usize,
14113    dy: usize,
14114    dx: usize,
14115    n: u32,
14116    c: u32,
14117    h: u32,
14118    w: u32,
14119    h_out: u32,
14120    w_out: u32,
14121    kh: u32,
14122    kw: u32,
14123    sh: u32,
14124    sw: u32,
14125    ph: u32,
14126    pw: u32,
14127    base: *mut u8,
14128) {
14129    let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14130    let (h_out, w_out) = (h_out as usize, w_out as usize);
14131    let (kh, kw) = (kh as usize, kw as usize);
14132    let (sh, sw) = (sh as usize, sw as usize);
14133    let (ph, pw) = (ph as usize, pw as usize);
14134    let xs = sl(x, base, n * c * h * w);
14135    let dys = sl(dy, base, n * c * h_out * w_out);
14136    let dxs = sl_mut(dx, base, n * c * h * w);
14137    crate::training_bwd::maxpool2d_backward_nchw(
14138        xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14139    );
14140}
14141
14142pub unsafe fn execute_rope_backward_f32(
14143    dy: usize,
14144    cos: usize,
14145    sin: usize,
14146    dx: usize,
14147    batch: u32,
14148    seq: u32,
14149    hidden: u32,
14150    head_dim: u32,
14151    n_rot: u32,
14152    cos_len: u32,
14153    base: *mut u8,
14154) {
14155    let (b, s, hs, dh, nr, cl) = (
14156        batch as usize,
14157        seq as usize,
14158        hidden as usize,
14159        head_dim as usize,
14160        n_rot as usize,
14161        cos_len as usize,
14162    );
14163    let nh = hs / dh;
14164    let tab_half = dh / 2;
14165    let dys = sl(dy, base, b * s * hs);
14166    let cos_tab = sl(cos, base, cl);
14167    let sin_tab = sl(sin, base, cl);
14168    let out = sl_mut(dx, base, b * s * hs);
14169    for bi in 0..b {
14170        for si in 0..s {
14171            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
14172            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
14173            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
14174            for hi in 0..nh {
14175                let base_idx = bi * s * hs + si * hs + hi * dh;
14176                crate::training_bwd::rope_backward_row(
14177                    &dys[base_idx..base_idx + dh],
14178                    cp,
14179                    sp,
14180                    &mut out[base_idx..base_idx + dh],
14181                    dh,
14182                    nr,
14183                );
14184            }
14185        }
14186    }
14187}
14188
14189pub unsafe fn execute_cumsum_backward_f32(
14190    dy: usize,
14191    dx: usize,
14192    rows: u32,
14193    cols: u32,
14194    exclusive: bool,
14195    base: *mut u8,
14196) {
14197    let (rows, cols) = (rows as usize, cols as usize);
14198    let dys = sl(dy, base, rows * cols);
14199    let out = sl_mut(dx, base, rows * cols);
14200    for r in 0..rows {
14201        crate::training_bwd::cumsum_backward_row(
14202            &dys[r * cols..(r + 1) * cols],
14203            &mut out[r * cols..(r + 1) * cols],
14204            exclusive,
14205        );
14206    }
14207}
14208
14209pub unsafe fn execute_gather_backward_f32(
14210    dy: usize,
14211    indices: usize,
14212    dst: usize,
14213    outer: u32,
14214    axis_dim: u32,
14215    num_idx: u32,
14216    trailing: u32,
14217    base: *mut u8,
14218) {
14219    let (outer, axis_dim, num_idx, trailing) = (
14220        outer as usize,
14221        axis_dim as usize,
14222        num_idx as usize,
14223        trailing as usize,
14224    );
14225    let out = sl_mut(dst, base, outer * axis_dim * trailing);
14226    out.fill(0.0);
14227    crate::training_bwd::gather_axis_backward(
14228        sl(dy, base, outer * num_idx * trailing),
14229        sl(indices, base, num_idx),
14230        out,
14231        outer,
14232        axis_dim,
14233        num_idx,
14234        trailing,
14235    );
14236}
14237
14238/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
14239pub unsafe fn execute_dequant_matmul_gguf_f32(
14240    x: usize,
14241    w_q: usize,
14242    dst: usize,
14243    m: usize,
14244    k: usize,
14245    n: usize,
14246    scheme: rlx_ir::quant::QuantScheme,
14247    base: *mut u8,
14248) {
14249    unsafe {
14250        let block_bytes = scheme.gguf_block_bytes() as usize;
14251        let block_elems = scheme.gguf_block_size() as usize;
14252        let total_bytes = (k * n) / block_elems * block_bytes;
14253        let xs = sl(x, base, m * k);
14254        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
14255        let out = sl_mut(dst, base, m * n);
14256        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
14257    }
14258}
14259
14260/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
14261pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
14262    input: usize,
14263    w_q: usize,
14264    expert_idx: usize,
14265    dst: usize,
14266    m: usize,
14267    k: usize,
14268    n: usize,
14269    num_experts: usize,
14270    scheme: rlx_ir::quant::QuantScheme,
14271    base: *mut u8,
14272) {
14273    unsafe {
14274        let block_bytes = scheme.gguf_block_bytes() as usize;
14275        let block_elems = scheme.gguf_block_size() as usize;
14276        let slab_bytes = (k * n) / block_elems * block_bytes;
14277        let xs = sl(input, base, m * k);
14278        let w_bytes =
14279            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
14280        let ids = sl(expert_idx, base, m);
14281        let out = sl_mut(dst, base, m * n);
14282        crate::gguf_matmul::gguf_grouped_matmul_bt(
14283            xs,
14284            w_bytes,
14285            ids,
14286            out,
14287            m,
14288            k,
14289            n,
14290            num_experts,
14291            scheme,
14292        );
14293    }
14294}
14295
14296/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
14297pub unsafe fn execute_dequant_matmul_int4_f32(
14298    x: usize,
14299    w_q: usize,
14300    scale: usize,
14301    zp: usize,
14302    dst: usize,
14303    m: usize,
14304    k: usize,
14305    n: usize,
14306    block_size: u32,
14307    is_asymmetric: bool,
14308    base: *mut u8,
14309) {
14310    let bs = block_size as usize;
14311    let n_blocks = k.div_ceil(bs);
14312    unsafe {
14313        let xs = sl(x, base, m * k);
14314        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14315        let scales = sl(scale, base, n_blocks * n);
14316        let zps = if is_asymmetric {
14317            sl(zp, base, n_blocks * n)
14318        } else {
14319            &[][..]
14320        };
14321        let out = sl_mut(dst, base, m * n);
14322        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
14323    }
14324}
14325
14326/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
14327pub unsafe fn execute_dequant_matmul_fp8_f32(
14328    x: usize,
14329    w_q: usize,
14330    scale: usize,
14331    dst: usize,
14332    m: usize,
14333    k: usize,
14334    n: usize,
14335    e5m2: bool,
14336    base: *mut u8,
14337) {
14338    unsafe {
14339        let xs = sl(x, base, m * k);
14340        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
14341        let scales = sl(scale, base, n);
14342        let out = sl_mut(dst, base, m * n);
14343        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
14344    }
14345}
14346
14347/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
14348pub unsafe fn execute_dequant_matmul_nvfp4_f32(
14349    x: usize,
14350    w_q: usize,
14351    scale: usize,
14352    global_scale: usize,
14353    dst: usize,
14354    m: usize,
14355    k: usize,
14356    n: usize,
14357    base: *mut u8,
14358) {
14359    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
14360    unsafe {
14361        let xs = sl(x, base, m * k);
14362        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14363        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
14364        let gs = sl(global_scale, base, 1)[0];
14365        let out = sl_mut(dst, base, m * n);
14366        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
14367    }
14368}
14369
14370/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
14371pub unsafe fn execute_gated_delta_net_f16(
14372    q: usize,
14373    k: usize,
14374    v: usize,
14375    g: usize,
14376    beta: usize,
14377    state: usize,
14378    dst: usize,
14379    batch: usize,
14380    seq: usize,
14381    heads: usize,
14382    state_size: usize,
14383    base: *mut u8,
14384) {
14385    use half::f16;
14386    unsafe {
14387        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
14388            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
14389            raw.chunks_exact(2)
14390                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
14391                .collect()
14392        };
14393        let write_f16 = |off: usize, data: &[f32]| {
14394            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
14395            for (i, &v) in data.iter().enumerate() {
14396                let le = f16::from_f32(v).to_le_bytes();
14397                out[i * 2] = le[0];
14398                out[i * 2 + 1] = le[1];
14399            }
14400        };
14401
14402        let (b, s, h, n) = (batch, seq, heads, state_size);
14403        let q_f = read_f16(q, b * s * h * n);
14404        let k_f = read_f16(k, b * s * h * n);
14405        let v_f = read_f16(v, b * s * h * n);
14406        let g_f = read_f16(g, b * s * h);
14407        let b_f = read_f16(beta, b * s * h);
14408        let mut state_f = if state != 0 {
14409            read_f16(state, b * h * n * n)
14410        } else {
14411            vec![0f32; b * h * n * n]
14412        };
14413        let mut out_f = vec![0f32; b * s * h * n];
14414        let scale = 1.0f32 / (n as f32).sqrt();
14415        let mut sk_buf = vec![0f32; n];
14416        let mut owned_state = vec![0f32; h * n * n];
14417
14418        for bi in 0..b {
14419            let state_slice: &mut [f32] = if state != 0 {
14420                let start = bi * h * n * n;
14421                &mut state_f[start..start + h * n * n]
14422            } else {
14423                owned_state.fill(0.0);
14424                &mut owned_state
14425            };
14426
14427            for ti in 0..s {
14428                let qkv_step_base = bi * s * h * n + ti * h * n;
14429                let gb_step_base = bi * s * h + ti * h;
14430
14431                for hi in 0..h {
14432                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14433                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14434                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14435                    let g_t = g_f[gb_step_base + hi];
14436                    let beta_t = b_f[gb_step_base + hi];
14437
14438                    let s_base = hi * n * n;
14439                    let s_mat = &mut state_slice[s_base..s_base + n * n];
14440
14441                    let g_exp = g_t.exp();
14442                    for st in s_mat.iter_mut() {
14443                        *st *= g_exp;
14444                    }
14445
14446                    for j in 0..n {
14447                        let mut acc = 0f32;
14448                        for i in 0..n {
14449                            acc += s_mat[i * n + j] * k_row[i];
14450                        }
14451                        sk_buf[j] = acc;
14452                    }
14453
14454                    for j in 0..n {
14455                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
14456                    }
14457
14458                    for i in 0..n {
14459                        let ki = k_row[i];
14460                        for j in 0..n {
14461                            s_mat[i * n + j] += ki * sk_buf[j];
14462                        }
14463                    }
14464
14465                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14466                    for j in 0..n {
14467                        let mut acc = 0f32;
14468                        for i in 0..n {
14469                            acc += s_mat[i * n + j] * q_row[i];
14470                        }
14471                        out_row[j] = acc * scale;
14472                    }
14473                }
14474            }
14475        }
14476
14477        write_f16(dst, &out_f);
14478        if state != 0 {
14479            write_f16(state, &state_f);
14480        }
14481    }
14482}
14483
14484/// Host fallback for NCHW group norm (Metal unified-memory arena).
14485pub unsafe fn execute_group_norm_nchw_f32(
14486    src: usize,
14487    g: usize,
14488    b: usize,
14489    dst: usize,
14490    n: usize,
14491    c: usize,
14492    h: usize,
14493    w: usize,
14494    num_groups: usize,
14495    eps: f32,
14496    base: *mut u8,
14497) {
14498    let plane = c * h * w;
14499    for ni in 0..n {
14500        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14501        let gamma = unsafe { sl(g, base, c) };
14502        let beta = unsafe { sl(b, base, c) };
14503        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14504        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
14505    }
14506}
14507
14508/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
14509pub unsafe fn execute_layer_norm2d_nchw_f32(
14510    src: usize,
14511    g: usize,
14512    b: usize,
14513    dst: usize,
14514    n: usize,
14515    c: usize,
14516    h: usize,
14517    w: usize,
14518    eps: f32,
14519    base: *mut u8,
14520) {
14521    let plane = c * h * w;
14522    unsafe {
14523        let input = sl(src, base, n * plane);
14524        let gamma = sl(g, base, c);
14525        let beta = sl(b, base, c);
14526        let output = sl_mut(dst, base, n * plane);
14527        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
14528    }
14529}
14530
14531/// Host fallback for NCHW ConvTranspose2d.
14532pub unsafe fn execute_conv_transpose2d_nchw_f32(
14533    src: usize,
14534    weight: usize,
14535    dst: usize,
14536    n: usize,
14537    c_in: usize,
14538    h: usize,
14539    w_in: usize,
14540    c_out: usize,
14541    h_out: usize,
14542    w_out: usize,
14543    kh: usize,
14544    kw: usize,
14545    sh: usize,
14546    sw: usize,
14547    ph: usize,
14548    pw: usize,
14549    dh: usize,
14550    dw: usize,
14551    groups: usize,
14552    base: *mut u8,
14553) {
14554    let in_elems = n * c_in * h * w_in;
14555    let w_elems = c_in * (c_out / groups) * kh * kw;
14556    let out_elems = n * c_out * h_out * w_out;
14557    unsafe {
14558        let input = sl(src, base, in_elems);
14559        let wt = sl(weight, base, w_elems);
14560        let output = sl_mut(dst, base, out_elems);
14561        crate::kernels::conv_transpose2d_nchw(
14562            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
14563            dw, groups,
14564        );
14565    }
14566}
14567
14568/// Host fallback for nearest 2× upsample on NCHW.
14569pub unsafe fn execute_resize_nearest_2x_f32(
14570    src: usize,
14571    dst: usize,
14572    n: usize,
14573    c: usize,
14574    h: usize,
14575    w: usize,
14576    base: *mut u8,
14577) {
14578    let in_plane = c * h * w;
14579    let out_plane = c * h * 2 * w * 2;
14580    for ni in 0..n {
14581        let input = unsafe {
14582            sl(
14583                src + ni * in_plane * std::mem::size_of::<f32>(),
14584                base,
14585                in_plane,
14586            )
14587        };
14588        let output = unsafe {
14589            sl_mut(
14590                dst + ni * out_plane * std::mem::size_of::<f32>(),
14591                base,
14592                out_plane,
14593            )
14594        };
14595        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
14596    }
14597}
14598
14599/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
14600pub unsafe fn execute_axial_rope2d_f32(
14601    src: usize,
14602    dst: usize,
14603    batch: usize,
14604    seq: usize,
14605    hidden: usize,
14606    end_x: usize,
14607    end_y: usize,
14608    head_dim: usize,
14609    num_heads: usize,
14610    theta: f32,
14611    repeat_factor: usize,
14612    base: *mut u8,
14613) {
14614    let plane = seq * hidden;
14615    let plane_bytes = plane * std::mem::size_of::<f32>();
14616    for bi in 0..batch {
14617        let in_off = src + bi * plane_bytes;
14618        let input = unsafe { sl(in_off, base, plane) };
14619        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
14620            input,
14621            num_heads,
14622            seq,
14623            head_dim,
14624            end_x,
14625            end_y,
14626            theta,
14627            repeat_factor,
14628        );
14629        let out_off = dst + bi * plane_bytes;
14630        let output = unsafe { sl_mut(out_off, base, plane) };
14631        output.copy_from_slice(&rotated);
14632    }
14633}
14634
14635/// Ternary pruned radix-2 butterfly stage on `[batch, n_fft, 2]` interleaved state.
14636pub unsafe fn execute_fft_butterfly_stage_f32(
14637    state_src: usize,
14638    state_dst: usize,
14639    gate_src: usize,
14640    rev_src: usize,
14641    tw_re_src: usize,
14642    tw_im_src: usize,
14643    batch: usize,
14644    n_fft: usize,
14645    stage: usize,
14646    base: *mut u8,
14647) {
14648    let half = n_fft / 2;
14649    let stride = 1usize << stage;
14650    let gate = unsafe { sl(gate_src, base, half) };
14651    let rev = unsafe { sl(rev_src, base, half) };
14652    let tw_re = unsafe { sl(tw_re_src, base, half) };
14653    let tw_im = unsafe { sl(tw_im_src, base, half) };
14654    let row_elems = n_fft * 2;
14655    for b in 0..batch {
14656        let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
14657        let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
14658        let inp = unsafe { sl(in_off, base, row_elems) };
14659        let out = unsafe { sl_mut(out_off, base, row_elems) };
14660        out.copy_from_slice(inp);
14661        for bf in 0..half {
14662            if gate[bf] == 0.0 {
14663                continue;
14664            }
14665            let group = bf / stride;
14666            let k = bf % stride;
14667            let i0 = group * 2 * stride + k;
14668            let i1 = i0 + stride;
14669            let w_re = tw_re[bf];
14670            let w_im = tw_im[bf];
14671            let in_a_re = inp[i0 * 2];
14672            let in_a_im = inp[i0 * 2 + 1];
14673            let in_b_re = inp[i1 * 2];
14674            let in_b_im = inp[i1 * 2 + 1];
14675            let (b_re, b_im) = (
14676                in_b_re * w_re - in_b_im * w_im,
14677                in_b_re * w_im + in_b_im * w_re,
14678            );
14679            let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
14680            let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
14681            let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
14682                (bot_re, bot_im, top_re, top_im)
14683            } else {
14684                (top_re, top_im, bot_re, bot_im)
14685            };
14686            out[i0 * 2] = oa_re;
14687            out[i0 * 2 + 1] = oa_im;
14688            out[i1 * 2] = ob_re;
14689            out[i1 * 2 + 1] = ob_im;
14690        }
14691    }
14692}
14693
14694/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
14695pub unsafe fn execute_fft1d_f32(
14696    src: usize,
14697    dst: usize,
14698    outer: usize,
14699    n_complex: usize,
14700    inverse: bool,
14701    norm_tag: u32,
14702    base: *mut u8,
14703) {
14704    let row_elems = 2 * n_complex;
14705    let mut re = vec![0f32; n_complex];
14706    let mut im = vec![0f32; n_complex];
14707    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14708    let scale = norm.output_scale(n_complex, inverse) as f32;
14709    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14710        BluesteinScratchF32::empty()
14711    } else {
14712        BluesteinScratchF32::build(n_complex, inverse)
14713    };
14714    for o in 0..outer {
14715        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
14716        let s = unsafe { sl(row_offset, base, row_elems) };
14717        re.copy_from_slice(&s[..n_complex]);
14718        im.copy_from_slice(&s[n_complex..]);
14719        if n_complex.is_power_of_two() {
14720            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14721        } else if n_complex <= 16 {
14722            fft_naive_inplace_f32(&mut re, &mut im, inverse);
14723        } else {
14724            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14725        }
14726        if scale != 1.0 {
14727            re.iter_mut().for_each(|v| *v *= scale);
14728            im.iter_mut().for_each(|v| *v *= scale);
14729        }
14730        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
14731        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
14732        d[..n_complex].copy_from_slice(&re);
14733        d[n_complex..].copy_from_slice(&im);
14734    }
14735}
14736
14737/// C64 interleaved layout: each complex element is `[re: f32, im: f32]`.
14738pub unsafe fn execute_fft1d_c64(
14739    src: usize,
14740    dst: usize,
14741    outer: usize,
14742    n_complex: usize,
14743    inverse: bool,
14744    norm_tag: u32,
14745    base: *mut u8,
14746) {
14747    let row_bytes = n_complex * 8;
14748    let mut re = vec![0f32; n_complex];
14749    let mut im = vec![0f32; n_complex];
14750    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14751    let scale = norm.output_scale(n_complex, inverse) as f32;
14752    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14753        BluesteinScratchF32::empty()
14754    } else {
14755        BluesteinScratchF32::build(n_complex, inverse)
14756    };
14757    for o in 0..outer {
14758        let row_offset = src + o * row_bytes;
14759        for i in 0..n_complex {
14760            let elem_off = row_offset + i * 8;
14761            re[i] = f32::from_le_bytes([
14762                *base.add(elem_off),
14763                *base.add(elem_off + 1),
14764                *base.add(elem_off + 2),
14765                *base.add(elem_off + 3),
14766            ]);
14767            im[i] = f32::from_le_bytes([
14768                *base.add(elem_off + 4),
14769                *base.add(elem_off + 5),
14770                *base.add(elem_off + 6),
14771                *base.add(elem_off + 7),
14772            ]);
14773        }
14774        if n_complex.is_power_of_two() {
14775            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14776        } else if n_complex <= 16 {
14777            fft_naive_inplace_f32(&mut re, &mut im, inverse);
14778        } else {
14779            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14780        }
14781        if scale != 1.0 {
14782            re.iter_mut().for_each(|v| *v *= scale);
14783            im.iter_mut().for_each(|v| *v *= scale);
14784        }
14785        let dst_row = dst + o * row_bytes;
14786        for i in 0..n_complex {
14787            let elem_off = dst_row + i * 8;
14788            let re_b = re[i].to_le_bytes();
14789            let im_b = im[i].to_le_bytes();
14790            for j in 0..4 {
14791                *base.add(elem_off + j) = re_b[j];
14792                *base.add(elem_off + 4 + j) = im_b[j];
14793            }
14794        }
14795    }
14796}
14797
14798/// Dtype-dispatching host entry for `Op::LogMel` (shared by GPU host fallbacks).
14799pub unsafe fn execute_log_mel(
14800    spec: usize,
14801    filters: usize,
14802    dst: usize,
14803    outer: usize,
14804    n_fft: usize,
14805    n_bins: usize,
14806    n_mels: usize,
14807    base: *mut u8,
14808) {
14809    execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
14810}
14811
14812pub unsafe fn execute_log_mel_f32(
14813    spec: usize,
14814    filters: usize,
14815    dst: usize,
14816    outer: usize,
14817    n_fft: usize,
14818    n_bins: usize,
14819    n_mels: usize,
14820    base: *mut u8,
14821) {
14822    let spec_ptr = base.add(spec) as *const f32;
14823    let filt_ptr = base.add(filters) as *const f32;
14824    let dst_ptr = base.add(dst) as *mut f32;
14825    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14826    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14827    let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
14828    rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
14829}
14830
14831pub unsafe fn execute_log_mel_backward_f32(
14832    spec: usize,
14833    filters: usize,
14834    dy: usize,
14835    dst: usize,
14836    outer: usize,
14837    n_fft: usize,
14838    n_bins: usize,
14839    n_mels: usize,
14840    base: *mut u8,
14841) {
14842    let spec_ptr = base.add(spec) as *const f32;
14843    let filt_ptr = base.add(filters) as *const f32;
14844    let dy_ptr = base.add(dy) as *const f32;
14845    let dst_ptr = base.add(dst) as *mut f32;
14846    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14847    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14848    let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
14849    let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
14850    d_spec.fill(0.0);
14851    rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
14852}
14853
14854/// Dtype-dispatching host entry for `Op::Fft` (shared by GPU host fallbacks).
14855pub unsafe fn execute_fft1d(
14856    src: usize,
14857    dst: usize,
14858    outer: usize,
14859    n_complex: usize,
14860    inverse: bool,
14861    norm_tag: u32,
14862    dtype: rlx_ir::DType,
14863    base: *mut u8,
14864) {
14865    match dtype {
14866        rlx_ir::DType::F32 => {
14867            execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
14868        }
14869        rlx_ir::DType::F64 => {
14870            execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
14871        }
14872        rlx_ir::DType::C64 => {
14873            execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
14874        }
14875        other => panic!("execute_fft1d: unsupported dtype {other:?}"),
14876    }
14877}
14878
14879/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
14880/// the f64 path; twiddle recurrence is kept in f64 so accumulated
14881/// rotation drift doesn't dominate the per-stage error budget at
14882/// larger N.
14883fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
14884    let n = re.len();
14885    debug_assert_eq!(im.len(), n);
14886    debug_assert!(
14887        n.is_power_of_two(),
14888        "fft_radix2_f32: n={n} must be a power of two"
14889    );
14890    if n <= 1 {
14891        return;
14892    }
14893
14894    let mut j = 0usize;
14895    for i in 1..n {
14896        let mut bit = n >> 1;
14897        while j & bit != 0 {
14898            j ^= bit;
14899            bit >>= 1;
14900        }
14901        j ^= bit;
14902        if i < j {
14903            re.swap(i, j);
14904            im.swap(i, j);
14905        }
14906    }
14907
14908    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
14909    let mut len = 2usize;
14910    while len <= n {
14911        let half = len / 2;
14912        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14913        let w_re_step = theta.cos();
14914        let w_im_step = theta.sin();
14915        let mut i = 0usize;
14916        while i < n {
14917            let mut wre = 1.0_f64;
14918            let mut wim = 0.0_f64;
14919            for k in 0..half {
14920                let wre_f = wre as f32;
14921                let wim_f = wim as f32;
14922                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
14923                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
14924                let u_re = re[i + k];
14925                let u_im = im[i + k];
14926                re[i + k] = u_re + t_re;
14927                im[i + k] = u_im + t_im;
14928                re[i + k + half] = u_re - t_re;
14929                im[i + k + half] = u_im - t_im;
14930                let new_wre = wre * w_re_step - wim * w_im_step;
14931                let new_wim = wre * w_im_step + wim * w_re_step;
14932                wre = new_wre;
14933                wim = new_wim;
14934            }
14935            i += len;
14936        }
14937        len <<= 1;
14938    }
14939}
14940
14941/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
14942/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
14943/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
14944fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
14945    let n = re.len();
14946    debug_assert_eq!(im.len(), n);
14947    debug_assert!(
14948        n.is_power_of_two(),
14949        "fft_radix2: n={n} must be a power of two"
14950    );
14951    if n <= 1 {
14952        return;
14953    }
14954
14955    // Bit-reverse permutation.
14956    let mut j = 0usize;
14957    for i in 1..n {
14958        let mut bit = n >> 1;
14959        while j & bit != 0 {
14960            j ^= bit;
14961            bit >>= 1;
14962        }
14963        j ^= bit;
14964        if i < j {
14965            re.swap(i, j);
14966            im.swap(i, j);
14967        }
14968    }
14969
14970    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
14971    let sign = if inverse { 1.0 } else { -1.0 };
14972    let mut len = 2usize;
14973    while len <= n {
14974        let half = len / 2;
14975        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14976        let w_re_step = theta.cos();
14977        let w_im_step = theta.sin();
14978        let mut i = 0usize;
14979        while i < n {
14980            // Twiddle starts at 1+0i for each segment.
14981            let mut wre = 1.0_f64;
14982            let mut wim = 0.0_f64;
14983            for k in 0..half {
14984                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
14985                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
14986                let u_re = re[i + k];
14987                let u_im = im[i + k];
14988                re[i + k] = u_re + t_re;
14989                im[i + k] = u_im + t_im;
14990                re[i + k + half] = u_re - t_re;
14991                im[i + k + half] = u_im - t_im;
14992                let new_wre = wre * w_re_step - wim * w_im_step;
14993                let new_wim = wre * w_im_step + wim * w_re_step;
14994                wre = new_wre;
14995                wim = new_wim;
14996            }
14997            i += len;
14998        }
14999        len <<= 1;
15000    }
15001}
15002
15003/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
15004/// Built once per call to `execute_fft1d_f64` and reused across rows
15005/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
15006struct BluesteinScratchF64 {
15007    /// Power-of-two convolution length, ≥ 2N - 1.
15008    m: usize,
15009    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
15010    /// the requested direction. Forward chirp on the way in, output
15011    /// chirp on the way out.
15012    w_re: Vec<f64>,
15013    w_im: Vec<f64>,
15014    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
15015    /// Doesn't depend on the input — precomputed once.
15016    bf_re: Vec<f64>,
15017    bf_im: Vec<f64>,
15018    /// Workspace reused per row (avoids per-row allocation).
15019    ar: Vec<f64>,
15020    ai: Vec<f64>,
15021}
15022
15023impl BluesteinScratchF64 {
15024    fn empty() -> Self {
15025        Self {
15026            m: 0,
15027            w_re: Vec::new(),
15028            w_im: Vec::new(),
15029            bf_re: Vec::new(),
15030            bf_im: Vec::new(),
15031            ar: Vec::new(),
15032            ai: Vec::new(),
15033        }
15034    }
15035
15036    fn build(n: usize, inverse: bool) -> Self {
15037        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
15038        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
15039        let m = if n <= 1 {
15040            1
15041        } else {
15042            (2 * n - 1).next_power_of_two()
15043        };
15044
15045        // Chirp arg reduced via k² mod 2N — without this, large N
15046        // bleeds precision into the trig call (n² grows quadratically).
15047        let mod_2n = (2 * n) as u64;
15048        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15049        let mut w_re = vec![0.0_f64; n];
15050        let mut w_im = vec![0.0_f64; n];
15051        for k in 0..n {
15052            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15053            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15054            w_re[k] = theta.cos();
15055            w_im[k] = theta.sin();
15056        }
15057
15058        // Embed b[k] = conj(w[|k|]) into length M with the negative
15059        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
15060        let mut bf_re = vec![0.0_f64; m];
15061        let mut bf_im = vec![0.0_f64; m];
15062        if n > 0 {
15063            bf_re[0] = w_re[0];
15064            bf_im[0] = -w_im[0];
15065            for k in 1..n {
15066                bf_re[k] = w_re[k];
15067                bf_im[k] = -w_im[k];
15068                bf_re[m - k] = w_re[k];
15069                bf_im[m - k] = -w_im[k];
15070            }
15071        }
15072        if m > 1 {
15073            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15074        }
15075
15076        Self {
15077            m,
15078            w_re,
15079            w_im,
15080            bf_re,
15081            bf_im,
15082            ar: vec![0.0_f64; m],
15083            ai: vec![0.0_f64; m],
15084        }
15085    }
15086}
15087
15088/// Direct O(N²) DFT for small non-pow2 N (faster than Bluestein setup).
15089fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15090    let n = re.len();
15091    if n <= 1 {
15092        return;
15093    }
15094    let sign = if inverse { 1.0 } else { -1.0 };
15095    let mut out_re = vec![0.0_f64; n];
15096    let mut out_im = vec![0.0_f64; n];
15097    for k in 0..n {
15098        for nn in 0..n {
15099            let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15100            let c = theta.cos();
15101            let s = theta.sin();
15102            out_re[k] += re[nn] * c - im[nn] * s;
15103            out_im[k] += re[nn] * s + im[nn] * c;
15104        }
15105    }
15106    re.copy_from_slice(&out_re);
15107    im.copy_from_slice(&out_im);
15108}
15109
15110fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15111    let n = re.len();
15112    if n <= 1 {
15113        return;
15114    }
15115    let sign = if inverse { 1.0f32 } else { -1.0f32 };
15116    let mut out_re = vec![0.0_f32; n];
15117    let mut out_im = vec![0.0_f32; n];
15118    for k in 0..n {
15119        for nn in 0..n {
15120            let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
15121            let c = theta.cos();
15122            let s = theta.sin();
15123            out_re[k] += re[nn] * c - im[nn] * s;
15124            out_im[k] += re[nn] * s + im[nn] * c;
15125        }
15126    }
15127    re.copy_from_slice(&out_re);
15128    im.copy_from_slice(&out_im);
15129}
15130
15131/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
15132///   `n·k = (n² + k² - (k-n)²) / 2`
15133/// which lets the DFT be written as a linear convolution sandwiched
15134/// between two chirp multiplies:
15135///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
15136/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
15137/// Both directions stay unnormalized to match the radix-2 path, so the
15138/// chain rule keeps working without scaling.
15139fn fft_bluestein_inplace_f64(
15140    re: &mut [f64],
15141    im: &mut [f64],
15142    _inverse: bool,
15143    s: &mut BluesteinScratchF64,
15144) {
15145    let n = re.len();
15146    debug_assert_eq!(im.len(), n);
15147    debug_assert_eq!(s.w_re.len(), n);
15148    if n <= 1 {
15149        return;
15150    }
15151    let m = s.m;
15152
15153    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
15154    for k in 0..m {
15155        s.ar[k] = 0.0;
15156        s.ai[k] = 0.0;
15157    }
15158    for k in 0..n {
15159        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15160        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15161    }
15162
15163    // Length-M forward FFT of the padded chirped input.
15164    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
15165
15166    // Pointwise product with FFT(b). Stored back into (ar, ai).
15167    for k in 0..m {
15168        let ar = s.ar[k];
15169        let ai = s.ai[k];
15170        let br = s.bf_re[k];
15171        let bi = s.bf_im[k];
15172        s.ar[k] = ar * br - ai * bi;
15173        s.ai[k] = ar * bi + ai * br;
15174    }
15175
15176    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
15177    // divide by M to recover the true circular convolution.
15178    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
15179    let inv_m = 1.0 / (m as f64);
15180
15181    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
15182    for k in 0..n {
15183        let yr = s.ar[k] * inv_m;
15184        let yi = s.ai[k] * inv_m;
15185        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15186        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15187    }
15188}
15189
15190/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
15191/// precision (same justification as the radix-2 f32 path: twiddles in
15192/// f64, butterflies in f32). The actual conv buffers are f32.
15193struct BluesteinScratchF32 {
15194    m: usize,
15195    w_re: Vec<f32>,
15196    w_im: Vec<f32>,
15197    bf_re: Vec<f32>,
15198    bf_im: Vec<f32>,
15199    ar: Vec<f32>,
15200    ai: Vec<f32>,
15201}
15202
15203impl BluesteinScratchF32 {
15204    fn empty() -> Self {
15205        Self {
15206            m: 0,
15207            w_re: Vec::new(),
15208            w_im: Vec::new(),
15209            bf_re: Vec::new(),
15210            bf_im: Vec::new(),
15211            ar: Vec::new(),
15212            ai: Vec::new(),
15213        }
15214    }
15215
15216    fn build(n: usize, inverse: bool) -> Self {
15217        let m = if n <= 1 {
15218            1
15219        } else {
15220            (2 * n - 1).next_power_of_two()
15221        };
15222
15223        let mod_2n = (2 * n) as u64;
15224        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15225        let mut w_re = vec![0.0_f32; n];
15226        let mut w_im = vec![0.0_f32; n];
15227        for k in 0..n {
15228            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15229            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15230            w_re[k] = theta.cos() as f32;
15231            w_im[k] = theta.sin() as f32;
15232        }
15233
15234        let mut bf_re = vec![0.0_f32; m];
15235        let mut bf_im = vec![0.0_f32; m];
15236        if n > 0 {
15237            bf_re[0] = w_re[0];
15238            bf_im[0] = -w_im[0];
15239            for k in 1..n {
15240                bf_re[k] = w_re[k];
15241                bf_im[k] = -w_im[k];
15242                bf_re[m - k] = w_re[k];
15243                bf_im[m - k] = -w_im[k];
15244            }
15245        }
15246        if m > 1 {
15247            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
15248        }
15249
15250        Self {
15251            m,
15252            w_re,
15253            w_im,
15254            bf_re,
15255            bf_im,
15256            ar: vec![0.0_f32; m],
15257            ai: vec![0.0_f32; m],
15258        }
15259    }
15260}
15261
15262fn fft_bluestein_inplace_f32(
15263    re: &mut [f32],
15264    im: &mut [f32],
15265    _inverse: bool,
15266    s: &mut BluesteinScratchF32,
15267) {
15268    let n = re.len();
15269    debug_assert_eq!(im.len(), n);
15270    debug_assert_eq!(s.w_re.len(), n);
15271    if n <= 1 {
15272        return;
15273    }
15274    let m = s.m;
15275
15276    for k in 0..m {
15277        s.ar[k] = 0.0;
15278        s.ai[k] = 0.0;
15279    }
15280    for k in 0..n {
15281        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15282        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15283    }
15284
15285    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
15286
15287    for k in 0..m {
15288        let ar = s.ar[k];
15289        let ai = s.ai[k];
15290        let br = s.bf_re[k];
15291        let bi = s.bf_im[k];
15292        s.ar[k] = ar * br - ai * bi;
15293        s.ai[k] = ar * bi + ai * br;
15294    }
15295
15296    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
15297    let inv_m = 1.0_f32 / (m as f32);
15298
15299    for k in 0..n {
15300        let yr = s.ar[k] * inv_m;
15301        let yi = s.ai[k] * inv_m;
15302        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15303        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15304    }
15305}
15306
15307/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
15308/// [`CpuTensorRef`] for each input *at that input's declared dtype*
15309/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
15310/// typed slices) and a [`CpuTensorMut`] for the output, then calls
15311/// the kernel's single `execute` method.
15312unsafe fn dispatch_custom_op(
15313    kernel: &dyn crate::op_registry::CpuKernel,
15314    inputs: &[(usize, u32, Shape)],
15315    out_off: usize,
15316    out_len: u32,
15317    out_shape: &Shape,
15318    attrs: &[u8],
15319    base: *mut u8,
15320) {
15321    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
15322    use rlx_ir::DType;
15323
15324    // One arm per `DType` variant — single source of truth for
15325    // "which dtypes the CPU custom-op dispatcher wires." If a new
15326    // DType lands in `rlx-ir`, the compiler flags this match as
15327    // non-exhaustive and the gap gets named at the right place.
15328    macro_rules! build_in_view {
15329        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
15330            CpuTensorRef::$variant {
15331                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
15332                shape: $shape,
15333            }
15334        };
15335    }
15336    macro_rules! build_out_view {
15337        ($variant:ident, $rust_ty:ty) => {
15338            CpuTensorMut::$variant {
15339                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
15340                shape: out_shape,
15341            }
15342        };
15343    }
15344
15345    let in_views: Vec<CpuTensorRef<'_>> = inputs
15346        .iter()
15347        .map(|(off, len, shape)| {
15348            let n = *len as usize;
15349            let off = *off;
15350            match shape.dtype() {
15351                DType::F32 => build_in_view!(shape, off, n, F32, f32),
15352                DType::F64 => build_in_view!(shape, off, n, F64, f64),
15353                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
15354                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
15355                DType::I8 => build_in_view!(shape, off, n, I8, i8),
15356                DType::I16 => build_in_view!(shape, off, n, I16, i16),
15357                DType::I32 => build_in_view!(shape, off, n, I32, i32),
15358                DType::I64 => build_in_view!(shape, off, n, I64, i64),
15359                DType::U8 => build_in_view!(shape, off, n, U8, u8),
15360                DType::U32 => build_in_view!(shape, off, n, U32, u32),
15361                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
15362                // C64 isn't a CpuTensor variant today; the user-registered
15363                // op_registry path doesn't see complex inputs (those are
15364                // handled by built-in ops with dedicated kernels).
15365                DType::C64 => panic!(
15366                    "Op::Custom kernel input has DType::C64 — built-in \
15367                 complex ops handle their own kernels; user-registered \
15368                 ops don't yet see complex tensors"
15369                ),
15370            }
15371        })
15372        .collect();
15373
15374    let result = match out_shape.dtype() {
15375        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
15376        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
15377        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
15378        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
15379        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
15380        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
15381        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
15382        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
15383        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
15384        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
15385        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
15386        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
15387    };
15388    if let Err(e) = result {
15389        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
15390    }
15391}
15392
15393/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
15394/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
15395/// uses them at call sites with concrete dtypes); the custom-op
15396/// dispatcher uses these to enumerate every `DType` uniformly without
15397/// listing one helper per dtype.
15398#[inline(always)]
15399unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
15400    if offset == usize::MAX {
15401        return &[];
15402    }
15403    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
15404}
15405
15406#[inline(always)]
15407unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
15408    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
15409}
15410
15411// Unsafe helpers to create slices from arena base + offset
15412#[inline(always)]
15413/// In-place per-element activation. Mirrors the dispatch in
15414/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
15415/// apply the activation after `bias_add` for all non-Gelu cases.
15416fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
15417    use rlx_ir::op::Activation;
15418    match act {
15419        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
15420        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
15421        Activation::Silu => crate::kernels::par_silu_inplace(d),
15422        Activation::Relu => {
15423            for v in d.iter_mut() {
15424                *v = v.max(0.0);
15425            }
15426        }
15427        Activation::Sigmoid => {
15428            for v in d.iter_mut() {
15429                *v = 1.0 / (1.0 + (-*v).exp());
15430            }
15431        }
15432        Activation::Tanh => {
15433            for v in d.iter_mut() {
15434                *v = v.tanh();
15435            }
15436        }
15437        Activation::Exp => {
15438            for v in d.iter_mut() {
15439                *v = v.exp();
15440            }
15441        }
15442        Activation::Log => {
15443            for v in d.iter_mut() {
15444                *v = v.ln();
15445            }
15446        }
15447        Activation::Sqrt => {
15448            for v in d.iter_mut() {
15449                *v = v.sqrt();
15450            }
15451        }
15452        Activation::Rsqrt => {
15453            for v in d.iter_mut() {
15454                *v = 1.0 / v.sqrt();
15455            }
15456        }
15457        Activation::Neg => {
15458            for v in d.iter_mut() {
15459                *v = -*v;
15460            }
15461        }
15462        Activation::Abs => {
15463            for v in d.iter_mut() {
15464                *v = v.abs();
15465            }
15466        }
15467        Activation::Round => {
15468            for v in d.iter_mut() {
15469                *v = v.round();
15470            }
15471        }
15472        Activation::Sin => {
15473            for v in d.iter_mut() {
15474                *v = v.sin();
15475            }
15476        }
15477        Activation::Cos => {
15478            for v in d.iter_mut() {
15479                *v = v.cos();
15480            }
15481        }
15482        Activation::Tan => {
15483            for v in d.iter_mut() {
15484                *v = v.tan();
15485            }
15486        }
15487        Activation::Atan => {
15488            for v in d.iter_mut() {
15489                *v = v.atan();
15490            }
15491        }
15492    }
15493}
15494
15495/// im2col for one image (single batch + group slice).
15496///
15497/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
15498/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
15499/// (in the padded region) are written as 0.
15500///
15501/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
15502///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
15503#[allow(clippy::too_many_arguments)]
15504fn im2col(
15505    x: &[f32],
15506    col: &mut [f32],
15507    c_in: usize,
15508    h: usize,
15509    w: usize,
15510    h_out: usize,
15511    w_out: usize,
15512    kh: usize,
15513    kw: usize,
15514    sh: usize,
15515    sw: usize,
15516    ph: usize,
15517    pw: usize,
15518    dh: usize,
15519    dw_dil: usize,
15520) {
15521    let n_dim = h_out * w_out;
15522    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15523    debug_assert_eq!(x.len(), c_in * h * w);
15524    let h_isz = h as isize;
15525    let w_isz = w as isize;
15526    let ph_isz = ph as isize;
15527    let pw_isz = pw as isize;
15528    for ci in 0..c_in {
15529        for ki in 0..kh {
15530            for kj in 0..kw {
15531                let row = ((ci * kh) + ki) * kw + kj;
15532                let row_off = row * n_dim;
15533                for ho in 0..h_out {
15534                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
15535                    if hi < 0 || hi >= h_isz {
15536                        for wo in 0..w_out {
15537                            col[row_off + ho * w_out + wo] = 0.0;
15538                        }
15539                        continue;
15540                    }
15541                    let hi = hi as usize;
15542                    let in_row_off = (ci * h + hi) * w;
15543                    for wo in 0..w_out {
15544                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15545                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
15546                            0.0
15547                        } else {
15548                            x[in_row_off + wi as usize]
15549                        };
15550                    }
15551                }
15552            }
15553        }
15554    }
15555}
15556
15557/// col2im — inverse of `im2col` with scatter-accumulation. The caller
15558/// is responsible for zeroing `x` if it doesn't already start zero
15559/// (the conv-input-grad path zeros once before the batch loop).
15560///
15561/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
15562/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
15563#[allow(clippy::too_many_arguments)]
15564fn col2im(
15565    col: &[f32],
15566    x: &mut [f32],
15567    c_in: usize,
15568    h: usize,
15569    w: usize,
15570    h_out: usize,
15571    w_out: usize,
15572    kh: usize,
15573    kw: usize,
15574    sh: usize,
15575    sw: usize,
15576    ph: usize,
15577    pw: usize,
15578    dh: usize,
15579    dw_dil: usize,
15580) {
15581    let n_dim = h_out * w_out;
15582    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15583    debug_assert_eq!(x.len(), c_in * h * w);
15584    let h_isz = h as isize;
15585    let w_isz = w as isize;
15586    let ph_isz = ph as isize;
15587    let pw_isz = pw as isize;
15588    for ci in 0..c_in {
15589        for ki in 0..kh {
15590            for kj in 0..kw {
15591                let row = ((ci * kh) + ki) * kw + kj;
15592                let row_off = row * n_dim;
15593                for ho in 0..h_out {
15594                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
15595                    if hi < 0 || hi >= h_isz {
15596                        continue;
15597                    }
15598                    let hi = hi as usize;
15599                    let in_row_off = (ci * h + hi) * w;
15600                    for wo in 0..w_out {
15601                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15602                        if wi < 0 || wi >= w_isz {
15603                            continue;
15604                        }
15605                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
15606                    }
15607                }
15608            }
15609        }
15610    }
15611}
15612
15613/// Element-wise backward for `Op::Activation`. `xs` is the original
15614/// input to the forward activation; `dys` is the upstream gradient.
15615/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
15616/// Decompose a per-channel quantization shape into the
15617/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
15618/// flat output index to a channel index. Per-tensor (`axis = None`)
15619/// degenerates to `chan_dim = 1, inner = len`, which makes the
15620/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
15621/// scalar version used.
15622fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
15623    match axis {
15624        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
15625        Some(d) => {
15626            let chan_dim = shape.dim(d).unwrap_static();
15627            let inner: usize = (d + 1..shape.rank())
15628                .map(|i| shape.dim(i).unwrap_static())
15629                .product::<usize>()
15630                .max(1);
15631            (d, chan_dim, inner)
15632        }
15633    }
15634}
15635
15636fn activation_backward_kernel(
15637    act: rlx_ir::op::Activation,
15638    xs: &[f32],
15639    dys: &[f32],
15640    out: &mut [f32],
15641) {
15642    use rlx_ir::op::Activation;
15643    let n = xs.len();
15644    debug_assert_eq!(dys.len(), n);
15645    debug_assert_eq!(out.len(), n);
15646    match act {
15647        Activation::Relu => {
15648            for i in 0..n {
15649                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15650            }
15651        }
15652        Activation::Sigmoid => {
15653            for i in 0..n {
15654                let s = 1.0 / (1.0 + (-xs[i]).exp());
15655                out[i] = s * (1.0 - s) * dys[i];
15656            }
15657        }
15658        Activation::Tanh => {
15659            for i in 0..n {
15660                let t = xs[i].tanh();
15661                out[i] = (1.0 - t * t) * dys[i];
15662            }
15663        }
15664        Activation::Silu => {
15665            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
15666            for i in 0..n {
15667                let s = 1.0 / (1.0 + (-xs[i]).exp());
15668                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15669            }
15670        }
15671        Activation::Gelu => {
15672            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
15673            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
15674            const INV_SQRT2: f32 = 0.707_106_77;
15675            const INV_SQRT_2PI: f32 = 0.398_942_3;
15676            for i in 0..n {
15677                let x = xs[i];
15678                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
15679                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15680                out[i] = (phi + x * pdf) * dys[i];
15681            }
15682        }
15683        Activation::GeluApprox => {
15684            // Tanh-approximation:
15685            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
15686            const C: f32 = 0.797_884_6; // √(2/π)
15687            const A: f32 = 0.044_715;
15688            for i in 0..n {
15689                let x = xs[i];
15690                let inner = C * (x + A * x * x * x);
15691                let t = inner.tanh();
15692                let dinner = C * (1.0 + 3.0 * A * x * x);
15693                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
15694                out[i] = d * dys[i];
15695            }
15696        }
15697        Activation::Exp => {
15698            for i in 0..n {
15699                out[i] = xs[i].exp() * dys[i];
15700            }
15701        }
15702        Activation::Log => {
15703            for i in 0..n {
15704                out[i] = dys[i] / xs[i];
15705            }
15706        }
15707        Activation::Sqrt => {
15708            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
15709            for i in 0..n {
15710                let s = xs[i].sqrt();
15711                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15712            }
15713        }
15714        Activation::Rsqrt => {
15715            // d/dx (1/√x) = -0.5 · x^(-3/2).
15716            for i in 0..n {
15717                let s = xs[i].sqrt();
15718                out[i] = if s > 0.0 {
15719                    -0.5 * dys[i] / (xs[i] * s)
15720                } else {
15721                    0.0
15722                };
15723            }
15724        }
15725        Activation::Neg => {
15726            for i in 0..n {
15727                out[i] = -dys[i];
15728            }
15729        }
15730        Activation::Abs => {
15731            // sign(x); 0 at x=0.
15732            for i in 0..n {
15733                let x = xs[i];
15734                let s = if x > 0.0 {
15735                    1.0
15736                } else if x < 0.0 {
15737                    -1.0
15738                } else {
15739                    0.0
15740                };
15741                out[i] = s * dys[i];
15742            }
15743        }
15744        Activation::Round => {
15745            // STE: pretend the round was identity in the backward
15746            // pass. The round step has zero gradient almost
15747            // everywhere, so without this trick the optimizer can't
15748            // learn through it.
15749            out.copy_from_slice(dys);
15750        }
15751        Activation::Sin => {
15752            // d/dx sin(x) = cos(x).
15753            for i in 0..n {
15754                out[i] = xs[i].cos() * dys[i];
15755            }
15756        }
15757        Activation::Cos => {
15758            for i in 0..n {
15759                out[i] = -xs[i].sin() * dys[i];
15760            }
15761        }
15762        Activation::Tan => {
15763            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
15764            for i in 0..n {
15765                let t = xs[i].tan();
15766                out[i] = (1.0 + t * t) * dys[i];
15767            }
15768        }
15769        Activation::Atan => {
15770            // d/dx atan(x) = 1 / (1 + x²)
15771            for i in 0..n {
15772                let x = xs[i];
15773                out[i] = dys[i] / (1.0 + x * x);
15774            }
15775        }
15776    }
15777}
15778
15779/// f64 sibling of `activation_backward_kernel`. Same math, twice the
15780/// precision — used by f64 graphs where the f32 kernel reading bytes
15781/// as `&[f32]` would silently discard half of every f64 value.
15782fn activation_backward_kernel_f64(
15783    act: rlx_ir::op::Activation,
15784    xs: &[f64],
15785    dys: &[f64],
15786    out: &mut [f64],
15787) {
15788    use rlx_ir::op::Activation;
15789    let n = xs.len();
15790    debug_assert_eq!(dys.len(), n);
15791    debug_assert_eq!(out.len(), n);
15792    match act {
15793        Activation::Relu => {
15794            for i in 0..n {
15795                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15796            }
15797        }
15798        Activation::Sigmoid => {
15799            for i in 0..n {
15800                let s = 1.0 / (1.0 + (-xs[i]).exp());
15801                out[i] = s * (1.0 - s) * dys[i];
15802            }
15803        }
15804        Activation::Tanh => {
15805            for i in 0..n {
15806                let t = xs[i].tanh();
15807                out[i] = (1.0 - t * t) * dys[i];
15808            }
15809        }
15810        Activation::Silu => {
15811            for i in 0..n {
15812                let s = 1.0 / (1.0 + (-xs[i]).exp());
15813                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15814            }
15815        }
15816        Activation::Gelu | Activation::GeluApprox => {
15817            // Both rare on f64 paths; use the high-quality libm erf.
15818            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
15819            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
15820            for i in 0..n {
15821                let x = xs[i];
15822                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
15823                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15824                out[i] = (phi + x * pdf) * dys[i];
15825            }
15826        }
15827        Activation::Exp => {
15828            for i in 0..n {
15829                out[i] = xs[i].exp() * dys[i];
15830            }
15831        }
15832        Activation::Log => {
15833            for i in 0..n {
15834                out[i] = dys[i] / xs[i];
15835            }
15836        }
15837        Activation::Sqrt => {
15838            for i in 0..n {
15839                let s = xs[i].sqrt();
15840                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15841            }
15842        }
15843        Activation::Rsqrt => {
15844            for i in 0..n {
15845                let s = xs[i].sqrt();
15846                out[i] = if s > 0.0 {
15847                    -0.5 * dys[i] / (xs[i] * s)
15848                } else {
15849                    0.0
15850                };
15851            }
15852        }
15853        Activation::Neg => {
15854            for i in 0..n {
15855                out[i] = -dys[i];
15856            }
15857        }
15858        Activation::Abs => {
15859            for i in 0..n {
15860                let x = xs[i];
15861                let s = if x > 0.0 {
15862                    1.0
15863                } else if x < 0.0 {
15864                    -1.0
15865                } else {
15866                    0.0
15867                };
15868                out[i] = s * dys[i];
15869            }
15870        }
15871        Activation::Round => {
15872            out.copy_from_slice(dys);
15873        }
15874        Activation::Sin => {
15875            for i in 0..n {
15876                out[i] = xs[i].cos() * dys[i];
15877            }
15878        }
15879        Activation::Cos => {
15880            for i in 0..n {
15881                out[i] = -xs[i].sin() * dys[i];
15882            }
15883        }
15884        Activation::Tan => {
15885            for i in 0..n {
15886                let t = xs[i].tan();
15887                out[i] = (1.0 + t * t) * dys[i];
15888            }
15889        }
15890        Activation::Atan => {
15891            for i in 0..n {
15892                let x = xs[i];
15893                out[i] = dys[i] / (1.0 + x * x);
15894            }
15895        }
15896    }
15897}
15898
15899/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
15900/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
15901/// arithmetic). Adequate for gradient kernels; if higher precision is
15902/// needed, swap in a libm dependency.
15903#[inline(always)]
15904fn erf_f64(x: f64) -> f64 {
15905    let s = x.signum();
15906    let x = x.abs();
15907    let t = 1.0 / (1.0 + 0.327_591_1 * x);
15908    let y = 1.0
15909        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
15910            + 0.254_829_59)
15911            * t
15912            * (-x * x).exp();
15913    s * y
15914}
15915
15916/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
15917/// over all of ℝ — plenty for f32 gradient kernels).
15918#[inline(always)]
15919fn erf_f32(x: f32) -> f32 {
15920    let s = x.signum();
15921    let x = x.abs();
15922    let t = 1.0 / (1.0 + 0.327_591_1 * x);
15923    let y = 1.0
15924        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
15925            + 0.254_829_6)
15926            * t
15927            * (-x * x).exp();
15928    s * y
15929}
15930
15931fn narrow_thunk_closure(
15932    src: usize,
15933    dst: usize,
15934    outer: u32,
15935    src_stride: u32,
15936    dst_stride: u32,
15937    inner: u32,
15938    elem_bytes: u8,
15939) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
15940    let (outer, ss, ds, inner, eb) = (
15941        outer as usize,
15942        src_stride as usize,
15943        dst_stride as usize,
15944        inner as usize,
15945        elem_bytes as usize,
15946    );
15947    let row_bytes = inner.saturating_mul(eb);
15948    let src_row_stride = ss.saturating_mul(eb);
15949    let dst_row_stride = ds.saturating_mul(eb);
15950    Arc::new(move |base: *mut u8| unsafe {
15951        if row_bytes == 0 || src == dst {
15952            return;
15953        }
15954        // Compiled-fn path has no arena length; skip if offsets look bogus.
15955        let arena_len = usize::MAX;
15956        for o in 0..outer {
15957            let s_off = src + o * src_row_stride;
15958            let d_off = dst + o * dst_row_stride;
15959            if s_off == d_off {
15960                continue;
15961            }
15962            if s_off.saturating_add(row_bytes) > arena_len
15963                || d_off.saturating_add(row_bytes) > arena_len
15964            {
15965                break;
15966            }
15967            std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
15968        }
15969    })
15970}
15971
15972unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
15973    if offset == usize::MAX {
15974        return &[];
15975    }
15976    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
15977}
15978
15979#[inline(always)]
15980unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
15981    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
15982}
15983
15984#[inline(always)]
15985unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
15986    if offset == usize::MAX {
15987        return &[];
15988    }
15989    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
15990}
15991
15992#[inline(always)]
15993unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
15994    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
15995}
15996
15997// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
15998// integer-tensor thunks that haven't landed yet (Sample, Gather index
15999// buffers); deleting them now would force re-deriving the unsafe
16000// boilerplate when the next int-typed thunk lands.
16001#[inline(always)]
16002#[allow(dead_code)]
16003unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16004    if offset == usize::MAX {
16005        return &[];
16006    }
16007    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16008}
16009
16010#[inline(always)]
16011#[allow(dead_code)]
16012unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16013    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16014}
16015
16016#[inline(always)]
16017unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16018    if offset == usize::MAX {
16019        return &[];
16020    }
16021    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16022}
16023
16024#[inline(always)]
16025unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16026    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16027}
16028
16029/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
16030/// the output shape; `in_strides` gives the source stride for each
16031/// output dim (broadcast axes have stride 0).
16032fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16033    let rank = out_dims.len();
16034    let mut idx = vec![0u32; rank];
16035    for o in 0..out.len() {
16036        let mut src_off = 0usize;
16037        for d in 0..rank {
16038            src_off += idx[d] as usize * in_strides[d] as usize;
16039        }
16040        out[o] = inp[src_off];
16041        // Increment index — last dim varies fastest.
16042        for d in (0..rank).rev() {
16043            idx[d] += 1;
16044            if idx[d] < out_dims[d] {
16045                break;
16046            }
16047            idx[d] = 0;
16048        }
16049    }
16050}
16051
16052/// f64 elementwise activation. Reads `inp`, writes `out`. For now
16053/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
16054/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
16055/// transcendental-free subset). Approximate Gelu/Silu deferred until a
16056/// workload demands them at f64.
16057fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16058    match kind {
16059        Activation::Neg => {
16060            for (o, &v) in out.iter_mut().zip(inp) {
16061                *o = -v;
16062            }
16063        }
16064        Activation::Exp => {
16065            for (o, &v) in out.iter_mut().zip(inp) {
16066                *o = v.exp();
16067            }
16068        }
16069        Activation::Log => {
16070            for (o, &v) in out.iter_mut().zip(inp) {
16071                *o = v.ln();
16072            }
16073        }
16074        Activation::Sqrt => {
16075            for (o, &v) in out.iter_mut().zip(inp) {
16076                *o = v.sqrt();
16077            }
16078        }
16079        Activation::Rsqrt => {
16080            for (o, &v) in out.iter_mut().zip(inp) {
16081                *o = 1.0 / v.sqrt();
16082            }
16083        }
16084        Activation::Abs => {
16085            for (o, &v) in out.iter_mut().zip(inp) {
16086                *o = v.abs();
16087            }
16088        }
16089        Activation::Tanh => {
16090            for (o, &v) in out.iter_mut().zip(inp) {
16091                *o = v.tanh();
16092            }
16093        }
16094        Activation::Sigmoid => {
16095            for (o, &v) in out.iter_mut().zip(inp) {
16096                *o = 1.0 / (1.0 + (-v).exp());
16097            }
16098        }
16099        Activation::Relu => {
16100            for (o, &v) in out.iter_mut().zip(inp) {
16101                *o = v.max(0.0);
16102            }
16103        }
16104        Activation::Round => {
16105            for (o, &v) in out.iter_mut().zip(inp) {
16106                *o = v.round_ties_even();
16107            }
16108        }
16109        Activation::Sin => {
16110            for (o, &v) in out.iter_mut().zip(inp) {
16111                *o = v.sin();
16112            }
16113        }
16114        Activation::Cos => {
16115            for (o, &v) in out.iter_mut().zip(inp) {
16116                *o = v.cos();
16117            }
16118        }
16119        Activation::Tan => {
16120            for (o, &v) in out.iter_mut().zip(inp) {
16121                *o = v.tan();
16122            }
16123        }
16124        Activation::Atan => {
16125            for (o, &v) in out.iter_mut().zip(inp) {
16126                *o = v.atan();
16127            }
16128        }
16129        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
16130            panic!(
16131                "apply_activation_f64: {kind:?} not yet implemented at f64. \
16132                    Add when a workload needs it."
16133            );
16134        }
16135    }
16136}
16137
16138#[inline]
16139fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
16140    match op {
16141        BinaryOp::Add => a + b,
16142        BinaryOp::Sub => a - b,
16143        BinaryOp::Mul => a * b,
16144        BinaryOp::Div => a / b,
16145        BinaryOp::Max => a.max(b),
16146        BinaryOp::Min => a.min(b),
16147        BinaryOp::Pow => a.powf(b),
16148    }
16149}
16150
16151/// f64 sum reduction over a contiguous middle range.
16152/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
16153fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
16154    for o in 0..outer {
16155        for n in 0..inner {
16156            let mut acc = 0.0_f64;
16157            for r in 0..reduced {
16158                acc += inp[o * reduced * inner + r * inner + n];
16159            }
16160            out[o * inner + n] = acc;
16161        }
16162    }
16163}
16164
16165#[cfg(test)]
16166mod tests {
16167    use super::*;
16168    use rlx_ir::*;
16169
16170    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
16171    /// fusion pass collapses them — the Narrow becomes Nop, and the
16172    /// Rope reads from the parent buffer with its row stride. This
16173    /// test runs the unfused path (batch*seq > FusedAttnBlock
16174    /// threshold) and asserts the rewrite happened.
16175    #[test]
16176    fn narrow_rope_fuses_in_unfused_path() {
16177        let f = DType::F32;
16178        let mut g = Graph::new("nr_fuse");
16179        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
16180        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
16181        let cos = g.input("cos", Shape::new(&[16], f));
16182        let sin = g.input("sin", Shape::new(&[16], f));
16183        // Last-axis narrow: Q = qkv[..., 0..64]
16184        let q = g.narrow_(qkv, 2, 0, 64);
16185        let q_rope = g.rope(q, cos, sin, 16);
16186        g.set_outputs(vec![q_rope]);
16187
16188        let plan = rlx_opt::memory::plan_memory(&g);
16189        let arena = crate::arena::Arena::from_plan(plan);
16190        let sched = compile_thunks(&g, &arena);
16191
16192        let mut narrow_count = 0;
16193        let mut rope_with_stride: Option<u32> = None;
16194        for t in &sched.thunks {
16195            match t {
16196                Thunk::Narrow { .. } => narrow_count += 1,
16197                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
16198                _ => {}
16199            }
16200        }
16201        // After fusion the Narrow is gone; only the Rope remains, and
16202        // it now walks with the parent QKV's row stride (3 * 64 = 192).
16203        assert_eq!(
16204            narrow_count, 0,
16205            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
16206        );
16207        assert_eq!(
16208            rope_with_stride,
16209            Some(192),
16210            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
16211        );
16212    }
16213
16214    /// Plan #15: SSM selective scan matches a naive Python-style
16215    /// Python-style sequential reference.
16216    #[test]
16217    fn ssm_selective_scan_matches_reference() {
16218        use rlx_ir::Philox4x32;
16219        let bch = 1usize;
16220        let s = 4usize;
16221        let h = 3usize;
16222        let n = 2usize;
16223
16224        let mut rng = Philox4x32::new(13);
16225        let mut x = vec![0f32; bch * s * h];
16226        rng.fill_normal(&mut x);
16227        let mut delta = vec![0f32; bch * s * h];
16228        // Keep Δ small so exp(Δ·A) doesn't blow up.
16229        for v in delta.iter_mut() {
16230            *v = (rng.next_f32() - 0.5) * 0.1;
16231        }
16232        let mut a = vec![0f32; h * n];
16233        for v in a.iter_mut() {
16234            *v = -(rng.next_f32() * 0.5 + 0.1);
16235        } // negative for stability
16236        let mut b = vec![0f32; bch * s * n];
16237        rng.fill_normal(&mut b);
16238        let mut c = vec![0f32; bch * s * n];
16239        rng.fill_normal(&mut c);
16240
16241        // Reference scan.
16242        let mut expected = vec![0f32; bch * s * h];
16243        for bi in 0..bch {
16244            let mut state = vec![0f32; h * n];
16245            for si in 0..s {
16246                for ci in 0..h {
16247                    let d = delta[bi * s * h + si * h + ci];
16248                    let xv = x[bi * s * h + si * h + ci];
16249                    let mut acc = 0f32;
16250                    for ni in 0..n {
16251                        let da = (d * a[ci * n + ni]).exp();
16252                        state[ci * n + ni] =
16253                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
16254                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
16255                    }
16256                    expected[bi * s * h + si * h + ci] = acc;
16257                }
16258            }
16259        }
16260
16261        // RLX path.
16262        let f = DType::F32;
16263        let mut g = Graph::new("ssm");
16264        let xn = g.input("x", Shape::new(&[bch, s, h], f));
16265        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
16266        let an = g.param("a", Shape::new(&[h, n], f));
16267        let bn = g.param("b", Shape::new(&[bch, s, n], f));
16268        let cn = g.param("c", Shape::new(&[bch, s, n], f));
16269        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
16270        g.set_outputs(vec![yn]);
16271
16272        let plan = rlx_opt::memory::plan_memory(&g);
16273        let mut arena = crate::arena::Arena::from_plan(plan);
16274        let sched = compile_thunks(&g, &arena);
16275
16276        let xn_off = arena.byte_offset(xn);
16277        let dn_off = arena.byte_offset(dn);
16278        let an_off = arena.byte_offset(an);
16279        let bn_off = arena.byte_offset(bn);
16280        let cn_off = arena.byte_offset(cn);
16281        let yn_off = arena.byte_offset(yn);
16282        let buf = arena.raw_buf_mut();
16283        unsafe {
16284            let copy = |dst: *mut f32, data: &[f32]| {
16285                for (i, &v) in data.iter().enumerate() {
16286                    *dst.add(i) = v;
16287                }
16288            };
16289            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16290            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
16291            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16292            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16293            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
16294        }
16295        execute_thunks(&sched, arena.raw_buf_mut());
16296
16297        let actual: Vec<f32> = unsafe {
16298            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
16299            (0..bch * s * h).map(|i| *p.add(i)).collect()
16300        };
16301
16302        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16303            assert!(
16304                (e - a).abs() < 1e-3,
16305                "mismatch at {i}: expected {e}, got {a}"
16306            );
16307        }
16308    }
16309
16310    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
16311    /// scalar 7-loop reference.
16312    #[test]
16313    fn conv_1x1_fast_path_matches_scalar() {
16314        use rlx_ir::Philox4x32;
16315        // [N=2, C_in=4, H=3, W=3]
16316        let n = 2usize;
16317        let c_in = 4usize;
16318        let h = 3usize;
16319        let w = 3usize;
16320        let c_out = 5usize;
16321        let mut rng = Philox4x32::new(31);
16322        let mut x = vec![0f32; n * c_in * h * w];
16323        rng.fill_normal(&mut x);
16324        let mut weight = vec![0f32; c_out * c_in];
16325        rng.fill_normal(&mut weight);
16326
16327        // Reference: scalar 1×1 conv = per-batch matmul
16328        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
16329        let mut expected = vec![0f32; n * c_out * h * w];
16330        for ni in 0..n {
16331            for co in 0..c_out {
16332                for hi in 0..h {
16333                    for wi in 0..w {
16334                        let mut acc = 0f32;
16335                        for ci in 0..c_in {
16336                            acc += weight[co * c_in + ci]
16337                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
16338                        }
16339                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
16340                    }
16341                }
16342            }
16343        }
16344
16345        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
16346        let f = DType::F32;
16347        let mut g = Graph::new("conv1x1");
16348        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16349        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
16350        // Manually add Op::Conv since there's no `g.conv()` helper.
16351        let cn = g.add_node(
16352            rlx_ir::Op::Conv {
16353                kernel_size: vec![1, 1],
16354                stride: vec![1, 1],
16355                padding: vec![0, 0],
16356                dilation: vec![1, 1],
16357                groups: 1,
16358            },
16359            vec![xn, wn],
16360            Shape::new(&[n, c_out, h, w], f),
16361        );
16362        g.set_outputs(vec![cn]);
16363
16364        let plan = rlx_opt::memory::plan_memory(&g);
16365        let mut arena = crate::arena::Arena::from_plan(plan);
16366        let sched = compile_thunks(&g, &arena);
16367
16368        // Verify the fast path was selected.
16369        let saw_fast = sched
16370            .thunks
16371            .iter()
16372            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
16373        let saw_slow = sched
16374            .thunks
16375            .iter()
16376            .any(|t| matches!(t, Thunk::Conv2D { .. }));
16377        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
16378        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
16379
16380        let xn_off = arena.byte_offset(xn);
16381        let wn_off = arena.byte_offset(wn);
16382        let cn_off = arena.byte_offset(cn);
16383        let buf = arena.raw_buf_mut();
16384        unsafe {
16385            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16386            for (i, &v) in x.iter().enumerate() {
16387                *xp.add(i) = v;
16388            }
16389            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
16390            for (i, &v) in weight.iter().enumerate() {
16391                *wp.add(i) = v;
16392            }
16393        }
16394        execute_thunks(&sched, arena.raw_buf_mut());
16395
16396        let actual: Vec<f32> = unsafe {
16397            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
16398            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
16399        };
16400
16401        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16402            assert!(
16403                (e - a).abs() < 1e-3,
16404                "mismatch at {i}: expected {e}, got {a}"
16405            );
16406        }
16407    }
16408
16409    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
16410    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
16411    #[test]
16412    fn dequant_matmul_int8_sym_matches_reference() {
16413        use rlx_ir::Philox4x32;
16414        use rlx_ir::quant::QuantScheme;
16415
16416        let m = 3usize;
16417        let k = 8usize;
16418        let n = 4usize;
16419        let block_size = 4usize; // 2 blocks per column
16420        let blocks_per_col = k / block_size;
16421
16422        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
16423        let mut rng = Philox4x32::new(99);
16424        let mut x = vec![0f32; m * k];
16425        rng.fill_normal(&mut x);
16426        let w_q: Vec<i8> = (0..(k * n))
16427            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
16428            .collect();
16429        let scales: Vec<f32> = (0..(blocks_per_col * n))
16430            .map(|i| 0.01 + 0.001 * i as f32)
16431            .collect();
16432
16433        // Reference: build f32 weights from (q * scale) per block.
16434        let mut w_f32 = vec![0f32; k * n];
16435        for p in 0..k {
16436            let block = p / block_size;
16437            for j in 0..n {
16438                let s = scales[block * n + j];
16439                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
16440            }
16441        }
16442        let mut expected = vec![0f32; m * n];
16443        for i in 0..m {
16444            for j in 0..n {
16445                let mut acc = 0f32;
16446                for p in 0..k {
16447                    acc += x[i * k + p] * w_f32[p * n + j];
16448                }
16449                expected[i * n + j] = acc;
16450            }
16451        }
16452
16453        // RLX path.
16454        let f = DType::F32;
16455        let mut g = Graph::new("dq");
16456        let xn = g.input("x", Shape::new(&[m, k], f));
16457        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
16458        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
16459        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
16460        let dq = g.dequant_matmul(
16461            xn,
16462            wn,
16463            sn,
16464            zn,
16465            QuantScheme::Int8Block {
16466                block_size: block_size as u32,
16467            },
16468            Shape::new(&[m, n], f),
16469        );
16470        g.set_outputs(vec![dq]);
16471
16472        let plan = rlx_opt::memory::plan_memory(&g);
16473        let mut arena = crate::arena::Arena::from_plan(plan);
16474        let sched = compile_thunks(&g, &arena);
16475
16476        let xn_off = arena.byte_offset(xn);
16477        let wn_off = arena.byte_offset(wn);
16478        let sn_off = arena.byte_offset(sn);
16479        let zn_off = arena.byte_offset(zn);
16480        let dq_off = arena.byte_offset(dq);
16481        let buf = arena.raw_buf_mut();
16482        unsafe {
16483            // Seed f32 inputs.
16484            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16485            for (i, &v) in x.iter().enumerate() {
16486                *xp.add(i) = v;
16487            }
16488            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
16489            for (i, &v) in scales.iter().enumerate() {
16490                *sp.add(i) = v;
16491            }
16492            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
16493            for i in 0..(blocks_per_col * n) {
16494                *zp.add(i) = 0.0;
16495            }
16496            // Seed i8 weights byte-by-byte.
16497            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
16498            for (i, &v) in w_q.iter().enumerate() {
16499                *wp.add(i) = v;
16500            }
16501        }
16502        execute_thunks(&sched, arena.raw_buf_mut());
16503
16504        let actual: Vec<f32> = unsafe {
16505            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
16506            (0..m * n).map(|i| *p.add(i)).collect()
16507        };
16508
16509        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16510            assert!(
16511                (e - a).abs() < 1e-3,
16512                "mismatch at {i}: expected {e}, got {a}"
16513            );
16514        }
16515    }
16516
16517    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
16518    #[test]
16519    fn lora_matmul_matches_unfused_reference() {
16520        use rlx_ir::Philox4x32;
16521
16522        let m = 4usize;
16523        let k = 8usize;
16524        let n = 6usize;
16525        let r = 2usize;
16526        let scale = 0.5f32;
16527
16528        // Random inputs (deterministic via Philox).
16529        let mut rng = Philox4x32::new(42);
16530        let mut x = vec![0f32; m * k];
16531        rng.fill_normal(&mut x);
16532        let mut w = vec![0f32; k * n];
16533        rng.fill_normal(&mut w);
16534        let mut a = vec![0f32; k * r];
16535        rng.fill_normal(&mut a);
16536        let mut b = vec![0f32; r * n];
16537        rng.fill_normal(&mut b);
16538
16539        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
16540        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
16541            let mut o = vec![0f32; rows * cols];
16542            for i in 0..rows {
16543                for j in 0..cols {
16544                    let mut acc = 0f32;
16545                    for p in 0..inner {
16546                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
16547                    }
16548                    o[i * cols + j] = acc;
16549                }
16550            }
16551            o
16552        };
16553        let xw = naive(&x, &w, m, k, n);
16554        let xa = naive(&x, &a, m, k, r);
16555        let xab = naive(&xa, &b, m, r, n);
16556        let mut expected = xw;
16557        for i in 0..(m * n) {
16558            expected[i] += scale * xab[i];
16559        }
16560
16561        // RLX path: build a graph with one LoraMatMul.
16562        let f = DType::F32;
16563        let mut g = Graph::new("lora");
16564        let xn = g.input("x", Shape::new(&[m, k], f));
16565        let wn = g.param("w", Shape::new(&[k, n], f));
16566        let an = g.param("a", Shape::new(&[k, r], f));
16567        let bn = g.param("b", Shape::new(&[r, n], f));
16568        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
16569        g.set_outputs(vec![lm]);
16570
16571        let plan = rlx_opt::memory::plan_memory(&g);
16572        let mut arena = crate::arena::Arena::from_plan(plan);
16573        let sched = compile_thunks(&g, &arena);
16574
16575        let xn_off = arena.byte_offset(xn);
16576        let wn_off = arena.byte_offset(wn);
16577        let an_off = arena.byte_offset(an);
16578        let bn_off = arena.byte_offset(bn);
16579        let lm_off = arena.byte_offset(lm);
16580        let buf = arena.raw_buf_mut();
16581        unsafe {
16582            let copy = |dst: *mut f32, data: &[f32]| {
16583                for (i, &v) in data.iter().enumerate() {
16584                    *dst.add(i) = v;
16585                }
16586            };
16587            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16588            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
16589            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16590            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16591        }
16592        execute_thunks(&sched, arena.raw_buf_mut());
16593
16594        let actual: Vec<f32> = unsafe {
16595            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
16596            (0..m * n).map(|i| *p.add(i)).collect()
16597        };
16598
16599        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16600            assert!(
16601                (e - a).abs() < 1e-3,
16602                "mismatch at {i}: expected {e}, got {a}"
16603            );
16604        }
16605    }
16606
16607    /// Plan #42: fused sampling kernel determinism + greedy fallback.
16608    #[test]
16609    fn sample_temperature_zero_is_argmax() {
16610        // Very low temperature → distribution collapses on argmax.
16611        // Same seed → same output bit-for-bit.
16612        let f = DType::F32;
16613        let mut g = Graph::new("samp");
16614        let logits = g.input("logits", Shape::new(&[1, 8], f));
16615        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
16616        g.set_outputs(vec![s]);
16617        let plan = rlx_opt::memory::plan_memory(&g);
16618        let mut arena = crate::arena::Arena::from_plan(plan);
16619        let sched = compile_thunks(&g, &arena);
16620
16621        let logits_off = arena.byte_offset(logits);
16622        let s_off = arena.byte_offset(s);
16623        let buf = arena.raw_buf_mut();
16624        unsafe {
16625            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16626            // argmax = index 5 (value 9.0).
16627            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
16628            for (i, &v) in inputs.iter().enumerate() {
16629                *p.add(i) = v;
16630            }
16631        }
16632        execute_thunks(&sched, arena.raw_buf_mut());
16633
16634        let token = unsafe {
16635            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16636            *p as usize
16637        };
16638        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
16639    }
16640
16641    #[test]
16642    fn sample_top_k_one_is_deterministic() {
16643        // top_k=1 forces only the argmax to have nonzero probability.
16644        let f = DType::F32;
16645        let mut g = Graph::new("samp_k1");
16646        let logits = g.input("logits", Shape::new(&[1, 4], f));
16647        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
16648        g.set_outputs(vec![s]);
16649        let plan = rlx_opt::memory::plan_memory(&g);
16650        let mut arena = crate::arena::Arena::from_plan(plan);
16651        let sched = compile_thunks(&g, &arena);
16652
16653        let logits_off = arena.byte_offset(logits);
16654        let s_off = arena.byte_offset(s);
16655        let buf = arena.raw_buf_mut();
16656        unsafe {
16657            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16658            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
16659            for (i, &v) in inputs.iter().enumerate() {
16660                *p.add(i) = v;
16661            }
16662        }
16663        execute_thunks(&sched, arena.raw_buf_mut());
16664        let token = unsafe {
16665            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16666            *p as usize
16667        };
16668        assert_eq!(token, 1);
16669    }
16670
16671    /// Plan #44: cumsum primitive parity vs. naive scan.
16672    #[test]
16673    fn cumsum_inclusive_matches_naive() {
16674        let f = DType::F32;
16675        let mut g = Graph::new("cumsum");
16676        let x = g.input("x", Shape::new(&[2, 4], f));
16677        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
16678        g.set_outputs(vec![cs]);
16679        let plan = rlx_opt::memory::plan_memory(&g);
16680        let mut arena = crate::arena::Arena::from_plan(plan);
16681        let sched = compile_thunks(&g, &arena);
16682
16683        // Cache offsets up-front so we can drop the immutable borrow.
16684        let x_off = arena.byte_offset(x);
16685        let out_off = arena.byte_offset(cs);
16686        let buf = arena.raw_buf_mut();
16687        unsafe {
16688            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
16689            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
16690            for (i, &v) in inputs.iter().enumerate() {
16691                *p.add(i) = v;
16692            }
16693        }
16694        execute_thunks(&sched, arena.raw_buf_mut());
16695
16696        let out: Vec<f32> = unsafe {
16697            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
16698            (0..8).map(|i| *p.add(i)).collect()
16699        };
16700        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
16701    }
16702
16703    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
16704    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
16705    /// path collapse into a single strided-Attention thunk.
16706    #[test]
16707    fn narrow_attention_fuses_in_unfused_path() {
16708        let f = DType::F32;
16709        let mut g = Graph::new("nattn_fuse");
16710        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
16711        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
16712        let mask = g.input("mask", Shape::new(&[8, 16], f));
16713        let q = g.narrow_(qkv, 2, 0, 64);
16714        let k = g.narrow_(qkv, 2, 64, 64);
16715        let v = g.narrow_(qkv, 2, 128, 64);
16716        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
16717        g.set_outputs(vec![attn]);
16718
16719        let plan = rlx_opt::memory::plan_memory(&g);
16720        let arena = crate::arena::Arena::from_plan(plan);
16721        let sched = compile_thunks(&g, &arena);
16722
16723        let mut narrow_count = 0;
16724        let mut attn_strides: Option<(u32, u32, u32)> = None;
16725        for t in &sched.thunks {
16726            match t {
16727                Thunk::Narrow { .. } => narrow_count += 1,
16728                Thunk::Attention {
16729                    q_row_stride,
16730                    k_row_stride,
16731                    v_row_stride,
16732                    ..
16733                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
16734                _ => {}
16735            }
16736        }
16737        // After fusion the 3 narrows are gone; Attention now walks the
16738        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
16739        assert_eq!(
16740            narrow_count, 0,
16741            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
16742        );
16743        assert_eq!(
16744            attn_strides,
16745            Some((192, 192, 192)),
16746            "Attention should walk Q/K/V with parent row stride 192"
16747        );
16748    }
16749
16750    // ── Backward / training op parity tests ────────────────────
16751    //
16752    // Strategy: build a graph that contains exactly the backward op
16753    // under test (plus its inputs as graph Inputs), execute, and
16754    // compare against a hand-rolled scalar reference. For
16755    // Conv2dBackwardInput we additionally check against the numerical
16756    // gradient of the forward Conv2D — that's the gold-standard test
16757    // that validates the math, not just consistency between two
16758    // implementations of the same formula.
16759
16760    fn run_graph(
16761        g: &Graph,
16762        inputs: &[(NodeId, &[f32])],
16763        out_id: NodeId,
16764        out_len: usize,
16765    ) -> Vec<f32> {
16766        let plan = rlx_opt::memory::plan_memory(g);
16767        let mut arena = crate::arena::Arena::from_plan(plan);
16768        let sched = compile_thunks(g, &arena);
16769        for &(id, data) in inputs {
16770            let off = arena.byte_offset(id);
16771            let buf = arena.raw_buf_mut();
16772            unsafe {
16773                let p = buf.as_mut_ptr().add(off) as *mut f32;
16774                for (i, &v) in data.iter().enumerate() {
16775                    *p.add(i) = v;
16776                }
16777            }
16778        }
16779        execute_thunks(&sched, arena.raw_buf_mut());
16780        let off = arena.byte_offset(out_id);
16781        unsafe {
16782            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
16783            (0..out_len).map(|i| *p.add(i)).collect()
16784        }
16785    }
16786
16787    #[test]
16788    fn relu_backward_matches_mask() {
16789        let f = DType::F32;
16790        let len = 7usize;
16791        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
16792        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
16793
16794        let mut g = Graph::new("relu_bw");
16795        let xn = g.input("x", Shape::new(&[len], f));
16796        let dyn_ = g.input("dy", Shape::new(&[len], f));
16797        let dx = g.relu_backward(xn, dyn_);
16798        g.set_outputs(vec![dx]);
16799
16800        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
16801        // Reference: gradient is dy where x>0 strictly, else 0.
16802        // (zero is not "positive" — the forward applied max(0, x), and at
16803        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
16804        let expected: Vec<f32> = x
16805            .iter()
16806            .zip(&dy)
16807            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
16808            .collect();
16809        for (a, e) in actual.iter().zip(&expected) {
16810            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
16811        }
16812    }
16813
16814    #[test]
16815    fn maxpool2d_backward_routes_to_argmax() {
16816        let f = DType::F32;
16817        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
16818        let x: Vec<f32> = vec![
16819            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
16820        ];
16821        // Argmax of each 2x2 window:
16822        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
16823        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
16824        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
16825
16826        let mut g = Graph::new("maxpool_bw");
16827        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
16828        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
16829        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
16830        g.set_outputs(vec![dx]);
16831
16832        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
16833        let mut expected = vec![0f32; 16];
16834        expected[5] = 0.5;
16835        expected[7] = 1.0;
16836        expected[13] = 2.0;
16837        expected[15] = 4.0;
16838        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
16839            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
16840        }
16841    }
16842
16843    #[test]
16844    fn conv2d_backward_input_matches_numerical_gradient() {
16845        use rlx_ir::Philox4x32;
16846        // Small enough to numerically differentiate exhaustively but
16847        // big enough to exercise stride/padding edge cases.
16848        let n = 1usize;
16849        let c_in = 2usize;
16850        let h = 4usize;
16851        let w = 4usize;
16852        let c_out = 3usize;
16853        let kh = 3usize;
16854        let kw = 3usize;
16855        let ph = 1usize;
16856        let pw = 1usize;
16857        let sh = 1usize;
16858        let sw = 1usize;
16859        // Output dims with padding=1, stride=1: same as input.
16860        let h_out = (h + 2 * ph - kh) / sh + 1;
16861        let w_out = (w + 2 * pw - kw) / sw + 1;
16862        assert_eq!(h_out, 4);
16863        assert_eq!(w_out, 4);
16864
16865        let mut rng = Philox4x32::new(7);
16866        let mut x = vec![0f32; n * c_in * h * w];
16867        rng.fill_normal(&mut x);
16868        let mut wt = vec![0f32; c_out * c_in * kh * kw];
16869        rng.fill_normal(&mut wt);
16870        let mut dy = vec![0f32; n * c_out * h_out * w_out];
16871        rng.fill_normal(&mut dy);
16872
16873        // Analytical: Conv2dBackwardInput on (dy, w).
16874        let f = DType::F32;
16875        let mut g = Graph::new("conv_bwi");
16876        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16877        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
16878        let dx = g.conv2d_backward_input(
16879            dy_in,
16880            w_in,
16881            Shape::new(&[n, c_in, h, w], f),
16882            vec![kh, kw],
16883            vec![sh, sw],
16884            vec![ph, pw],
16885            vec![1, 1],
16886            1,
16887        );
16888        g.set_outputs(vec![dx]);
16889        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
16890
16891        // Numerical: for each x[i], finite-difference forward conv twice.
16892        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
16893        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
16894        let forward = |x: &[f32]| -> Vec<f32> {
16895            let mut out = vec![0f32; n * c_out * h_out * w_out];
16896            for ni in 0..n {
16897                for co in 0..c_out {
16898                    for ho in 0..h_out {
16899                        for wo in 0..w_out {
16900                            let mut acc = 0f32;
16901                            for ci in 0..c_in {
16902                                for ki in 0..kh {
16903                                    for kj in 0..kw {
16904                                        let hi = ho * sh + ki;
16905                                        let wi = wo * sw + kj;
16906                                        if hi < ph || wi < pw {
16907                                            continue;
16908                                        }
16909                                        let hi = hi - ph;
16910                                        let wi = wi - pw;
16911                                        if hi >= h || wi >= w {
16912                                            continue;
16913                                        }
16914                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
16915                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
16916                                        acc += xv * wv;
16917                                    }
16918                                }
16919                            }
16920                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
16921                        }
16922                    }
16923                }
16924            }
16925            out
16926        };
16927        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
16928        let eps = 1e-3f32;
16929        let mut numerical = vec![0f32; x.len()];
16930        for i in 0..x.len() {
16931            let saved = x[i];
16932            x[i] = saved + eps;
16933            let plus = dot(&forward(&x), &dy);
16934            x[i] = saved - eps;
16935            let minus = dot(&forward(&x), &dy);
16936            x[i] = saved;
16937            numerical[i] = (plus - minus) / (2.0 * eps);
16938        }
16939        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
16940            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
16941            assert!(
16942                (a - n).abs() < 5e-3,
16943                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
16944            );
16945        }
16946    }
16947
16948    #[test]
16949    fn conv2d_backward_weight_matches_numerical_gradient() {
16950        use rlx_ir::Philox4x32;
16951        let n = 2usize;
16952        let c_in = 2usize;
16953        let h = 4usize;
16954        let w = 4usize;
16955        let c_out = 2usize;
16956        let kh = 3usize;
16957        let kw = 3usize;
16958        let ph = 0usize;
16959        let pw = 0usize;
16960        let sh = 1usize;
16961        let sw = 1usize;
16962        let h_out = (h + 2 * ph - kh) / sh + 1;
16963        let w_out = (w + 2 * pw - kw) / sw + 1;
16964
16965        let mut rng = Philox4x32::new(11);
16966        let mut x = vec![0f32; n * c_in * h * w];
16967        rng.fill_normal(&mut x);
16968        let mut wt = vec![0f32; c_out * c_in * kh * kw];
16969        rng.fill_normal(&mut wt);
16970        let mut dy = vec![0f32; n * c_out * h_out * w_out];
16971        rng.fill_normal(&mut dy);
16972
16973        let f = DType::F32;
16974        let mut g = Graph::new("conv_bww");
16975        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16976        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16977        let dwn = g.conv2d_backward_weight(
16978            xn,
16979            dyn_,
16980            Shape::new(&[c_out, c_in, kh, kw], f),
16981            vec![kh, kw],
16982            vec![sh, sw],
16983            vec![ph, pw],
16984            vec![1, 1],
16985            1,
16986        );
16987        g.set_outputs(vec![dwn]);
16988        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
16989
16990        let forward = |wt: &[f32]| -> Vec<f32> {
16991            let mut out = vec![0f32; n * c_out * h_out * w_out];
16992            for ni in 0..n {
16993                for co in 0..c_out {
16994                    for ho in 0..h_out {
16995                        for wo in 0..w_out {
16996                            let mut acc = 0f32;
16997                            for ci in 0..c_in {
16998                                for ki in 0..kh {
16999                                    for kj in 0..kw {
17000                                        let hi = ho + ki;
17001                                        let wi = wo + kj;
17002                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17003                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17004                                        acc += xv * wv;
17005                                    }
17006                                }
17007                            }
17008                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17009                        }
17010                    }
17011                }
17012            }
17013            out
17014        };
17015        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17016        let eps = 1e-3f32;
17017        let mut numerical = vec![0f32; wt.len()];
17018        for i in 0..wt.len() {
17019            let saved = wt[i];
17020            wt[i] = saved + eps;
17021            let plus = dot(&forward(&wt), &dy);
17022            wt[i] = saved - eps;
17023            let minus = dot(&forward(&wt), &dy);
17024            wt[i] = saved;
17025            numerical[i] = (plus - minus) / (2.0 * eps);
17026        }
17027        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17028            assert!(
17029                (a - n).abs() < 5e-3,
17030                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17031            );
17032        }
17033    }
17034
17035    #[test]
17036    fn softmax_cross_entropy_matches_reference() {
17037        let f = DType::F32;
17038        let logits: Vec<f32> = vec![
17039            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
17040            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
17041            5.0, 5.0, 5.0, // row 2: uniform
17042        ];
17043        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17044
17045        let mut g = Graph::new("sce");
17046        let lg = g.input("logits", Shape::new(&[3, 3], f));
17047        let lb = g.input("labels", Shape::new(&[3], f));
17048        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17049        g.set_outputs(vec![loss]);
17050        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17051
17052        // Reference per-row: -log(softmax(row)[label]).
17053        let mut expected = vec![0f32; 3];
17054        for ni in 0..3 {
17055            let row = &logits[ni * 3..(ni + 1) * 3];
17056            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17057            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17058            let lse = m + sum.ln();
17059            let label_idx = labels[ni] as usize;
17060            expected[ni] = lse - row[label_idx];
17061        }
17062        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17063            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17064        }
17065    }
17066
17067    #[test]
17068    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
17069        use rlx_ir::Philox4x32;
17070        let n = 4usize;
17071        let c = 5usize;
17072        let mut rng = Philox4x32::new(23);
17073        let mut logits = vec![0f32; n * c];
17074        rng.fill_normal(&mut logits);
17075        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
17076        let mut d_loss = vec![0f32; n];
17077        rng.fill_normal(&mut d_loss);
17078
17079        let f = DType::F32;
17080        let mut g = Graph::new("sce_bw");
17081        let lg = g.input("logits", Shape::new(&[n, c], f));
17082        let lb = g.input("labels", Shape::new(&[n], f));
17083        let dl = g.input("d_loss", Shape::new(&[n], f));
17084        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
17085        g.set_outputs(vec![dlogits]);
17086        let analytical = run_graph(
17087            &g,
17088            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
17089            dlogits,
17090            n * c,
17091        );
17092
17093        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
17094        let sce_loss = |logits: &[f32]| -> Vec<f32> {
17095            let mut out = vec![0f32; n];
17096            for ni in 0..n {
17097                let row = &logits[ni * c..(ni + 1) * c];
17098                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17099                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17100                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
17101            }
17102            out
17103        };
17104        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
17105        let eps = 1e-3f32;
17106        let mut numerical = vec![0f32; logits.len()];
17107        for i in 0..logits.len() {
17108            let saved = logits[i];
17109            logits[i] = saved + eps;
17110            let plus = dot(&sce_loss(&logits), &d_loss);
17111            logits[i] = saved - eps;
17112            let minus = dot(&sce_loss(&logits), &d_loss);
17113            logits[i] = saved;
17114            numerical[i] = (plus - minus) / (2.0 * eps);
17115        }
17116        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
17117            assert!(
17118                (a - num).abs() < 5e-3,
17119                "sce_bw[{i}]: analytical {a} vs numerical {num}"
17120            );
17121        }
17122    }
17123
17124    // ── End-to-end autodiff parity tests ──────────────────────
17125    //
17126    // Build a forward graph, run `grad_with_loss` to produce a graph
17127    // that emits [loss, gradients...], execute it through rlx-cpu,
17128    // and compare each gradient to a finite-difference estimate
17129    // produced by re-running the forward graph with each parameter
17130    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
17131    // 5e-3 absolute error.
17132
17133    /// Initialize Op::Constant slots in the arena with their literal
17134    /// data. Mirrors the loop in rlx_runtime::backend (which serves
17135    /// the same role for production runs).
17136    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
17137        for node in graph.nodes() {
17138            if let Op::Constant { data } = &node.op
17139                && arena.has_buffer(node.id)
17140                && !data.is_empty()
17141            {
17142                let buf = arena.slice_mut(node.id);
17143                let n_floats = data.len() / 4;
17144                let n = buf.len().min(n_floats);
17145                for i in 0..n {
17146                    let bytes = [
17147                        data[i * 4],
17148                        data[i * 4 + 1],
17149                        data[i * 4 + 2],
17150                        data[i * 4 + 3],
17151                    ];
17152                    buf[i] = f32::from_le_bytes(bytes);
17153                }
17154            }
17155        }
17156    }
17157
17158    /// Compile + arena-prep helper for these tests. Returns the
17159    /// schedule and a populated arena. `seed_inputs` writes f32 input
17160    /// data into the arena slot for each (NodeId, &[f32]) pair.
17161    fn prepare(
17162        graph: &Graph,
17163        seed_inputs: &[(NodeId, &[f32])],
17164    ) -> (ThunkSchedule, crate::arena::Arena) {
17165        let plan = rlx_opt::memory::plan_memory(graph);
17166        let mut arena = crate::arena::Arena::from_plan(plan);
17167        let sched = compile_thunks(graph, &arena);
17168        fill_constants_into_arena(graph, &mut arena);
17169        for &(id, data) in seed_inputs {
17170            let off = arena.byte_offset(id);
17171            let buf = arena.raw_buf_mut();
17172            unsafe {
17173                let p = buf.as_mut_ptr().add(off) as *mut f32;
17174                for (i, &v) in data.iter().enumerate() {
17175                    *p.add(i) = v;
17176                }
17177            }
17178        }
17179        (sched, arena)
17180    }
17181
17182    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
17183        let off = arena.byte_offset(id);
17184        unsafe {
17185            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17186            (0..len).map(|i| *p.add(i)).collect()
17187        }
17188    }
17189
17190    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
17191        let off = arena.byte_offset(id);
17192        let buf = arena.raw_buf_mut();
17193        unsafe {
17194            let p = buf.as_mut_ptr().add(off) as *mut f32;
17195            for (i, &v) in data.iter().enumerate() {
17196                *p.add(i) = v;
17197            }
17198        }
17199    }
17200
17201    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
17202    fn prepare_f64(
17203        graph: &Graph,
17204        seed_inputs: &[(NodeId, &[f64])],
17205    ) -> (ThunkSchedule, crate::arena::Arena) {
17206        let plan = rlx_opt::memory::plan_memory(graph);
17207        let mut arena = crate::arena::Arena::from_plan(plan);
17208        let sched = compile_thunks(graph, &arena);
17209        fill_constants_into_arena(graph, &mut arena);
17210        for &(id, data) in seed_inputs {
17211            let off = arena.byte_offset(id);
17212            let buf = arena.raw_buf_mut();
17213            unsafe {
17214                let p = buf.as_mut_ptr().add(off) as *mut f64;
17215                for (i, &v) in data.iter().enumerate() {
17216                    *p.add(i) = v;
17217                }
17218            }
17219        }
17220        (sched, arena)
17221    }
17222
17223    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
17224        let off = arena.byte_offset(id);
17225        unsafe {
17226            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
17227            (0..len).map(|i| *p.add(i)).collect()
17228        }
17229    }
17230
17231    /// End-to-end f64 DenseSolve through the full compile + execute
17232    /// path. Validates: IR shape inference, memory planner f64 sizing,
17233    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
17234    /// dispatch, Accelerate dgesv FFI.
17235    ///
17236    /// System:
17237    ///   A = [[2, 1],
17238    ///        [1, 3]]   b = [5, 10]
17239    ///   ⇒  x = [1, 3]   (verified by hand)
17240    #[test]
17241    fn dense_solve_f64_end_to_end() {
17242        let mut g = Graph::new("solve_e2e");
17243        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
17244        let b = g.input("b", Shape::new(&[2], DType::F64));
17245        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
17246        g.set_outputs(vec![x]);
17247
17248        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
17249        let b_data = [5.0, 10.0_f64];
17250        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17251        execute_thunks(&sched, arena.raw_buf_mut());
17252
17253        let got = read_arena_f64(&arena, x, 2);
17254        let want = [1.0, 3.0_f64];
17255        for i in 0..2 {
17256            assert!(
17257                (got[i] - want[i]).abs() < 1e-12,
17258                "x[{i}] = {} (expected {})",
17259                got[i],
17260                want[i]
17261            );
17262        }
17263    }
17264
17265    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
17266    /// MNA structure for a passive RC mesh in Circulax). Validates
17267    /// that the solve scales beyond the trivial 2×2 and that the
17268    /// row-major ↔ col-major dance in `dgesv` is correct for the
17269    /// general case.
17270    #[test]
17271    fn dense_solve_f64_5x5_laplacian() {
17272        let n = 5usize;
17273        let mut g = Graph::new("solve_5x5");
17274        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17275        let b = g.input("b", Shape::new(&[n], DType::F64));
17276        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17277        g.set_outputs(vec![x]);
17278
17279        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
17280        let mut a_data = vec![0.0_f64; n * n];
17281        for i in 0..n {
17282            a_data[i * n + i] = 2.0;
17283            if i > 0 {
17284                a_data[i * n + (i - 1)] = -1.0;
17285            }
17286            if i + 1 < n {
17287                a_data[i * n + (i + 1)] = -1.0;
17288            }
17289        }
17290        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
17291        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17292        execute_thunks(&sched, arena.raw_buf_mut());
17293
17294        let got = read_arena_f64(&arena, x, n);
17295        // Verify A·x ≈ b by computing the residual.
17296        let mut residual = vec![0.0_f64; n];
17297        for i in 0..n {
17298            for j in 0..n {
17299                residual[i] += a_data[i * n + j] * got[j];
17300            }
17301        }
17302        for i in 0..n {
17303            assert!(
17304                (residual[i] - b_data[i]).abs() < 1e-10,
17305                "row {i}: residual {} vs b {}",
17306                residual[i],
17307                b_data[i]
17308            );
17309        }
17310    }
17311
17312    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
17313    ///
17314    /// Forward:
17315    ///   A      : Param  [N, N]   f64
17316    ///   b      : Input  [N]      f64
17317    ///   x      = solve(A, b)            (DenseSolve)
17318    ///   loss   = sum(x)                 (Reduce::Sum)
17319    ///
17320    /// Backward (via grad_with_loss):
17321    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
17322    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
17323    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
17324    ///   db       = dx_int                       (DenseSolve VJP step 3)
17325    ///
17326    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
17327    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
17328    /// We verify this against the autodiff-emitted graph's output and
17329    /// against a finite-difference baseline.
17330    #[test]
17331    fn hello_resistor_gradient_end_to_end() {
17332        use rlx_opt::autodiff::grad_with_loss;
17333        let n = 3usize;
17334
17335        // ── Build forward graph ──
17336        let mut g = Graph::new("hello_resistor");
17337        let a = g.param("A", Shape::new(&[n, n], DType::F64));
17338        let b = g.input("b", Shape::new(&[n], DType::F64));
17339        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17340        let loss = g.reduce(
17341            x,
17342            ReduceOp::Sum,
17343            vec![0],
17344            false,
17345            Shape::new(&[1], DType::F64),
17346        );
17347        g.set_outputs(vec![loss]);
17348
17349        // ── Run reverse-mode AD ──
17350        let bwd = grad_with_loss(&g, &[a, b]);
17351        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
17352
17353        // ── Locate the inputs the bwd graph still needs from us ──
17354        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
17355        // appear under their original names. Find them by name.
17356        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17357            for node in graph.nodes() {
17358                let name = match &node.op {
17359                    rlx_ir::Op::Input { name } => Some(name.as_str()),
17360                    rlx_ir::Op::Param { name } => Some(name.as_str()),
17361                    _ => None,
17362                };
17363                if name == Some(want) {
17364                    return node.id;
17365                }
17366            }
17367            panic!("no node named {want:?} in bwd graph");
17368        };
17369        let a_bwd = find_by_name(&bwd, "A");
17370        let b_bwd = find_by_name(&bwd, "b");
17371        let d_out_bwd = find_by_name(&bwd, "d_output");
17372
17373        // ── Test data ──
17374        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
17375        // b = [1,2,3]
17376        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17377        let b_data = [1.0, 2.0, 3.0_f64];
17378        let d_output = [1.0_f64]; // ∂loss/∂loss
17379
17380        // ── Compile + execute backward graph ──
17381        let (sched, mut arena) = prepare_f64(
17382            &bwd,
17383            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
17384        );
17385        execute_thunks(&sched, arena.raw_buf_mut());
17386
17387        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
17388        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
17389        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
17390
17391        // ── Closed-form reference ──
17392        // x = A⁻¹ b ; loss = sum(x).
17393        let x_ref = {
17394            let mut a = a_data;
17395            let mut b = b_data;
17396            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
17397            assert_eq!(info, 0);
17398            b
17399        };
17400        let loss_ref: f64 = x_ref.iter().sum();
17401        // db = (Aᵀ)⁻¹ · 1
17402        let db_ref = {
17403            let mut at = [0.0_f64; 9];
17404            for i in 0..n {
17405                for j in 0..n {
17406                    at[i * n + j] = a_data[j * n + i];
17407                }
17408            }
17409            let mut ones = [1.0_f64; 3];
17410            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
17411            assert_eq!(info, 0);
17412            ones
17413        };
17414        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
17415        let mut da_ref = [0.0_f64; 9];
17416        for i in 0..n {
17417            for j in 0..n {
17418                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
17419            }
17420        }
17421
17422        // ── Assertions vs analytic answer ──
17423        assert!(
17424            (loss_out[0] - loss_ref).abs() < 1e-10,
17425            "loss: got {}, want {}",
17426            loss_out[0],
17427            loss_ref
17428        );
17429        for i in 0..n {
17430            assert!(
17431                (db_out[i] - db_ref[i]).abs() < 1e-10,
17432                "db[{i}]: got {}, want {}",
17433                db_out[i],
17434                db_ref[i]
17435            );
17436        }
17437        for i in 0..n * n {
17438            assert!(
17439                (da_out[i] - da_ref[i]).abs() < 1e-10,
17440                "dA[{i}]: got {}, want {}",
17441                da_out[i],
17442                da_ref[i]
17443            );
17444        }
17445
17446        // ── Cross-check vs finite differences on db (a few entries) ──
17447        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
17448        let h = 1e-6_f64;
17449        for k in 0..n {
17450            let mut bp = b_data;
17451            bp[k] += h;
17452            let mut bm = b_data;
17453            bm[k] -= h;
17454            let lp = {
17455                let mut ac = a_data;
17456                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
17457                assert_eq!(info, 0);
17458                bp.iter().sum::<f64>()
17459            };
17460            let lm = {
17461                let mut ac = a_data;
17462                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
17463                assert_eq!(info, 0);
17464                bm.iter().sum::<f64>()
17465            };
17466            let fd = (lp - lm) / (2.0 * h);
17467            assert!(
17468                (db_out[k] - fd).abs() < 1e-7,
17469                "FD mismatch on db[{k}]: AD={} FD={}",
17470                db_out[k],
17471                fd
17472            );
17473        }
17474    }
17475
17476    /// Smallest possible Op::Scan basic test: geometric growth.
17477    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
17478    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
17479    /// to f64 precision.
17480    #[test]
17481    fn scan_geometric_growth_f64() {
17482        let n = 3usize;
17483        let length = 10u32;
17484
17485        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
17486        let mut body = Graph::new("scan_body");
17487        let x = body.input("carry", Shape::new(&[n], DType::F64));
17488        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
17489        let scale = body.add_node(
17490            Op::Constant { data: scale_bytes },
17491            vec![],
17492            Shape::new(&[n], DType::F64),
17493        );
17494        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17495        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
17496        body.set_outputs(vec![next]);
17497
17498        // Outer graph: scan(init, body, length).
17499        let mut g = Graph::new("scan_outer");
17500        let init = g.input("init", Shape::new(&[n], DType::F64));
17501        let final_carry = g.scan(init, body, length);
17502        g.set_outputs(vec![final_carry]);
17503
17504        let init_data = vec![1.0_f64; n];
17505        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17506        execute_thunks(&sched, arena.raw_buf_mut());
17507        let got = read_arena_f64(&arena, final_carry, n);
17508        let want: f64 = 1.1_f64.powi(length as i32);
17509        for i in 0..n {
17510            assert!(
17511                (got[i] - want).abs() < 1e-12,
17512                "got[{i}] = {} want {}",
17513                got[i],
17514                want
17515            );
17516        }
17517    }
17518
17519    /// Per-step xs scan: cumulative-sum.
17520    ///   carry_0 = init
17521    ///   carry_{t+1} = carry_t + xs\[t\]
17522    ///   final = sum_{t<length} xs\[t\] + init
17523    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
17524    /// (next carry). Validates the per-step-input plumbing end-to-end.
17525    #[test]
17526    fn scan_with_xs_cumulative_sum() {
17527        let n = 3usize;
17528        let length = 4u32;
17529
17530        let mut body = Graph::new("cumsum_body");
17531        // carry must come first in NodeId order — declare it first.
17532        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17533        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
17534        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
17535        body.set_outputs(vec![next]);
17536
17537        let mut g = Graph::new("cumsum_outer");
17538        let init = g.input("init", Shape::new(&[n], DType::F64));
17539        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17540        let final_carry = g.scan_with_xs(init, &[xs], body, length);
17541        g.set_outputs(vec![final_carry]);
17542
17543        let init_data = vec![0.0_f64; n];
17544        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
17545        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17546        execute_thunks(&sched, arena.raw_buf_mut());
17547        let got = read_arena_f64(&arena, final_carry, n);
17548
17549        // Reference: column-wise sum of xs rows + init. With our row-major
17550        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
17551        // (per-step row at offset t*n contributes element j to slot j).
17552        let mut want = init_data.clone();
17553        for t in 0..length as usize {
17554            for j in 0..n {
17555                want[j] += xs_data[t * n + j];
17556            }
17557        }
17558        for i in 0..n {
17559            assert!(
17560                (got[i] - want[i]).abs() < 1e-12,
17561                "got[{i}] = {} want {}",
17562                got[i],
17563                want[i]
17564            );
17565        }
17566    }
17567
17568    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
17569    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
17570    /// Models a Backward-Euler step driven by a time-varying source.
17571    #[test]
17572    fn scan_with_xs_be_with_drive() {
17573        let n = 3usize;
17574        let length = 4u32;
17575        let dt = 0.1_f64;
17576
17577        let mut m_data = vec![0.0_f64; n * n];
17578        for i in 0..n {
17579            m_data[i * n + i] = 1.0 + dt * 2.0;
17580            if i > 0 {
17581                m_data[i * n + (i - 1)] = -dt;
17582            }
17583            if i + 1 < n {
17584                m_data[i * n + (i + 1)] = -dt;
17585            }
17586        }
17587        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17588
17589        let mut body = Graph::new("be_drive_body");
17590        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17591        let drive = body.input("drive", Shape::new(&[n], DType::F64));
17592        let m = body.add_node(
17593            Op::Constant { data: m_bytes },
17594            vec![],
17595            Shape::new(&[n, n], DType::F64),
17596        );
17597        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17598        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17599        body.set_outputs(vec![next]);
17600
17601        let mut g = Graph::new("be_drive_outer");
17602        let init = g.input("init", Shape::new(&[n], DType::F64));
17603        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17604        let final_carry = g.scan_with_xs(init, &[xs], body, length);
17605        g.set_outputs(vec![final_carry]);
17606
17607        let init_data = vec![0.0_f64; n];
17608        // Drive the system with a unit pulse on element 0 at t=0,
17609        // zeros after.
17610        let mut xs_data = vec![0.0_f64; length as usize * n];
17611        xs_data[0] = 1.0;
17612
17613        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17614        execute_thunks(&sched, arena.raw_buf_mut());
17615        let got = read_arena_f64(&arena, final_carry, n);
17616
17617        // Reference: per-step in pure Rust.
17618        let mut x = init_data.clone();
17619        for t in 0..length as usize {
17620            for j in 0..n {
17621                x[j] += xs_data[t * n + j];
17622            }
17623            let mut a_copy = m_data.clone();
17624            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
17625        }
17626        for i in 0..n {
17627            assert!(
17628                (got[i] - x[i]).abs() < 1e-12,
17629                "got[{i}] = {} ref {}",
17630                got[i],
17631                x[i]
17632            );
17633        }
17634    }
17635
17636    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
17637    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
17638    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
17639    /// per batch (each slice matches what the unbatched DenseSolve VJP
17640    /// would compute).
17641    #[test]
17642    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
17643        use rlx_opt::autodiff::grad_with_loss;
17644        let n = 3usize;
17645        let batch = 4usize;
17646
17647        let mut g = Graph::new("bds_grad");
17648        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
17649        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
17650        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
17651        let loss = g.reduce(
17652            x,
17653            ReduceOp::Sum,
17654            vec![0, 1],
17655            false,
17656            Shape::new(&[1], DType::F64),
17657        );
17658        g.set_outputs(vec![loss]);
17659
17660        let bwd = grad_with_loss(&g, &[a, b]);
17661
17662        let find = |graph: &Graph, want: &str| -> NodeId {
17663            for node in graph.nodes() {
17664                let name = match &node.op {
17665                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17666                    _ => None,
17667                };
17668                if name == Some(want) {
17669                    return node.id;
17670                }
17671            }
17672            panic!("no node named {want}");
17673        };
17674        let a_id = find(&bwd, "A");
17675        let b_id = find(&bwd, "b");
17676        let d_out_id = find(&bwd, "d_output");
17677
17678        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
17679        let mut a_data = vec![0.0_f64; batch * n * n];
17680        let mut b_data = vec![0.0_f64; batch * n];
17681        for bi in 0..batch {
17682            for i in 0..n {
17683                for j in 0..n {
17684                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
17685                }
17686                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
17687            }
17688            for i in 0..n {
17689                b_data[bi * n + i] = rng.next_f32() as f64;
17690            }
17691        }
17692        let d_seed = [1.0_f64];
17693
17694        let (sched, mut arena) = prepare_f64(
17695            &bwd,
17696            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
17697        );
17698        execute_thunks(&sched, arena.raw_buf_mut());
17699        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
17700        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
17701
17702        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
17703        // dA_i = -dB_i ⊗ x_i.
17704        for bi in 0..batch {
17705            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
17706            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
17707            let mut a_copy = a_slice.clone();
17708            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
17709            let x_ref = b_slice.clone();
17710            // dB: solve(A^T, ones)
17711            let mut at = vec![0.0_f64; n * n];
17712            for i in 0..n {
17713                for j in 0..n {
17714                    at[i * n + j] = a_slice[j * n + i];
17715                }
17716            }
17717            let mut ones = vec![1.0_f64; n];
17718            crate::blas::dgesv(&mut at, &mut ones, n, 1);
17719            let db_ref = ones;
17720            for i in 0..n {
17721                let got = db_out[bi * n + i];
17722                assert!(
17723                    (got - db_ref[i]).abs() < 1e-10,
17724                    "batch {bi}, db[{i}]: got {got} ref {}",
17725                    db_ref[i]
17726                );
17727            }
17728            // dA: -outer(db, x)
17729            for i in 0..n {
17730                for j in 0..n {
17731                    let got = da_out[bi * n * n + i * n + j];
17732                    let want = -db_ref[i] * x_ref[j];
17733                    assert!(
17734                        (got - want).abs() < 1e-10,
17735                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
17736                    );
17737                }
17738            }
17739        }
17740    }
17741
17742    /// AD knob: gradient through `scan_checkpointed` automatically
17743    /// uses the recompute backward path. Compares dinit from a plain
17744    /// scan against the same forward written with `scan_checkpointed`,
17745    /// both run through `grad_with_loss`. They must match to f64.
17746    #[test]
17747    fn scan_checkpointed_grad_matches_plain_scan_grad() {
17748        use rlx_opt::autodiff::grad_with_loss;
17749        let n = 2usize;
17750        let length = 6u32;
17751
17752        let make_body = || {
17753            let mut body = Graph::new("ck_body");
17754            let carry = body.input("carry", Shape::new(&[n], DType::F64));
17755            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
17756            let scale = body.add_node(
17757                Op::Constant { data: scale_bytes },
17758                vec![],
17759                Shape::new(&[n], DType::F64),
17760            );
17761            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
17762            body.set_outputs(vec![next]);
17763            body
17764        };
17765
17766        // Plain scan path.
17767        let mut g_plain = Graph::new("ck_plain");
17768        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
17769        let final_p = g_plain.scan(init_p, make_body(), length);
17770        let loss_p = g_plain.reduce(
17771            final_p,
17772            ReduceOp::Sum,
17773            vec![0],
17774            false,
17775            Shape::new(&[1], DType::F64),
17776        );
17777        g_plain.set_outputs(vec![loss_p]);
17778        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
17779
17780        // Checkpointed scan path with K=2 (length=6).
17781        let mut g_ck = Graph::new("ck_ckpt");
17782        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
17783        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
17784        let loss_c = g_ck.reduce(
17785            final_c,
17786            ReduceOp::Sum,
17787            vec![0],
17788            false,
17789            Shape::new(&[1], DType::F64),
17790        );
17791        g_ck.set_outputs(vec![loss_c]);
17792        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
17793
17794        let find = |graph: &Graph, want: &str| -> NodeId {
17795            for node in graph.nodes() {
17796                let name = match &node.op {
17797                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17798                    _ => None,
17799                };
17800                if name == Some(want) {
17801                    return node.id;
17802                }
17803            }
17804            panic!("no {want}");
17805        };
17806
17807        let init_data = vec![0.5_f64, -0.5];
17808        let d_seed = [1.0_f64];
17809
17810        let (s_p, mut a_p) = prepare_f64(
17811            &bwd_p,
17812            &[
17813                (find(&bwd_p, "init"), &init_data),
17814                (find(&bwd_p, "d_output"), &d_seed),
17815            ],
17816        );
17817        execute_thunks(&s_p, a_p.raw_buf_mut());
17818        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
17819
17820        let (s_c, mut a_c) = prepare_f64(
17821            &bwd_c,
17822            &[
17823                (find(&bwd_c, "init"), &init_data),
17824                (find(&bwd_c, "d_output"), &d_seed),
17825            ],
17826        );
17827        execute_thunks(&s_c, a_c.raw_buf_mut());
17828        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
17829
17830        for i in 0..n {
17831            assert!(
17832                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
17833                "dinit[{i}]: plain={} checkpointed={}",
17834                dinit_p[i],
17835                dinit_c[i]
17836            );
17837        }
17838    }
17839
17840    /// Recursive checkpointing end-to-end: build a ScanBackward
17841    /// configured with K=2 checkpoints (for length=4), and compare
17842    /// dinit against the same backward graph with full trajectory
17843    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
17844    /// Both paths must agree to f64 precision.
17845    #[test]
17846    fn recursive_checkpointing_matches_full_trajectory() {
17847        let n = 2usize;
17848        let length = 4u32;
17849
17850        // Body: carry + ones (deterministic, no xs)
17851        let build_body = || -> Graph {
17852            let mut body = Graph::new("rc_body");
17853            let carry = body.input("carry", Shape::new(&[n], DType::F64));
17854            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17855            let ones = body.add_node(
17856                Op::Constant { data: ones_bytes },
17857                vec![],
17858                Shape::new(&[n], DType::F64),
17859            );
17860            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
17861            body.set_outputs(vec![next]);
17862            body
17863        };
17864
17865        // body_vjp: same body + d_output, output dcarry. body_vjp is
17866        // used by ScanBackward to walk the chain rule per step.
17867        let body_vjp_for = || -> Graph {
17868            use rlx_opt::autodiff::grad;
17869            let body = build_body();
17870            // grad(body, [carry_id]) → graph with dcarry as the output.
17871            let carry_id = body
17872                .nodes()
17873                .iter()
17874                .find(|n| matches!(n.op, Op::Input { .. }))
17875                .map(|n| n.id)
17876                .unwrap();
17877            grad(&body, &[carry_id])
17878        };
17879
17880        // ── Forward (All-strategy): scan with full trajectory ──
17881        let mut g_full = Graph::new("rc_outer_full");
17882        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
17883        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
17884        // Hand-build a ScanBackward node that reads the full trajectory.
17885        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17886        let dinit_full_id = g_full.scan_backward(
17887            init_full,
17888            traj_full_id,
17889            upstream_full,
17890            &[],
17891            body_vjp_for(),
17892            length,
17893            true,
17894            Shape::new(&[n], DType::F64),
17895        );
17896        g_full.set_outputs(vec![dinit_full_id]);
17897
17898        // ── Forward (Recursive-2): scan saves only K=2 rows ──
17899        // Build the trajectory shape [K, *carry] = [2, 2].
17900        let k = 2u32;
17901        let mut g_rec = Graph::new("rc_outer_rec");
17902        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
17903        let traj_rec_id = g_rec.add_node(
17904            Op::Scan {
17905                body: Box::new(build_body()),
17906                length,
17907                save_trajectory: true,
17908                num_bcast: 0,
17909                num_xs: 0,
17910                num_checkpoints: k,
17911            },
17912            vec![init_rec],
17913            Shape::new(&[k as usize, n], DType::F64),
17914        );
17915        // Same upstream shape as the full version (the upstream is per
17916        // *forward step*, length rows — independent of K).
17917        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17918        let dinit_rec_id = g_rec.add_node(
17919            Op::ScanBackward {
17920                body_vjp: Box::new(body_vjp_for()),
17921                length,
17922                save_trajectory: true,
17923                num_xs: 0,
17924                num_checkpoints: k,
17925                forward_body: Some(Box::new(build_body())),
17926            },
17927            vec![init_rec, traj_rec_id, upstream_rec],
17928            Shape::new(&[n], DType::F64),
17929        );
17930        g_rec.set_outputs(vec![dinit_rec_id]);
17931
17932        // ── Run both, same inputs ──
17933        let init_data = vec![0.5_f64, -0.5];
17934        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
17935
17936        let find = |graph: &Graph, want: &str| -> NodeId {
17937            for node in graph.nodes() {
17938                if let Op::Input { name } = &node.op
17939                    && name == want
17940                {
17941                    return node.id;
17942                }
17943            }
17944            panic!("no input {want}");
17945        };
17946
17947        let (s_full, mut a_full) = prepare_f64(
17948            &g_full,
17949            &[
17950                (find(&g_full, "init"), &init_data),
17951                (find(&g_full, "upstream"), &upstream_data),
17952            ],
17953        );
17954        execute_thunks(&s_full, a_full.raw_buf_mut());
17955        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
17956
17957        let (s_rec, mut a_rec) = prepare_f64(
17958            &g_rec,
17959            &[
17960                (find(&g_rec, "init"), &init_data),
17961                (find(&g_rec, "upstream"), &upstream_data),
17962            ],
17963        );
17964        execute_thunks(&s_rec, a_rec.raw_buf_mut());
17965        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
17966
17967        for i in 0..n {
17968            assert!(
17969                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
17970                "i={i}: full={} rec={}",
17971                dinit_full[i],
17972                dinit_rec[i]
17973            );
17974        }
17975    }
17976
17977    /// vmap-of-grad: gradient through Scan, vmap'd over init.
17978    /// Forward (per row):
17979    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
17980    ///   loss = sum(carry_length) = sum(init) + length·n
17981    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
17982    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
17983    /// against per-row grad_with_loss runs. Validates the vmap rule
17984    /// for Op::ScanBackward.
17985    #[test]
17986    fn vmap_of_grad_scan_matches_per_row_runs() {
17987        use rlx_opt::autodiff::grad_with_loss;
17988        use rlx_opt::vmap::vmap;
17989        let n = 2usize;
17990        let length = 3u32;
17991        let batch = 3usize;
17992
17993        let mut body = Graph::new("scan_grad_body");
17994        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17995        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17996        let ones = body.add_node(
17997            Op::Constant { data: ones_bytes },
17998            vec![],
17999            Shape::new(&[n], DType::F64),
18000        );
18001        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18002        body.set_outputs(vec![next]);
18003
18004        let mut g = Graph::new("scan_grad_outer");
18005        let init = g.input("init", Shape::new(&[n], DType::F64));
18006        let final_x = g.scan(init, body, length);
18007        let loss = g.reduce(
18008            final_x,
18009            ReduceOp::Sum,
18010            vec![0],
18011            false,
18012            Shape::new(&[1], DType::F64),
18013        );
18014        g.set_outputs(vec![loss]);
18015
18016        let bwd = grad_with_loss(&g, &[init]);
18017        let bg = vmap(&bwd, &["init"], batch);
18018
18019        let find = |graph: &Graph, want: &str| -> NodeId {
18020            for node in graph.nodes() {
18021                let name = match &node.op {
18022                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18023                    _ => None,
18024                };
18025                if name == Some(want) {
18026                    return node.id;
18027                }
18028            }
18029            panic!("no node named {want}");
18030        };
18031        let init_b = find(&bg, "init");
18032        let d_out_b = find(&bg, "d_output");
18033
18034        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18035        let d_seed = [1.0_f64];
18036
18037        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18038        execute_thunks(&sched, arena.raw_buf_mut());
18039        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18040
18041        for i in 0..batch * n {
18042            assert!(
18043                (dinit_b[i] - 1.0).abs() < 1e-12,
18044                "dinit[{i}] = {} (expected 1.0)",
18045                dinit_b[i]
18046            );
18047        }
18048
18049        // Cross-check vs per-row grad_with_loss.
18050        for bi in 0..batch {
18051            let row = &init_data[bi * n..(bi + 1) * n];
18052            let mut g2 = Graph::new("per_row_grad");
18053            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18054            let mut body2 = Graph::new("per_row_body");
18055            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18056            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18057            let ones2 = body2.add_node(
18058                Op::Constant { data: ones2_bytes },
18059                vec![],
18060                Shape::new(&[n], DType::F64),
18061            );
18062            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18063            body2.set_outputs(vec![next2]);
18064            let final2 = g2.scan(init2, body2, length);
18065            let loss2 = g2.reduce(
18066                final2,
18067                ReduceOp::Sum,
18068                vec![0],
18069                false,
18070                Shape::new(&[1], DType::F64),
18071            );
18072            g2.set_outputs(vec![loss2]);
18073            let bwd2 = grad_with_loss(&g2, &[init2]);
18074            let init2_id = find(&bwd2, "init");
18075            let d_out2_id = find(&bwd2, "d_output");
18076            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
18077            execute_thunks(&s2, a2.raw_buf_mut());
18078            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
18079            for j in 0..n {
18080                let got = dinit_b[bi * n + j];
18081                let want = row_dinit[j];
18082                assert!(
18083                    (got - want).abs() < 1e-12,
18084                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
18085                );
18086            }
18087        }
18088    }
18089
18090    /// vmap of Op::Scan: batched cumulative-sum. Forward
18091    ///   carry_{t+1} = carry_t + xs\[t\]
18092    ///   final = init + sum(xs)
18093    /// vmap over both init and xs at batch=3. Each batch row should
18094    /// equal the scalar run of the same body+xs subset.
18095    #[test]
18096    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
18097        use rlx_opt::vmap::vmap;
18098        let n = 2usize;
18099        let length = 4u32;
18100        let batch = 3usize;
18101
18102        // Body: (carry, x_t) → carry + x_t
18103        let mut body = Graph::new("scan_body_cumsum");
18104        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18105        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18106        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18107        body.set_outputs(vec![next]);
18108
18109        let mut g = Graph::new("scan_outer_cumsum");
18110        let init = g.input("init", Shape::new(&[n], DType::F64));
18111        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18112        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18113        g.set_outputs(vec![final_carry]);
18114
18115        // vmap over both init and xs.
18116        let bg = vmap(&g, &["init", "xs"], batch);
18117
18118        // Test data — distinct per-batch rows.
18119        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
18120        // xs has shape [B, length, n] after vmap (the outer's xs is
18121        // [length, n]; vmap lifts it to [B, length, n]).
18122        let xs_data: Vec<f64> = (0..batch * length as usize * n)
18123            .map(|i| 0.1 * (i as f64))
18124            .collect();
18125
18126        let find = |graph: &Graph, want: &str| -> NodeId {
18127            for node in graph.nodes() {
18128                if let Op::Input { name } = &node.op
18129                    && name == want
18130                {
18131                    return node.id;
18132                }
18133            }
18134            panic!("no input {want}");
18135        };
18136        let init_b = find(&bg, "init");
18137        let xs_b = find(&bg, "xs");
18138        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
18139        execute_thunks(&sched, arena.raw_buf_mut());
18140        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
18141
18142        // Reference: per-batch scalar Scan.
18143        for bi in 0..batch {
18144            let init_slice = &init_data[bi * n..(bi + 1) * n];
18145            let mut x = init_slice.to_vec();
18146            for t in 0..length as usize {
18147                for j in 0..n {
18148                    x[j] += xs_data[bi * length as usize * n + t * n + j];
18149                }
18150            }
18151
18152            for i in 0..n {
18153                let got = batched_out[bi * n + i];
18154                assert!(
18155                    (got - x[i]).abs() < 1e-12,
18156                    "row {bi}, i {i}: got {got} ref {}",
18157                    x[i]
18158                );
18159            }
18160        }
18161    }
18162
18163    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
18164    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
18165    /// and b (batched [B,N]). Run on CPU and compare each batch row
18166    /// against an independent scalar dgesv.
18167    #[test]
18168    fn vmap_dense_solve_matches_scalar_runs() {
18169        use rlx_opt::vmap::vmap;
18170        let n = 3usize;
18171        let batch = 4usize;
18172
18173        let mut g = Graph::new("solve_forward");
18174        let a = g.input("A", Shape::new(&[n, n], DType::F64));
18175        let b = g.input("b", Shape::new(&[n], DType::F64));
18176        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18177        g.set_outputs(vec![x]);
18178
18179        // vmap both A and b across the batch.
18180        let bg = vmap(&g, &["A", "b"], batch);
18181
18182        // Independent A and b per batch row.
18183        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
18184        let mut a_data = vec![0.0_f64; batch * n * n];
18185        let mut b_data = vec![0.0_f64; batch * n];
18186        for bi in 0..batch {
18187            // Diagonally dominant A — guaranteed non-singular.
18188            for i in 0..n {
18189                for j in 0..n {
18190                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18191                }
18192                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18193            }
18194            for i in 0..n {
18195                b_data[bi * n + i] = rng.next_f32() as f64;
18196            }
18197        }
18198
18199        let find = |graph: &Graph, want: &str| -> NodeId {
18200            for node in graph.nodes() {
18201                if let Op::Input { name } = &node.op
18202                    && name == want
18203                {
18204                    return node.id;
18205                }
18206            }
18207            panic!("no input named {want}");
18208        };
18209        let ba = find(&bg, "A");
18210        let bb = find(&bg, "b");
18211        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
18212        execute_thunks(&sched, arena.raw_buf_mut());
18213        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
18214
18215        // Reference: per-batch dgesv.
18216        for bi in 0..batch {
18217            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18218            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18219            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
18220            for i in 0..n {
18221                let got = batched_x[bi * n + i];
18222                let want = b_slice[i];
18223                assert!(
18224                    (got - want).abs() < 1e-12,
18225                    "row {bi}, i {i}: got {got} want {want}"
18226                );
18227            }
18228        }
18229    }
18230
18231    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
18232    /// and reduces to a per-element loss. vmap over x with batch=4.
18233    /// Run the batched graph and compare each output row against an
18234    /// independent scalar run of the original graph. Validates the
18235    /// structural lift + the runtime path for batched MatMul +
18236    /// batched Binary + batched Reduce.
18237    #[test]
18238    fn vmap_matmul_add_reduce_matches_scalar_runs() {
18239        use rlx_opt::vmap::vmap;
18240        let n = 3usize;
18241        let batch = 4usize;
18242
18243        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
18244        let mut g = Graph::new("vmap_e2e_forward");
18245        let x = g.input("x", Shape::new(&[n], DType::F64));
18246        let w = g.input("w", Shape::new(&[n, n], DType::F64));
18247        let b = g.input("b", Shape::new(&[n], DType::F64));
18248        let x_row = g.add_node(
18249            Op::Reshape {
18250                new_shape: vec![1, n as i64],
18251            },
18252            vec![x],
18253            Shape::new(&[1, n], DType::F64),
18254        );
18255        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
18256        let mm_flat = g.add_node(
18257            Op::Reshape {
18258                new_shape: vec![n as i64],
18259            },
18260            vec![mm],
18261            Shape::new(&[n], DType::F64),
18262        );
18263        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
18264        let loss = g.reduce(
18265            yv,
18266            ReduceOp::Sum,
18267            vec![0],
18268            false,
18269            Shape::new(&[1], DType::F64),
18270        );
18271        g.set_outputs(vec![loss]);
18272
18273        // Build the vmap'd version (batch over x; w and b shared).
18274        let bg = vmap(&g, &["x"], batch);
18275
18276        // Test data — distinct rows so we can verify the per-row dispatch.
18277        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
18278        let n_w = n * n;
18279        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
18280        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
18281        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
18282        for _ in 0..batch * n {
18283            x_data_batched.push(rng.next_f32() as f64);
18284        }
18285
18286        // Run the batched graph.
18287        let find = |graph: &Graph, want: &str| -> NodeId {
18288            for node in graph.nodes() {
18289                if let Op::Input { name } = &node.op
18290                    && name == want
18291                {
18292                    return node.id;
18293                }
18294            }
18295            panic!("no input named {want}");
18296        };
18297        let bx = find(&bg, "x");
18298        let bw = find(&bg, "w");
18299        let bb = find(&bg, "b");
18300        let (sched, mut arena) =
18301            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
18302        execute_thunks(&sched, arena.raw_buf_mut());
18303        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
18304        // (it preserves the leading batch axis but reduces what was [n] to [].
18305        // Since the original output was [1] f64 and the reduce was over
18306        // axis 0, after vmap the leading-axis-shifted reduce keeps the
18307        // leading 1 from the original output's [1] shape.)
18308        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
18309
18310        // Reference: run the original (un-batched) graph once per batch row.
18311        for bi in 0..batch {
18312            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
18313            let mut g2 = Graph::new("scalar_run");
18314            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
18315            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
18316            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
18317            let xr = g2.add_node(
18318                Op::Reshape {
18319                    new_shape: vec![1, n as i64],
18320                },
18321                vec![x2],
18322                Shape::new(&[1, n], DType::F64),
18323            );
18324            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
18325            let mf = g2.add_node(
18326                Op::Reshape {
18327                    new_shape: vec![n as i64],
18328                },
18329                vec![m],
18330                Shape::new(&[n], DType::F64),
18331            );
18332            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
18333            let l2 = g2.reduce(
18334                yv2,
18335                ReduceOp::Sum,
18336                vec![0],
18337                false,
18338                Shape::new(&[1], DType::F64),
18339            );
18340            g2.set_outputs(vec![l2]);
18341            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
18342            execute_thunks(&s2, a2.raw_buf_mut());
18343            let scalar_out = read_arena_f64(&a2, l2, 1);
18344            assert!(
18345                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
18346                "row {bi}: batched={} scalar={}",
18347                batched_out[bi],
18348                scalar_out[0]
18349            );
18350        }
18351    }
18352
18353    /// Full gradient through scan-with-xs: dinit AND dxs both checked
18354    /// against finite differences. Forward
18355    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
18356    ///   loss        = sum(carry_length)
18357    /// Verifies that grad_with_loss returns gradients w.r.t. both
18358    /// `init` and `xs` and that dxs matches per-element FD.
18359    #[test]
18360    fn scan_with_xs_dxs_matches_fd() {
18361        use rlx_opt::autodiff::grad_with_loss;
18362        let n = 3usize;
18363        let length = 3u32;
18364        let dt = 0.1_f64;
18365
18366        let mut m_data = vec![0.0_f64; n * n];
18367        for i in 0..n {
18368            m_data[i * n + i] = 1.0 + dt * 2.0;
18369            if i > 0 {
18370                m_data[i * n + (i - 1)] = -dt;
18371            }
18372            if i + 1 < n {
18373                m_data[i * n + (i + 1)] = -dt;
18374            }
18375        }
18376        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18377
18378        let mut body = Graph::new("be_dxs_body");
18379        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18380        let drive = body.input("drive", Shape::new(&[n], DType::F64));
18381        let m = body.add_node(
18382            Op::Constant { data: m_bytes },
18383            vec![],
18384            Shape::new(&[n, n], DType::F64),
18385        );
18386        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18387        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18388        body.set_outputs(vec![next]);
18389
18390        let mut g = Graph::new("be_dxs_outer");
18391        let init = g.input("init", Shape::new(&[n], DType::F64));
18392        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18393        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18394        let loss = g.reduce(
18395            final_carry,
18396            ReduceOp::Sum,
18397            vec![0],
18398            false,
18399            Shape::new(&[1], DType::F64),
18400        );
18401        g.set_outputs(vec![loss]);
18402
18403        // wrt = [init, xs] — get both gradients back.
18404        let bwd = grad_with_loss(&g, &[init, xs]);
18405        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
18406
18407        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18408            for node in graph.nodes() {
18409                let name = match &node.op {
18410                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18411                    _ => None,
18412                };
18413                if name == Some(want) {
18414                    return node.id;
18415                }
18416            }
18417            panic!("no node named {want:?}");
18418        };
18419        let init_bwd = find_by_name(&bwd, "init");
18420        let xs_bwd = find_by_name(&bwd, "xs");
18421        let d_out_bwd = find_by_name(&bwd, "d_output");
18422
18423        let init_data = vec![0.5_f64, 0.0, -0.5];
18424        let xs_data: Vec<f64> = (0..length as usize * n)
18425            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18426            .collect();
18427        let d_seed = [1.0_f64];
18428
18429        let (sched, mut arena) = prepare_f64(
18430            &bwd,
18431            &[
18432                (init_bwd, &init_data),
18433                (xs_bwd, &xs_data),
18434                (d_out_bwd, &d_seed),
18435            ],
18436        );
18437        execute_thunks(&sched, arena.raw_buf_mut());
18438        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18439        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
18440
18441        let h = 1e-6;
18442        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
18443            let mut acc = x0.to_vec();
18444            for t in 0..length as usize {
18445                for j in 0..n {
18446                    acc[j] += xs_in[t * n + j];
18447                }
18448                let mut a_copy = m_data.clone();
18449                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18450            }
18451            acc.iter().sum()
18452        };
18453
18454        // FD on dinit (sanity).
18455        for i in 0..n {
18456            let mut ip = init_data.to_vec();
18457            ip[i] += h;
18458            let mut im = init_data.to_vec();
18459            im[i] -= h;
18460            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
18461            assert!(
18462                (dinit[i] - fd).abs() < 1e-7,
18463                "FD dinit[{i}]: AD={} FD={}",
18464                dinit[i],
18465                fd
18466            );
18467        }
18468
18469        // FD on every dxs entry — full per-step gradient check.
18470        for t in 0..length as usize {
18471            for j in 0..n {
18472                let idx = t * n + j;
18473                let mut xp = xs_data.clone();
18474                xp[idx] += h;
18475                let mut xm = xs_data.clone();
18476                xm[idx] -= h;
18477                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
18478                assert!(
18479                    (dxs[idx] - fd).abs() < 1e-7,
18480                    "FD dxs[t={t},j={j}]: AD={} FD={}",
18481                    dxs[idx],
18482                    fd
18483                );
18484            }
18485        }
18486    }
18487
18488    /// Gradient through a scan with per-step xs (Circulax-shaped).
18489    /// Forward:
18490    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
18491    ///   loss = sum(carry_length)
18492    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
18493    /// but `dinit` flows correctly through the body's reverse Jacobian
18494    /// even with xs in the chain. Verify dinit against finite differences.
18495    #[test]
18496    fn scan_with_xs_gradient_dinit_matches_fd() {
18497        use rlx_opt::autodiff::grad_with_loss;
18498        let n = 3usize;
18499        let length = 3u32;
18500        let dt = 0.1_f64;
18501
18502        let mut m_data = vec![0.0_f64; n * n];
18503        for i in 0..n {
18504            m_data[i * n + i] = 1.0 + dt * 2.0;
18505            if i > 0 {
18506                m_data[i * n + (i - 1)] = -dt;
18507            }
18508            if i + 1 < n {
18509                m_data[i * n + (i + 1)] = -dt;
18510            }
18511        }
18512        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18513
18514        let mut body = Graph::new("be_xs_grad_body");
18515        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18516        let drive = body.input("drive", Shape::new(&[n], DType::F64));
18517        let m = body.add_node(
18518            Op::Constant { data: m_bytes },
18519            vec![],
18520            Shape::new(&[n, n], DType::F64),
18521        );
18522        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18523        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18524        body.set_outputs(vec![next]);
18525
18526        let mut g = Graph::new("be_xs_grad_outer");
18527        let init = g.input("init", Shape::new(&[n], DType::F64));
18528        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18529        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18530        let loss = g.reduce(
18531            final_carry,
18532            ReduceOp::Sum,
18533            vec![0],
18534            false,
18535            Shape::new(&[1], DType::F64),
18536        );
18537        g.set_outputs(vec![loss]);
18538
18539        let bwd = grad_with_loss(&g, &[init]);
18540
18541        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18542            for node in graph.nodes() {
18543                let name = match &node.op {
18544                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18545                    _ => None,
18546                };
18547                if name == Some(want) {
18548                    return node.id;
18549                }
18550            }
18551            panic!("no node named {want:?}");
18552        };
18553        let init_bwd = find_by_name(&bwd, "init");
18554        let xs_bwd = find_by_name(&bwd, "xs");
18555        let d_out_bwd = find_by_name(&bwd, "d_output");
18556
18557        let init_data = vec![0.5_f64, 0.0, -0.5];
18558        // Drive: small per-step pulse, varying per element.
18559        let xs_data: Vec<f64> = (0..length as usize * n)
18560            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18561            .collect();
18562        let d_seed = [1.0_f64];
18563
18564        let (sched, mut arena) = prepare_f64(
18565            &bwd,
18566            &[
18567                (init_bwd, &init_data),
18568                (xs_bwd, &xs_data),
18569                (d_out_bwd, &d_seed),
18570            ],
18571        );
18572        execute_thunks(&sched, arena.raw_buf_mut());
18573        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18574
18575        let h = 1e-6;
18576        let loss_at = |x0: &[f64]| -> f64 {
18577            let mut acc = x0.to_vec();
18578            for t in 0..length as usize {
18579                for j in 0..n {
18580                    acc[j] += xs_data[t * n + j];
18581                }
18582                let mut a_copy = m_data.clone();
18583                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18584            }
18585            acc.iter().sum()
18586        };
18587        for i in 0..n {
18588            let mut ip = init_data.to_vec();
18589            ip[i] += h;
18590            let mut im = init_data.to_vec();
18591            im[i] -= h;
18592            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18593            assert!(
18594                (dinit[i] - fd).abs() < 1e-7,
18595                "FD dinit[{i}]: AD={} FD={}",
18596                dinit[i],
18597                fd
18598            );
18599        }
18600    }
18601
18602    /// Gradient through a geometric-growth scan: forward
18603    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
18604    ///   final   = x_length     = init · 1.1^length
18605    ///   loss    = sum(final)
18606    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
18607    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
18608    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
18609    #[test]
18610    fn scan_gradient_geometric_matches_closed_form() {
18611        use rlx_opt::autodiff::grad_with_loss;
18612        let n = 3usize;
18613        let length = 5u32;
18614
18615        let mut body = Graph::new("scan_grad_body");
18616        let x = body.input("carry", Shape::new(&[n], DType::F64));
18617        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
18618        let scale = body.add_node(
18619            Op::Constant { data: scale_bytes },
18620            vec![],
18621            Shape::new(&[n], DType::F64),
18622        );
18623        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18624        body.set_outputs(vec![next]);
18625
18626        let mut g = Graph::new("scan_grad_outer");
18627        let init = g.input("init", Shape::new(&[n], DType::F64));
18628        let final_x = g.scan(init, body, length);
18629        let loss = g.reduce(
18630            final_x,
18631            ReduceOp::Sum,
18632            vec![0],
18633            false,
18634            Shape::new(&[1], DType::F64),
18635        );
18636        g.set_outputs(vec![loss]);
18637
18638        let bwd = grad_with_loss(&g, &[init]);
18639        assert_eq!(bwd.outputs.len(), 2);
18640
18641        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18642            for node in graph.nodes() {
18643                let name = match &node.op {
18644                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18645                    _ => None,
18646                };
18647                if name == Some(want) {
18648                    return node.id;
18649                }
18650            }
18651            panic!("no node named {want:?}");
18652        };
18653        let init_bwd = find_by_name(&bwd, "init");
18654        let d_out_bwd = find_by_name(&bwd, "d_output");
18655
18656        let init_data = vec![1.0_f64; n];
18657        let d_seed = [1.0_f64];
18658        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18659        execute_thunks(&sched, arena.raw_buf_mut());
18660        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18661
18662        let want = 1.1_f64.powi(length as i32);
18663        for i in 0..n {
18664            assert!(
18665                (dinit[i] - want).abs() < 1e-12,
18666                "dinit[{i}] = {} want {}",
18667                dinit[i],
18668                want
18669            );
18670        }
18671
18672        // Finite-difference cross-check on init[0].
18673        let h = 1e-6;
18674        let loss_at = |x: &[f64]| -> f64 {
18675            let mut acc = x.to_vec();
18676            for _ in 0..length {
18677                for v in acc.iter_mut() {
18678                    *v *= 1.1;
18679                }
18680            }
18681            acc.iter().sum()
18682        };
18683        let mut ip = init_data.clone();
18684        ip[0] += h;
18685        let mut im = init_data.clone();
18686        im[0] -= h;
18687        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18688        assert!(
18689            (dinit[0] - fd).abs() < 1e-7,
18690            "FD dinit[0]: AD={} FD={}",
18691            dinit[0],
18692            fd
18693        );
18694    }
18695
18696    /// Gradient through Backward Euler scan composing with DenseSolve.
18697    /// Asserts dinit matches finite-difference per coordinate.
18698    #[test]
18699    fn scan_gradient_backward_euler_matches_fd() {
18700        use rlx_opt::autodiff::grad_with_loss;
18701        let n = 4usize;
18702        let length = 3u32;
18703        let dt = 0.05_f64;
18704
18705        let mut m_data = vec![0.0_f64; n * n];
18706        for i in 0..n {
18707            m_data[i * n + i] = 1.0 + dt * 2.0;
18708            if i > 0 {
18709                m_data[i * n + (i - 1)] = -dt;
18710            }
18711            if i + 1 < n {
18712                m_data[i * n + (i + 1)] = -dt;
18713            }
18714        }
18715        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18716
18717        let mut body = Graph::new("be_grad_body");
18718        let x = body.input("x", Shape::new(&[n], DType::F64));
18719        let m = body.add_node(
18720            Op::Constant { data: m_bytes },
18721            vec![],
18722            Shape::new(&[n, n], DType::F64),
18723        );
18724        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18725        body.set_outputs(vec![next]);
18726
18727        let mut g = Graph::new("be_grad_outer");
18728        let init = g.input("x0", Shape::new(&[n], DType::F64));
18729        let final_x = g.scan(init, body, length);
18730        let loss = g.reduce(
18731            final_x,
18732            ReduceOp::Sum,
18733            vec![0],
18734            false,
18735            Shape::new(&[1], DType::F64),
18736        );
18737        g.set_outputs(vec![loss]);
18738
18739        let bwd = grad_with_loss(&g, &[init]);
18740
18741        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18742            for node in graph.nodes() {
18743                let name = match &node.op {
18744                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18745                    _ => None,
18746                };
18747                if name == Some(want) {
18748                    return node.id;
18749                }
18750            }
18751            panic!("no node named {want:?}");
18752        };
18753        let init_bwd = find_by_name(&bwd, "x0");
18754        let d_out_bwd = find_by_name(&bwd, "d_output");
18755
18756        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18757        let d_seed = [1.0_f64];
18758        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18759        execute_thunks(&sched, arena.raw_buf_mut());
18760        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18761
18762        let h = 1e-6;
18763        let loss_at = |x0: &[f64]| -> f64 {
18764            let mut acc = x0.to_vec();
18765            for _ in 0..length {
18766                let mut a_copy = m_data.clone();
18767                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18768            }
18769            acc.iter().sum()
18770        };
18771        for i in 0..n {
18772            let mut ip = init_data.to_vec();
18773            ip[i] += h;
18774            let mut im = init_data.to_vec();
18775            im[i] -= h;
18776            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18777            assert!(
18778                (dinit[i] - fd).abs() < 1e-7,
18779                "FD dinit[{i}]: AD={} FD={}",
18780                dinit[i],
18781                fd
18782            );
18783        }
18784    }
18785
18786    /// Trajectory-mode scan: same Backward Euler body, but record the
18787    /// carry at every step. Output is `[length, n]` — row `t` is the
18788    /// state after step `t+1`. Validates the SaveAt-style waveform
18789    /// recording end-to-end, including that the last row equals what
18790    /// the no-trajectory variant would have returned.
18791    #[test]
18792    fn scan_trajectory_backward_euler_records_waveform() {
18793        let n = 4usize;
18794        let length = 5u32;
18795        let dt = 0.05_f64;
18796
18797        let mut m_data = vec![0.0_f64; n * n];
18798        for i in 0..n {
18799            m_data[i * n + i] = 1.0 + dt * 2.0;
18800            if i > 0 {
18801                m_data[i * n + (i - 1)] = -dt;
18802            }
18803            if i + 1 < n {
18804                m_data[i * n + (i + 1)] = -dt;
18805            }
18806        }
18807        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18808
18809        let mut body = Graph::new("be_traj_body");
18810        let x = body.input("x", Shape::new(&[n], DType::F64));
18811        let m = body.add_node(
18812            Op::Constant { data: m_bytes },
18813            vec![],
18814            Shape::new(&[n, n], DType::F64),
18815        );
18816        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18817        body.set_outputs(vec![next]);
18818
18819        let mut g = Graph::new("be_traj_outer");
18820        let init = g.input("x0", Shape::new(&[n], DType::F64));
18821        let traj = g.scan_trajectory(init, body, length);
18822        g.set_outputs(vec![traj]);
18823
18824        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18825        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18826        execute_thunks(&sched, arena.raw_buf_mut());
18827        let got = read_arena_f64(&arena, traj, length as usize * n);
18828
18829        // Reference: each step's solve, recorded.
18830        let mut want = Vec::<f64>::with_capacity(length as usize * n);
18831        let mut x_ref = init_data.to_vec();
18832        for _ in 0..length {
18833            let mut a_copy = m_data.clone();
18834            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
18835            want.extend_from_slice(&x_ref);
18836        }
18837        for i in 0..length as usize * n {
18838            assert!(
18839                (got[i] - want[i]).abs() < 1e-12,
18840                "got[{i}] = {} ref {}",
18841                got[i],
18842                want[i]
18843            );
18844        }
18845
18846        // Sanity: trajectory rows are monotone-decreasing in mass
18847        // (Backward Euler diffuses; boundary leak removes mass).
18848        for t in 1..length as usize {
18849            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
18850            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
18851            assert!(
18852                curr <= prev + 1e-15,
18853                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
18854                t - 1
18855            );
18856        }
18857
18858        // Last row of the trajectory equals what a non-trajectory
18859        // scan returns — verify by running the same forward through
18860        // the simpler API and comparing.
18861        let mut body2 = Graph::new("be_final_body");
18862        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
18863        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18864        let m2 = body2.add_node(
18865            Op::Constant { data: m_bytes2 },
18866            vec![],
18867            Shape::new(&[n, n], DType::F64),
18868        );
18869        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
18870        body2.set_outputs(vec![next2]);
18871
18872        let mut g2 = Graph::new("be_final_outer");
18873        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
18874        let final_x = g2.scan(init2, body2, length);
18875        g2.set_outputs(vec![final_x]);
18876        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
18877        execute_thunks(&sched2, arena2.raw_buf_mut());
18878        let final_got = read_arena_f64(&arena2, final_x, n);
18879
18880        let last_row = &got[(length as usize - 1) * n..length as usize * n];
18881        for i in 0..n {
18882            assert!(
18883                (last_row[i] - final_got[i]).abs() < 1e-15,
18884                "last trajectory row[{i}] = {} vs final-scan = {}",
18885                last_row[i],
18886                final_got[i]
18887            );
18888        }
18889    }
18890
18891    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
18892    /// pattern for Backward Euler.
18893    /// Body: x_{t+1} = solve(I + dt·A, x_t).
18894    /// 1-D heat-equation Laplacian A; analytic ground truth from
18895    /// composing the same per-step solve in Rust.
18896    #[test]
18897    fn scan_backward_euler_heat_f64() {
18898        let n = 4usize;
18899        let length = 5u32;
18900        let dt = 0.05_f64;
18901
18902        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
18903        // M is constant across iterations; embed it in the body via Op::Constant.
18904        let mut m_data = vec![0.0_f64; n * n];
18905        for i in 0..n {
18906            m_data[i * n + i] = 1.0 + dt * 2.0;
18907            if i > 0 {
18908                m_data[i * n + (i - 1)] = -dt;
18909            }
18910            if i + 1 < n {
18911                m_data[i * n + (i + 1)] = -dt;
18912            }
18913        }
18914        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18915
18916        let mut body = Graph::new("be_body");
18917        let x = body.input("x", Shape::new(&[n], DType::F64));
18918        let m = body.add_node(
18919            Op::Constant { data: m_bytes },
18920            vec![],
18921            Shape::new(&[n, n], DType::F64),
18922        );
18923        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18924        body.set_outputs(vec![next]);
18925
18926        let mut g = Graph::new("be_outer");
18927        let init = g.input("x0", Shape::new(&[n], DType::F64));
18928        let final_x = g.scan(init, body, length);
18929        g.set_outputs(vec![final_x]);
18930
18931        // Initial: a sharp pulse at index 1.
18932        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18933        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18934        execute_thunks(&sched, arena.raw_buf_mut());
18935        let got = read_arena_f64(&arena, final_x, n);
18936
18937        // Reference: apply the same M-solve `length` times in pure Rust.
18938        let mut ref_x = init_data.to_vec();
18939        for _ in 0..length {
18940            let mut a_copy = m_data.clone();
18941            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
18942        }
18943        for i in 0..n {
18944            assert!(
18945                (got[i] - ref_x[i]).abs() < 1e-12,
18946                "got[{i}] = {} ref {}",
18947                got[i],
18948                ref_x[i]
18949            );
18950        }
18951        // Sanity: pulse should diffuse, mass should be conserved-ish
18952        // (Backward Euler is mass-conserving for this stencil with
18953        // zero-flux boundaries — but our boundaries leak, so check
18954        // that mass strictly decreases instead).
18955        let mass: f64 = got.iter().sum();
18956        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
18957    }
18958
18959    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
18960    /// stays correct end-to-end. Verifies the executor/lowering and
18961    /// the LAPACK column-major dance both honour `nrhs > 1`.
18962    #[test]
18963    fn dense_solve_f64_multi_rhs_forward() {
18964        let n = 3usize;
18965        let k = 2usize;
18966        let mut g = Graph::new("solve_multi_rhs");
18967        let a = g.input("A", Shape::new(&[n, n], DType::F64));
18968        let b = g.input("B", Shape::new(&[n, k], DType::F64));
18969        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
18970        g.set_outputs(vec![x]);
18971
18972        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
18973        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
18974        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18975        execute_thunks(&sched, arena.raw_buf_mut());
18976        let x_got = read_arena_f64(&arena, x, n * k);
18977        for c in 0..k {
18978            for i in 0..n {
18979                let mut acc = 0.0_f64;
18980                for j in 0..n {
18981                    acc += a_data[i * n + j] * x_got[j * k + c];
18982                }
18983                let want = b_data[i * k + c];
18984                assert!(
18985                    (acc - want).abs() < 1e-10,
18986                    "col {c} row {i}: got {acc} want {want}"
18987                );
18988            }
18989        }
18990    }
18991
18992    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
18993    /// Verified analytically + finite differences on dB[0,0].
18994    #[test]
18995    fn dense_solve_f64_multi_rhs_gradient() {
18996        use rlx_opt::autodiff::grad_with_loss;
18997        let n = 3usize;
18998        let k = 2usize;
18999        let mut g = Graph::new("solve_mrhs_grad");
19000        let a = g.param("A", Shape::new(&[n, n], DType::F64));
19001        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19002        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19003        let loss = g.reduce(
19004            x,
19005            ReduceOp::Sum,
19006            vec![0, 1],
19007            false,
19008            Shape::new(&[1], DType::F64),
19009        );
19010        g.set_outputs(vec![loss]);
19011
19012        let bwd = grad_with_loss(&g, &[a, b]);
19013        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19014            for node in graph.nodes() {
19015                let name = match &node.op {
19016                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19017                    _ => None,
19018                };
19019                if name == Some(want) {
19020                    return node.id;
19021                }
19022            }
19023            panic!("no node named {want:?}");
19024        };
19025        let a_bwd = find_by_name(&bwd, "A");
19026        let b_bwd = find_by_name(&bwd, "B");
19027        let d_out = find_by_name(&bwd, "d_output");
19028
19029        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19030        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19031        let d_seed = [1.0_f64];
19032
19033        let (sched, mut arena) = prepare_f64(
19034            &bwd,
19035            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19036        );
19037        execute_thunks(&sched, arena.raw_buf_mut());
19038        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19039        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19040
19041        // Reference.
19042        let mut x_ref = b_data;
19043        {
19044            let mut a_copy = a_data;
19045            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19046        }
19047        let mut at = [0.0_f64; 9];
19048        for i in 0..n {
19049            for j in 0..n {
19050                at[i * n + j] = a_data[j * n + i];
19051            }
19052        }
19053        let mut ones_nk = vec![1.0_f64; n * k];
19054        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19055        let db_ref = ones_nk;
19056        let mut da_ref = [0.0_f64; 9];
19057        for i in 0..n {
19058            for j in 0..n {
19059                let mut acc = 0.0_f64;
19060                for c in 0..k {
19061                    acc += db_ref[i * k + c] * x_ref[j * k + c];
19062                }
19063                da_ref[i * n + j] = -acc;
19064            }
19065        }
19066        for i in 0..n * k {
19067            assert!(
19068                (db_got[i] - db_ref[i]).abs() < 1e-10,
19069                "dB[{i}]: got {} want {}",
19070                db_got[i],
19071                db_ref[i]
19072            );
19073        }
19074        for i in 0..n * n {
19075            assert!(
19076                (da_got[i] - da_ref[i]).abs() < 1e-10,
19077                "dA[{i}]: got {} want {}",
19078                da_got[i],
19079                da_ref[i]
19080            );
19081        }
19082
19083        // FD on dB[0,0].
19084        let h = 1e-6;
19085        let mut bp = b_data;
19086        bp[0] += h;
19087        let mut bm = b_data;
19088        bm[0] -= h;
19089        let xp = {
19090            let mut a_copy = a_data;
19091            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19092            bp
19093        };
19094        let xm = {
19095            let mut a_copy = a_data;
19096            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19097            bm
19098        };
19099        let lp: f64 = xp.iter().sum();
19100        let lm: f64 = xm.iter().sum();
19101        let fd = (lp - lm) / (2.0 * h);
19102        assert!(
19103            (db_got[0] - fd).abs() < 1e-7,
19104            "FD dB[0,0]: AD={} FD={}",
19105            db_got[0],
19106            fd
19107        );
19108    }
19109
19110    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
19111    #[test]
19112    fn dense_solve_f64_multi_rhs_jvp() {
19113        use rlx_opt::autodiff_fwd::jvp;
19114        let n = 3usize;
19115        let k = 2usize;
19116        let mut g = Graph::new("solve_mrhs_jvp");
19117        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19118        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19119        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19120        g.set_outputs(vec![x]);
19121
19122        let jg = jvp(&g, &[b]);
19123        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19124            for node in graph.nodes() {
19125                let name = match &node.op {
19126                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19127                    _ => None,
19128                };
19129                if name == Some(want) {
19130                    return node.id;
19131                }
19132            }
19133            panic!("no node named {want:?}");
19134        };
19135        let a_id = find_by_name(&jg, "A");
19136        let b_id = find_by_name(&jg, "B");
19137        let tb_id = find_by_name(&jg, "tangent_B");
19138
19139        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19140        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19141        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
19142
19143        let (sched, mut arena) =
19144            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19145        execute_thunks(&sched, arena.raw_buf_mut());
19146        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
19147
19148        let mut a_copy = a_data;
19149        let mut tb_copy = tb_data;
19150        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
19151        for i in 0..n * k {
19152            assert!(
19153                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
19154                "t_X[{i}]: AD={} ref={}",
19155                tangent_x[i],
19156                tb_copy[i]
19157            );
19158        }
19159
19160        let h = 1e-6;
19161        let mut bp = b_data;
19162        let mut bm = b_data;
19163        for i in 0..n * k {
19164            bp[i] += h * tb_data[i];
19165            bm[i] -= h * tb_data[i];
19166        }
19167        let xp = {
19168            let mut a_copy = a_data;
19169            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19170            bp
19171        };
19172        let xm = {
19173            let mut a_copy = a_data;
19174            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19175            bm
19176        };
19177        for i in 0..n * k {
19178            let fd = (xp[i] - xm[i]) / (2.0 * h);
19179            assert!(
19180                (tangent_x[i] - fd).abs() < 1e-7,
19181                "FD t_X[{i}]: AD={} FD={}",
19182                tangent_x[i],
19183                fd
19184            );
19185        }
19186    }
19187
19188    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
19189    ///
19190    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
19191    /// compile + run, and check the tangent output matches the
19192    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
19193    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
19194    #[test]
19195    fn jvp_dense_solve_b_runs_and_matches_fd() {
19196        use rlx_opt::autodiff_fwd::jvp;
19197        let n = 3usize;
19198
19199        // Forward.
19200        let mut g = Graph::new("jvp_b_e2e");
19201        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19202        let b = g.input("b", Shape::new(&[n], DType::F64));
19203        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19204        g.set_outputs(vec![x]);
19205
19206        // JVP graph perturbing b only.
19207        let jg = jvp(&g, &[b]);
19208        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
19209        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19210            for node in graph.nodes() {
19211                let name = match &node.op {
19212                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19213                    _ => None,
19214                };
19215                if name == Some(want) {
19216                    return node.id;
19217                }
19218            }
19219            panic!("no node named {want:?}");
19220        };
19221        let a_id = find_by_name(&jg, "A");
19222        let b_id = find_by_name(&jg, "b");
19223        let tb_id = find_by_name(&jg, "tangent_b");
19224
19225        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19226        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19227        // Pick an arbitrary perturbation direction.
19228        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
19229
19230        let (sched, mut arena) =
19231            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19232        execute_thunks(&sched, arena.raw_buf_mut());
19233
19234        // Outputs: [primal_x, tangent_x].
19235        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
19236        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19237
19238        // Closed form: t_x = solve(A, t_b).
19239        let t_x_ref = {
19240            let mut a = a_data;
19241            let mut tb = tb_data;
19242            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
19243            assert_eq!(info, 0);
19244            tb
19245        };
19246        for i in 0..n {
19247            assert!(
19248                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19249                "t_x[{i}]: got {} want {}",
19250                tangent_x[i],
19251                t_x_ref[i]
19252            );
19253        }
19254
19255        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
19256        let h = 1e-6;
19257        let mut bp = b_data;
19258        let mut bm = b_data;
19259        for i in 0..n {
19260            bp[i] += h * tb_data[i];
19261            bm[i] -= h * tb_data[i];
19262        }
19263        let xp = {
19264            let mut a = a_data;
19265            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
19266            assert_eq!(info, 0);
19267            bp
19268        };
19269        let xm = {
19270            let mut a = a_data;
19271            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
19272            assert_eq!(info, 0);
19273            bm
19274        };
19275        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
19276        for i in 0..n {
19277            assert!(
19278                (tangent_x[i] - fd[i]).abs() < 1e-7,
19279                "FD mismatch t_x[{i}]: AD={} FD={}",
19280                tangent_x[i],
19281                fd[i]
19282            );
19283        }
19284        // Sanity: primal output is the actual solve.
19285        let primal_ref = {
19286            let mut a = a_data;
19287            let mut b = b_data;
19288            crate::blas::dgesv(&mut a, &mut b, n, 1);
19289            b
19290        };
19291        for i in 0..n {
19292            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
19293        }
19294    }
19295
19296    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
19297    /// path includes the −t_A·x correction term.
19298    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
19299    /// directional derivative of `solve(A, b)` w.r.t. A in the
19300    /// `t_A` direction.
19301    #[test]
19302    fn jvp_dense_solve_a_runs_and_matches_fd() {
19303        use rlx_opt::autodiff_fwd::jvp;
19304        let n = 3usize;
19305
19306        let mut g = Graph::new("jvp_a_e2e");
19307        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19308        let b = g.input("b", Shape::new(&[n], DType::F64));
19309        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19310        g.set_outputs(vec![x]);
19311
19312        let jg = jvp(&g, &[a]);
19313        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19314            for node in graph.nodes() {
19315                let name = match &node.op {
19316                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19317                    _ => None,
19318                };
19319                if name == Some(want) {
19320                    return node.id;
19321                }
19322            }
19323            panic!("no node named {want:?}");
19324        };
19325        let a_id = find_by_name(&jg, "A");
19326        let b_id = find_by_name(&jg, "b");
19327        let ta_id = find_by_name(&jg, "tangent_A");
19328
19329        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19330        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19331        // Asymmetric perturbation direction for A.
19332        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
19333
19334        let (sched, mut arena) =
19335            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
19336        execute_thunks(&sched, arena.raw_buf_mut());
19337
19338        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19339
19340        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
19341        let x_ref = {
19342            let mut a = a_data;
19343            let mut b = b_data;
19344            crate::blas::dgesv(&mut a, &mut b, n, 1);
19345            b
19346        };
19347        let mut prod = [0.0_f64; 3];
19348        for i in 0..n {
19349            for j in 0..n {
19350                prod[i] += ta_data[i * n + j] * x_ref[j];
19351            }
19352        }
19353        let t_x_ref = {
19354            let mut a = a_data;
19355            let mut p = prod;
19356            crate::blas::dgesv(&mut a, &mut p, n, 1);
19357            [-p[0], -p[1], -p[2]]
19358        };
19359        for i in 0..n {
19360            assert!(
19361                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19362                "closed-form t_x[{i}]: AD={} ref={}",
19363                tangent_x[i],
19364                t_x_ref[i]
19365            );
19366        }
19367
19368        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
19369        let h = 1e-6;
19370        let mut ap = a_data;
19371        let mut am = a_data;
19372        for i in 0..n * n {
19373            ap[i] += h * ta_data[i];
19374            am[i] -= h * ta_data[i];
19375        }
19376        let xp = {
19377            let mut a = ap;
19378            let mut b = b_data;
19379            crate::blas::dgesv(&mut a, &mut b, n, 1);
19380            b
19381        };
19382        let xm = {
19383            let mut a = am;
19384            let mut b = b_data;
19385            crate::blas::dgesv(&mut a, &mut b, n, 1);
19386            b
19387        };
19388        for i in 0..n {
19389            let fd = (xp[i] - xm[i]) / (2.0 * h);
19390            assert!(
19391                (tangent_x[i] - fd).abs() < 1e-7,
19392                "FD t_x[{i}]: AD={} FD={}",
19393                tangent_x[i],
19394                fd
19395            );
19396        }
19397    }
19398
19399    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
19400    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
19401    /// in-test reference loop that does the same i32 accumulation
19402    /// and requantize math. Symmetric quant (zp=0) to keep the math
19403    /// head-to-head.
19404    #[test]
19405    fn q_conv2d_matches_reference() {
19406        use rlx_ir::Philox4x32;
19407        // Small NCHW shape — enough to exercise stride/padding edges.
19408        let n = 1usize;
19409        let c_in = 2usize;
19410        let h = 5usize;
19411        let w_in = 5usize;
19412        let c_out = 3usize;
19413        let kh = 3usize;
19414        let kw = 3usize;
19415        let ph = 1usize;
19416        let pw = 1usize;
19417        let sh = 1usize;
19418        let sw = 1usize;
19419        let h_out = (h + 2 * ph - kh) / sh + 1;
19420        let w_out = (w_in + 2 * pw - kw) / sw + 1;
19421
19422        let x_scale = 0.04f32;
19423        let w_scale = 0.02f32;
19424        let out_scale = 0.5f32;
19425        let mult = x_scale * w_scale / out_scale;
19426
19427        let mut rng = Philox4x32::new(2099);
19428        let mut xf = vec![0f32; n * c_in * h * w_in];
19429        rng.fill_normal(&mut xf);
19430        let mut wf = vec![0f32; c_out * c_in * kh * kw];
19431        rng.fill_normal(&mut wf);
19432        let xq: Vec<i8> = xf
19433            .iter()
19434            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19435            .collect();
19436        let wq: Vec<i8> = wf
19437            .iter()
19438            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19439            .collect();
19440        let bias: Vec<i32> = vec![0i32; c_out];
19441
19442        let mut g = Graph::new("qconv");
19443        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
19444        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
19445        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
19446        let out = g.q_conv2d(
19447            xn,
19448            wn,
19449            bn,
19450            vec![kh, kw],
19451            vec![sh, sw],
19452            vec![ph, pw],
19453            vec![1, 1],
19454            1,
19455            0,
19456            0,
19457            0,
19458            mult,
19459            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
19460        );
19461        g.set_outputs(vec![out]);
19462
19463        let plan = rlx_opt::memory::plan_memory(&g);
19464        let mut arena = crate::arena::Arena::from_plan(plan);
19465        let sched = compile_thunks(&g, &arena);
19466        // Capture offsets before borrowing the buf mutably (avoids
19467        // overlap between &mut and the &arena.byte_offset reads).
19468        let xn_off = arena.byte_offset(xn);
19469        let wn_off = arena.byte_offset(wn);
19470        let bn_off = arena.byte_offset(bn);
19471        let out_off = arena.byte_offset(out);
19472        let buf = arena.raw_buf_mut();
19473        unsafe {
19474            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19475            for (i, &v) in xq.iter().enumerate() {
19476                *p.add(i) = v;
19477            }
19478            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19479            for (i, &v) in wq.iter().enumerate() {
19480                *p.add(i) = v;
19481            }
19482            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19483            for (i, &v) in bias.iter().enumerate() {
19484                *p.add(i) = v;
19485            }
19486        }
19487        execute_thunks(&sched, arena.raw_buf_mut());
19488        let out_q: Vec<i8> = unsafe {
19489            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19490            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
19491        };
19492
19493        // Reference: scalar loop in NCHW with the same requantize.
19494        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
19495        for ni in 0..n {
19496            for co in 0..c_out {
19497                for ho in 0..h_out {
19498                    for wo in 0..w_out {
19499                        let mut acc: i32 = 0;
19500                        for ci in 0..c_in {
19501                            for ki in 0..kh {
19502                                for kj in 0..kw {
19503                                    let hi = ho * sh + ki;
19504                                    let wi = wo * sw + kj;
19505                                    if hi < ph || wi < pw {
19506                                        continue;
19507                                    }
19508                                    let hi = hi - ph;
19509                                    let wi = wi - pw;
19510                                    if hi >= h || wi >= w_in {
19511                                        continue;
19512                                    }
19513                                    let xv =
19514                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
19515                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
19516                                    acc += xv * wv;
19517                                }
19518                            }
19519                        }
19520                        let r = (acc as f32 * mult).round() as i32;
19521                        let r = r.clamp(-128, 127) as i8;
19522                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
19523                    }
19524                }
19525            }
19526        }
19527
19528        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19529            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
19530        }
19531    }
19532
19533    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
19534    /// fake-quant reference `Dequantize → MatMul → Quantize` that
19535    /// would produce the same output if we round-tripped through
19536    /// f32. Both should agree element-for-element (or within ±1 i8
19537    /// step, since rounding in the requantize uses different code
19538    /// paths). Symmetric quantization (zp=0) for both paths to keep
19539    /// the math head-to-head.
19540    #[test]
19541    fn q_matmul_matches_fake_quant_reference() {
19542        use rlx_ir::Philox4x32;
19543        let m = 3usize;
19544        let k = 8usize;
19545        let n = 5usize;
19546        let mut rng = Philox4x32::new(2031);
19547
19548        // Pick scales and quantize random f32 inputs to i8.
19549        let x_scale = 0.05f32;
19550        let w_scale = 0.03f32;
19551        let out_scale = 0.4f32;
19552        let mult = x_scale * w_scale / out_scale;
19553        let mut xf = vec![0f32; m * k];
19554        rng.fill_normal(&mut xf);
19555        let mut wf = vec![0f32; k * n];
19556        rng.fill_normal(&mut wf);
19557        let xq: Vec<i8> = xf
19558            .iter()
19559            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19560            .collect();
19561        let wq: Vec<i8> = wf
19562            .iter()
19563            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19564            .collect();
19565        let bias: Vec<i32> = vec![0i32; n];
19566
19567        // ── Direct INT8 path ──
19568        let _f = DType::F32;
19569        let mut g_q = Graph::new("qmm_direct");
19570        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
19571        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
19572        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
19573        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
19574        g_q.set_outputs(vec![out]);
19575        let plan = rlx_opt::memory::plan_memory(&g_q);
19576        let mut arena = crate::arena::Arena::from_plan(plan);
19577        let sched = compile_thunks(&g_q, &arena);
19578
19579        // Fill inputs.
19580        let xn_off = arena.byte_offset(xn);
19581        let wn_off = arena.byte_offset(wn);
19582        let bn_off = arena.byte_offset(bn);
19583        let out_off = arena.byte_offset(out);
19584        let buf = arena.raw_buf_mut();
19585        unsafe {
19586            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19587            for (i, &v) in xq.iter().enumerate() {
19588                *p.add(i) = v;
19589            }
19590            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19591            for (i, &v) in wq.iter().enumerate() {
19592                *p.add(i) = v;
19593            }
19594            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19595            for (i, &v) in bias.iter().enumerate() {
19596                *p.add(i) = v;
19597            }
19598        }
19599        execute_thunks(&sched, arena.raw_buf_mut());
19600        let out_q: Vec<i8> = unsafe {
19601            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19602            (0..m * n).map(|i| *p.add(i)).collect()
19603        };
19604
19605        // ── Fake-quant reference: scalar emulation in plain Rust ──
19606        // Same arithmetic the kernel does, but in a verifier loop:
19607        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
19608        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
19609        let mut out_ref = vec![0i8; m * n];
19610        for mi in 0..m {
19611            for ni in 0..n {
19612                let mut acc: i32 = 0;
19613                for ki in 0..k {
19614                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
19615                }
19616                let r = (acc as f32 * mult).round() as i32;
19617                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
19618            }
19619        }
19620
19621        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19622            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
19623        }
19624    }
19625
19626    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
19627    /// dequantize back, and confirm the result tracks the input
19628    /// within the per-element scale (the inevitable rounding error).
19629    /// Also pins the kernel's saturation behavior at the i8 limits.
19630    #[test]
19631    fn quantize_dequantize_round_trip() {
19632        use rlx_ir::Philox4x32;
19633        let len = 64;
19634        let mut rng = Philox4x32::new(2027);
19635        let mut x = vec![0f32; len];
19636        rng.fill_normal(&mut x);
19637        // Stretch a couple values past the +/- saturation cliff so
19638        // the saturate_i8 path is exercised.
19639        x[0] = 999.0;
19640        x[1] = -999.0;
19641
19642        let scale = 0.05f32;
19643        let zp = 3i32;
19644
19645        let f = DType::F32;
19646        let mut g = Graph::new("qdq");
19647        let xn = g.input("x", Shape::new(&[len], f));
19648        let q = g.quantize(xn, scale, zp);
19649        let dq = g.dequantize(q, scale, zp);
19650        g.set_outputs(vec![dq]);
19651
19652        let plan = rlx_opt::memory::plan_memory(&g);
19653        let mut arena = crate::arena::Arena::from_plan(plan);
19654        let sched = compile_thunks(&g, &arena);
19655        let xn_off = arena.byte_offset(xn);
19656        let dq_off = arena.byte_offset(dq);
19657        let buf = arena.raw_buf_mut();
19658        unsafe {
19659            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19660            for (i, &v) in x.iter().enumerate() {
19661                *p.add(i) = v;
19662            }
19663        }
19664        execute_thunks(&sched, arena.raw_buf_mut());
19665        let out: Vec<f32> = unsafe {
19666            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19667            (0..len).map(|i| *p.add(i)).collect()
19668        };
19669
19670        // Saturated values at i=0,1 should clamp to ±127's dequant
19671        // range (= (±127 - zp) · scale).
19672        let sat_pos = (127 - zp) as f32 * scale;
19673        let sat_neg = (-128 - zp) as f32 * scale;
19674        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
19675        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
19676
19677        // Everything else should round-trip within `scale` (one quant
19678        // step = the worst-case rounding error).
19679        for i in 2..len {
19680            assert!(
19681                (out[i] - x[i]).abs() <= scale + 1e-5,
19682                "qdq[{i}]: {} → {}, scale={scale}",
19683                x[i],
19684                out[i]
19685            );
19686        }
19687    }
19688
19689    /// Per-channel quantize / dequantize: independent scale and zp
19690    /// per slice along an axis. Verifies (a) each channel uses its
19691    /// own scale (not a shared one), (b) saturation still respects
19692    /// the i8 range, (c) channel data layout decomposition is
19693    /// correct (no cross-channel leakage).
19694    #[test]
19695    fn quantize_per_channel_round_trip() {
19696        let c = 4usize;
19697        let inner = 5usize;
19698        // Different magnitudes per channel — proves the per-channel
19699        // scale is actually being read for each row.
19700        let mags = [0.01f32, 0.5, 5.0, 50.0];
19701        let mut x = vec![0f32; c * inner];
19702        for ci in 0..c {
19703            for ii in 0..inner {
19704                // Sweep through values that span [-max_abs, +max_abs]
19705                // for each channel, plus one value past the cliff to
19706                // trigger saturation.
19707                x[ci * inner + ii] = match ii {
19708                    0 => -mags[ci],
19709                    1 => 0.0,
19710                    2 => mags[ci],
19711                    3 => mags[ci] * 1000.0,  // saturates +
19712                    _ => -mags[ci] * 1000.0, // saturates -
19713                };
19714            }
19715        }
19716        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
19717        let zps: Vec<i32> = vec![0, 0, 0, 0];
19718
19719        let f = DType::F32;
19720        let mut g = Graph::new("qdq_pc");
19721        let xn = g.input("x", Shape::new(&[c, inner], f));
19722        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
19723        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
19724        g.set_outputs(vec![dq]);
19725
19726        let plan = rlx_opt::memory::plan_memory(&g);
19727        let mut arena = crate::arena::Arena::from_plan(plan);
19728        let sched = compile_thunks(&g, &arena);
19729        let xn_off = arena.byte_offset(xn);
19730        let dq_off = arena.byte_offset(dq);
19731        let buf = arena.raw_buf_mut();
19732        unsafe {
19733            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19734            for (i, &v) in x.iter().enumerate() {
19735                *p.add(i) = v;
19736            }
19737        }
19738        execute_thunks(&sched, arena.raw_buf_mut());
19739        let out: Vec<f32> = unsafe {
19740            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19741            (0..c * inner).map(|i| *p.add(i)).collect()
19742        };
19743
19744        for ci in 0..c {
19745            // Within-range entries (positions 0, 1, 2) must round-trip
19746            // within one quant step of *that channel's* scale.
19747            for ii in 0..3 {
19748                let idx = ci * inner + ii;
19749                assert!(
19750                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
19751                    "ch {ci} idx {ii}: {} vs {}",
19752                    x[idx],
19753                    out[idx]
19754                );
19755            }
19756            // Saturated positions clamp to ±127 · scale[ci].
19757            let sat_pos = 127.0 * scales[ci];
19758            let sat_neg = -128.0 * scales[ci];
19759            assert!(
19760                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
19761                "ch {ci} +sat: {}",
19762                out[ci * inner + 3]
19763            );
19764            assert!(
19765                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
19766                "ch {ci} -sat: {}",
19767                out[ci * inner + 4]
19768            );
19769        }
19770    }
19771
19772    /// `Op::ActivationBackward` parity for every supported kind.
19773    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
19774    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
19775    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
19776    /// the kernel.
19777    #[test]
19778    fn activation_backward_matches_numerical_per_kind() {
19779        use rlx_ir::Philox4x32;
19780        use rlx_ir::op::Activation;
19781        let mut rng = Philox4x32::new(91);
19782        let len = 32;
19783        // x sampled away from kink/branch points: shifted positive
19784        // (exp/sqrt/log domain) for the unary-positive activations;
19785        // wide range otherwise. Two parallel tests would be cleaner
19786        // but this is concise enough.
19787        let mut x_pos = vec![0f32; len];
19788        rng.fill_normal(&mut x_pos);
19789        for v in x_pos.iter_mut() {
19790            *v = v.abs() + 0.5;
19791        }
19792        let mut x_any = vec![0f32; len];
19793        rng.fill_normal(&mut x_any);
19794        let mut dy = vec![0f32; len];
19795        rng.fill_normal(&mut dy);
19796
19797        for &(kind, x_data, eps, tol) in &[
19798            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
19799            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
19800            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
19801            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
19802            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
19803            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
19804            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
19805            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
19806            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
19807            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
19808        ] {
19809            let f = DType::F32;
19810            let mut g = Graph::new("act_bw");
19811            let xn = g.input("x", Shape::new(&[len], f));
19812            let dyn_ = g.input("dy", Shape::new(&[len], f));
19813            let dx = g.activation_backward(kind, xn, dyn_);
19814            g.set_outputs(vec![dx]);
19815
19816            let plan = rlx_opt::memory::plan_memory(&g);
19817            let mut arena = crate::arena::Arena::from_plan(plan);
19818            let sched = compile_thunks(&g, &arena);
19819
19820            let xn_off = arena.byte_offset(xn);
19821            let dyn_off = arena.byte_offset(dyn_);
19822            let dx_off = arena.byte_offset(dx);
19823            let buf = arena.raw_buf_mut();
19824            unsafe {
19825                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19826                for (i, &v) in x_data.iter().enumerate() {
19827                    *p.add(i) = v;
19828                }
19829                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
19830                for (i, &v) in dy.iter().enumerate() {
19831                    *p.add(i) = v;
19832                }
19833            }
19834            execute_thunks(&sched, arena.raw_buf_mut());
19835            let analytical: Vec<f32> = unsafe {
19836                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
19837                (0..len).map(|i| *p.add(i)).collect()
19838            };
19839
19840            // Apply the forward activation manually; finite-difference
19841            // each element.
19842            let act_apply = |kind: Activation, x: f32| -> f32 {
19843                match kind {
19844                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
19845                    Activation::Tanh => x.tanh(),
19846                    Activation::Silu => x / (1.0 + (-x).exp()),
19847                    Activation::Gelu => {
19848                        // Match the kernel's exact erf form.
19849                        const INV_SQRT2: f32 = 0.707_106_77;
19850                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
19851                    }
19852                    Activation::GeluApprox => {
19853                        const C: f32 = 0.797_884_6;
19854                        const A: f32 = 0.044_715;
19855                        let inner = C * (x + A * x * x * x);
19856                        0.5 * x * (1.0 + inner.tanh())
19857                    }
19858                    Activation::Exp => x.exp(),
19859                    Activation::Log => x.ln(),
19860                    Activation::Sqrt => x.sqrt(),
19861                    Activation::Rsqrt => 1.0 / x.sqrt(),
19862                    Activation::Neg => -x,
19863                    Activation::Relu => x.max(0.0),
19864                    Activation::Abs => x.abs(),
19865                    Activation::Round => x.round(),
19866                    Activation::Sin => x.sin(),
19867                    Activation::Cos => x.cos(),
19868                    Activation::Tan => x.tan(),
19869                    Activation::Atan => x.atan(),
19870                }
19871            };
19872            for i in 0..len {
19873                let xv = x_data[i];
19874                let plus = act_apply(kind, xv + eps);
19875                let minus = act_apply(kind, xv - eps);
19876                let num = (plus - minus) / (2.0 * eps) * dy[i];
19877                assert!(
19878                    (analytical[i] - num).abs() < tol,
19879                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
19880                    analytical[i]
19881                );
19882            }
19883        }
19884    }
19885
19886    /// Batched 3-D MatMul VJP — the transformer-attention shape
19887    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
19888    /// `Op::Transpose` with a perm that swaps the last two dims.
19889    #[test]
19890    fn matmul_3d_gradient_matches_numerical() {
19891        use rlx_ir::Philox4x32;
19892        let batch = 2usize;
19893        let m = 3usize;
19894        let k = 4usize;
19895        let n = 5usize;
19896        let mut rng = Philox4x32::new(101);
19897        let mut a_data = vec![0f32; batch * m * k];
19898        rng.fill_normal(&mut a_data);
19899        let mut b_data = vec![0f32; batch * k * n];
19900        rng.fill_normal(&mut b_data);
19901
19902        let f = DType::F32;
19903        let mut fwd = Graph::new("matmul_3d");
19904        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
19905        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
19906        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
19907        let loss = fwd.add_node(
19908            Op::Reduce {
19909                op: ReduceOp::Sum,
19910                axes: vec![0, 1, 2],
19911                keep_dim: false,
19912            },
19913            vec![mm],
19914            Shape::from_dims(&[], f),
19915        );
19916        fwd.set_outputs(vec![loss]);
19917
19918        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
19919        let d_out = bwd_graph
19920            .nodes()
19921            .iter()
19922            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
19923            .map(|n| n.id)
19924            .unwrap();
19925
19926        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
19927        let mut arena = crate::arena::Arena::from_plan(plan);
19928        let sched = compile_thunks(&bwd_graph, &arena);
19929        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
19930            let off = arena.byte_offset(id);
19931            let buf = arena.raw_buf_mut();
19932            unsafe {
19933                let p = buf.as_mut_ptr().add(off) as *mut f32;
19934                for (i, &v) in data.iter().enumerate() {
19935                    *p.add(i) = v;
19936                }
19937            }
19938        }
19939        execute_thunks(&sched, arena.raw_buf_mut());
19940        let gb_id = bwd_graph.outputs[1];
19941        let g_b: Vec<f32> = unsafe {
19942            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
19943            (0..batch * k * n).map(|i| *p.add(i)).collect()
19944        };
19945
19946        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
19947        let forward_loss = |b_vals: &[f32]| -> f32 {
19948            let mut out = vec![0f32; batch * m * n];
19949            for bi in 0..batch {
19950                for mi in 0..m {
19951                    for ni in 0..n {
19952                        let mut acc = 0f32;
19953                        for ki in 0..k {
19954                            acc +=
19955                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
19956                        }
19957                        out[bi * m * n + mi * n + ni] = acc;
19958                    }
19959                }
19960            }
19961            out.iter().sum()
19962        };
19963        let eps = 1e-3f32;
19964        let mut bp_p = b_data.clone();
19965        let mut g_b_num = vec![0f32; b_data.len()];
19966        for i in 0..b_data.len() {
19967            let s = bp_p[i];
19968            bp_p[i] = s + eps;
19969            let lp = forward_loss(&bp_p);
19970            bp_p[i] = s - eps;
19971            let lm = forward_loss(&bp_p);
19972            bp_p[i] = s;
19973            g_b_num[i] = (lp - lm) / (2.0 * eps);
19974        }
19975        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
19976            assert!(
19977                (a - n).abs() < 5e-3,
19978                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
19979            );
19980        }
19981    }
19982
19983    /// Composed `Op::Softmax` VJP — the gradient is built from
19984    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
19985    /// SoftmaxBackward kernel. Verifies the closed-form
19986    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
19987    /// 2-D logits tensor.
19988    #[test]
19989    fn softmax_gradient_matches_numerical() {
19990        use rlx_ir::Philox4x32;
19991        let n = 3usize;
19992        let c = 5usize;
19993        let mut rng = Philox4x32::new(57);
19994        let mut x_data = vec![0f32; n * c];
19995        rng.fill_normal(&mut x_data);
19996
19997        let f = DType::F32;
19998        let mut fwd = Graph::new("softmax_only");
19999        let xn = fwd.input("x", Shape::new(&[n, c], f));
20000        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20001        // Loss = sum(softmax · target) for some random fixed target —
20002        // any linear loss will do; sum-of-all is the simplest and gives
20003        // a uniform gradient flow into the softmax.
20004        let loss = fwd.add_node(
20005            Op::Reduce {
20006                op: ReduceOp::Sum,
20007                axes: vec![0, 1],
20008                keep_dim: false,
20009            },
20010            vec![sm],
20011            Shape::from_dims(&[], f),
20012        );
20013        fwd.set_outputs(vec![loss]);
20014
20015        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
20016        // input so we can compare it directly. The forward NodeId for
20017        // `xn` doubles as its bwd-graph mirror.
20018        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20019        let d_out = bwd_graph
20020            .nodes()
20021            .iter()
20022            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20023            .map(|n| n.id)
20024            .unwrap();
20025
20026        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20027        let mut arena = crate::arena::Arena::from_plan(plan);
20028        let sched = compile_thunks(&bwd_graph, &arena);
20029        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20030            let off = arena.byte_offset(id);
20031            let buf = arena.raw_buf_mut();
20032            unsafe {
20033                let p = buf.as_mut_ptr().add(off) as *mut f32;
20034                for (i, &v) in data.iter().enumerate() {
20035                    *p.add(i) = v;
20036                }
20037            }
20038        }
20039        execute_thunks(&sched, arena.raw_buf_mut());
20040        let g_x_id = bwd_graph.outputs[1];
20041        let g_x: Vec<f32> = unsafe {
20042            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20043            (0..n * c).map(|i| *p.add(i)).collect()
20044        };
20045
20046        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
20047        // analytically. So expect g_x ≈ 0 within FD precision. (This
20048        // doubles as a strong sanity check for the composition.)
20049        let forward_loss = |x: &[f32]| -> f32 {
20050            let mut total = 0f32;
20051            for ni in 0..n {
20052                let row = &x[ni * c..(ni + 1) * c];
20053                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20054                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20055                for &v in row {
20056                    total += (v - m).exp() / denom;
20057                }
20058            }
20059            total
20060        };
20061        let eps = 1e-3f32;
20062        let mut p = x_data.clone();
20063        for i in 0..x_data.len() {
20064            let s = p[i];
20065            p[i] = s + eps;
20066            let lp = forward_loss(&p);
20067            p[i] = s - eps;
20068            let lm = forward_loss(&p);
20069            p[i] = s;
20070            let num = (lp - lm) / (2.0 * eps);
20071            assert!(
20072                (g_x[i] - num).abs() < 5e-3,
20073                "softmax g_x[{i}]: analytical {} vs numerical {num}",
20074                g_x[i]
20075            );
20076        }
20077    }
20078
20079    /// LayerNorm VJP — three gradients in one pass:
20080    ///   d_x via `LayerNormBackwardInput`,
20081    ///   d_gamma via `LayerNormBackwardGamma`,
20082    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
20083    #[test]
20084    fn layer_norm_gradient_matches_numerical() {
20085        use rlx_ir::Philox4x32;
20086        let rows = 3usize;
20087        let h = 6usize;
20088        let mut rng = Philox4x32::new(1009);
20089        let mut x_data = vec![0f32; rows * h];
20090        rng.fill_normal(&mut x_data);
20091        let mut g_data = vec![0f32; h];
20092        rng.fill_normal(&mut g_data);
20093        for v in g_data.iter_mut() {
20094            *v = v.abs() + 0.5;
20095        }
20096        let mut b_data = vec![0f32; h];
20097        rng.fill_normal(&mut b_data);
20098        let eps = 1e-5f32;
20099
20100        let f = DType::F32;
20101        let mut fwd = Graph::new("ln_only");
20102        let xn = fwd.input("x", Shape::new(&[rows, h], f));
20103        let gp = fwd.param("gamma", Shape::new(&[h], f));
20104        let bp = fwd.param("beta", Shape::new(&[h], f));
20105        let ln = fwd.add_node(
20106            Op::LayerNorm { axis: -1, eps },
20107            vec![xn, gp, bp],
20108            Shape::new(&[rows, h], f),
20109        );
20110        let loss = fwd.add_node(
20111            Op::Reduce {
20112                op: ReduceOp::Sum,
20113                axes: vec![0, 1],
20114                keep_dim: false,
20115            },
20116            vec![ln],
20117            Shape::from_dims(&[], f),
20118        );
20119        fwd.set_outputs(vec![loss]);
20120
20121        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
20122        let d_out = bwd_graph
20123            .nodes()
20124            .iter()
20125            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20126            .map(|n| n.id)
20127            .unwrap();
20128
20129        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20130        let mut arena = crate::arena::Arena::from_plan(plan);
20131        let sched = compile_thunks(&bwd_graph, &arena);
20132        for &(id, data) in &[
20133            (xn, &x_data),
20134            (gp, &g_data),
20135            (bp, &b_data),
20136            (d_out, &vec![1.0f32]),
20137        ] {
20138            let off = arena.byte_offset(id);
20139            let buf = arena.raw_buf_mut();
20140            unsafe {
20141                let p = buf.as_mut_ptr().add(off) as *mut f32;
20142                for (i, &v) in data.iter().enumerate() {
20143                    *p.add(i) = v;
20144                }
20145            }
20146        }
20147        execute_thunks(&sched, arena.raw_buf_mut());
20148        let read = |id: NodeId, n: usize| -> Vec<f32> {
20149            let off = arena.byte_offset(id);
20150            unsafe {
20151                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
20152                (0..n).map(|i| *p.add(i)).collect()
20153            }
20154        };
20155        let dx_a = read(bwd_graph.outputs[1], rows * h);
20156        let dg_a = read(bwd_graph.outputs[2], h);
20157        let db_a = read(bwd_graph.outputs[3], h);
20158
20159        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
20160            let mut total = 0f32;
20161            for r in 0..rows {
20162                let row = &x[r * h..(r + 1) * h];
20163                let mean = row.iter().sum::<f32>() / h as f32;
20164                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
20165                let inv_std = 1.0 / (var + eps).sqrt();
20166                for d in 0..h {
20167                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
20168                }
20169            }
20170            total
20171        };
20172        let h_eps = 1e-3f32;
20173
20174        let mut x_p = x_data.clone();
20175        for i in 0..x_p.len() {
20176            let s = x_p[i];
20177            x_p[i] = s + h_eps;
20178            let lp = forward_loss(&x_p, &g_data, &b_data);
20179            x_p[i] = s - h_eps;
20180            let lm = forward_loss(&x_p, &g_data, &b_data);
20181            x_p[i] = s;
20182            let num = (lp - lm) / (2.0 * h_eps);
20183            assert!(
20184                (dx_a[i] - num).abs() < 5e-3,
20185                "ln dx[{i}]: analytical {} vs numerical {num}",
20186                dx_a[i]
20187            );
20188        }
20189        let mut g_p = g_data.clone();
20190        for i in 0..g_p.len() {
20191            let s = g_p[i];
20192            g_p[i] = s + h_eps;
20193            let lp = forward_loss(&x_data, &g_p, &b_data);
20194            g_p[i] = s - h_eps;
20195            let lm = forward_loss(&x_data, &g_p, &b_data);
20196            g_p[i] = s;
20197            let num = (lp - lm) / (2.0 * h_eps);
20198            assert!(
20199                (dg_a[i] - num).abs() < 5e-3,
20200                "ln dg[{i}]: analytical {} vs numerical {num}",
20201                dg_a[i]
20202            );
20203        }
20204        let mut b_p = b_data.clone();
20205        for i in 0..b_p.len() {
20206            let s = b_p[i];
20207            b_p[i] = s + h_eps;
20208            let lp = forward_loss(&x_data, &g_data, &b_p);
20209            b_p[i] = s - h_eps;
20210            let lm = forward_loss(&x_data, &g_data, &b_p);
20211            b_p[i] = s;
20212            let num = (lp - lm) / (2.0 * h_eps);
20213            assert!(
20214                (db_a[i] - num).abs() < 5e-3,
20215                "ln db[{i}]: analytical {} vs numerical {num}",
20216                db_a[i]
20217            );
20218        }
20219    }
20220
20221    /// Single dense layer + softmax-cross-entropy + mean reduce —
20222    /// the simplest non-trivial training graph. Validates MatMul,
20223    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
20224    /// plumbing all at once.
20225    #[test]
20226    fn dense_sce_mean_gradient_matches_numerical() {
20227        use rlx_ir::Philox4x32;
20228        let bs = 4usize;
20229        let k_in = 3usize;
20230        let c = 5usize;
20231        let mut rng = Philox4x32::new(7);
20232        let mut x = vec![0f32; bs * k_in];
20233        rng.fill_normal(&mut x);
20234        let mut w_init = vec![0f32; k_in * c];
20235        rng.fill_normal(&mut w_init);
20236        let mut b_init = vec![0f32; c];
20237        rng.fill_normal(&mut b_init);
20238        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20239
20240        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
20241        let f = DType::F32;
20242        let mut fwd = Graph::new("dense_sce");
20243        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20244        let lb = fwd.input("labels", Shape::new(&[bs], f));
20245        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20246        let bp = fwd.param("b", Shape::new(&[c], f));
20247        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20248        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
20249        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20250        let loss = fwd.add_node(
20251            Op::Reduce {
20252                op: ReduceOp::Sum,
20253                axes: vec![0],
20254                keep_dim: false,
20255            },
20256            vec![loss_per],
20257            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
20258            Shape::from_dims(&[], f),
20259        );
20260        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
20261        // less aggressively than Mean would, and gives us a closed-form
20262        // reference for the loss we expect.
20263        // For simplicity though, switch to Mean which the tests should also cover.
20264        // (Re-using `loss` with Sum here for now; the mean factor cancels in
20265        // the gradient comparison since both analytical and numerical use the
20266        // same forward.)
20267        fwd.set_outputs(vec![loss]);
20268
20269        // ── Backward graph ──
20270        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
20271        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
20272        // in bwd_graph match their fwd ids (the mirror keeps order).
20273        let d_out = bwd_graph
20274            .nodes()
20275            .iter()
20276            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20277            .map(|n| n.id)
20278            .expect("d_output input");
20279
20280        let (sched, mut arena) = prepare(
20281            &bwd_graph,
20282            &[
20283                (xn, &x),
20284                (lb, &labels),
20285                (wp, &w_init),
20286                (bp, &b_init),
20287                (d_out, &[1.0]),
20288            ],
20289        );
20290        execute_thunks(&sched, arena.raw_buf_mut());
20291
20292        let outs = &bwd_graph.outputs;
20293        let loss_id = outs[0];
20294        let gw_id = outs[1];
20295        let gb_id = outs[2];
20296        let loss_actual = read_arena(&arena, loss_id, 1)[0];
20297        let gw_actual = read_arena(&arena, gw_id, k_in * c);
20298        let gb_actual = read_arena(&arena, gb_id, c);
20299
20300        // ── Forward-only graph for finite differences ──
20301        // Re-use the same `fwd` graph; set up its own arena and rerun
20302        // for each perturbed parameter.
20303        let plan = rlx_opt::memory::plan_memory(&fwd);
20304        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20305        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20306        write_arena(&mut fwd_arena, xn, &x);
20307        write_arena(&mut fwd_arena, lb, &labels);
20308
20309        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
20310            write_arena(arena, wp, w);
20311            write_arena(arena, bp, b);
20312            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20313            read_arena(arena, loss, 1)[0]
20314        };
20315
20316        // Sanity: the loss reported by the bwd graph matches the
20317        // forward-only graph on the unperturbed inputs.
20318        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
20319        assert!(
20320            (loss_actual - loss_check).abs() < 1e-4,
20321            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
20322        );
20323
20324        let eps = 1e-3f32;
20325        let mut w_perturbed = w_init.clone();
20326        let mut gw_numerical = vec![0f32; w_init.len()];
20327        for i in 0..w_init.len() {
20328            let saved = w_perturbed[i];
20329            w_perturbed[i] = saved + eps;
20330            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20331            w_perturbed[i] = saved - eps;
20332            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20333            w_perturbed[i] = saved;
20334            gw_numerical[i] = (lp - lm) / (2.0 * eps);
20335        }
20336        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
20337            assert!(
20338                (a - n).abs() < 5e-3,
20339                "grad_w[{i}]: analytical {a} vs numerical {n}"
20340            );
20341        }
20342
20343        let mut b_perturbed = b_init.clone();
20344        let mut gb_numerical = vec![0f32; b_init.len()];
20345        for i in 0..b_init.len() {
20346            let saved = b_perturbed[i];
20347            b_perturbed[i] = saved + eps;
20348            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20349            b_perturbed[i] = saved - eps;
20350            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20351            b_perturbed[i] = saved;
20352            gb_numerical[i] = (lp - lm) / (2.0 * eps);
20353        }
20354        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
20355            assert!(
20356                (a - n).abs() < 5e-3,
20357                "grad_b[{i}]: analytical {a} vs numerical {n}"
20358            );
20359        }
20360    }
20361
20362    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
20363    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
20364    #[test]
20365    fn dense_sce_mean_reduce_gradient_matches_numerical() {
20366        use rlx_ir::Philox4x32;
20367        let bs = 3usize;
20368        let k_in = 2usize;
20369        let c = 4usize;
20370        let mut rng = Philox4x32::new(13);
20371        let mut x = vec![0f32; bs * k_in];
20372        rng.fill_normal(&mut x);
20373        let mut w_init = vec![0f32; k_in * c];
20374        rng.fill_normal(&mut w_init);
20375        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20376
20377        let f = DType::F32;
20378        let mut fwd = Graph::new("dense_sce_mean");
20379        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20380        let lb = fwd.input("labels", Shape::new(&[bs], f));
20381        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20382        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20383        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
20384        let loss = fwd.add_node(
20385            Op::Reduce {
20386                op: ReduceOp::Mean,
20387                axes: vec![0],
20388                keep_dim: false,
20389            },
20390            vec![loss_per],
20391            Shape::from_dims(&[], f),
20392        );
20393        fwd.set_outputs(vec![loss]);
20394
20395        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
20396        let d_out = bwd_graph
20397            .nodes()
20398            .iter()
20399            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20400            .map(|n| n.id)
20401            .unwrap();
20402
20403        let (sched, mut arena) = prepare(
20404            &bwd_graph,
20405            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
20406        );
20407        execute_thunks(&sched, arena.raw_buf_mut());
20408
20409        let outs = &bwd_graph.outputs;
20410        let loss_id = outs[0];
20411        let gw_id = outs[1];
20412        let _ = read_arena(&arena, loss_id, 1)[0];
20413        let gw_actual = read_arena(&arena, gw_id, k_in * c);
20414
20415        let plan = rlx_opt::memory::plan_memory(&fwd);
20416        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20417        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20418        write_arena(&mut fwd_arena, xn, &x);
20419        write_arena(&mut fwd_arena, lb, &labels);
20420
20421        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
20422            write_arena(arena, wp, w);
20423            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20424            read_arena(arena, loss, 1)[0]
20425        };
20426
20427        let eps = 1e-3f32;
20428        let mut wp_p = w_init.clone();
20429        let mut gw_num = vec![0f32; w_init.len()];
20430        for i in 0..w_init.len() {
20431            let s = wp_p[i];
20432            wp_p[i] = s + eps;
20433            let lp = run_loss(&mut fwd_arena, &wp_p);
20434            wp_p[i] = s - eps;
20435            let lm = run_loss(&mut fwd_arena, &wp_p);
20436            wp_p[i] = s;
20437            gw_num[i] = (lp - lm) / (2.0 * eps);
20438        }
20439        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
20440            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
20441        }
20442    }
20443    /// The full TinyConv-MNIST forward path (downsized) plumbed
20444    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
20445    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
20446    /// compose into a graph that produces correct gradients.
20447    #[test]
20448    fn tinyconv_full_gradient_matches_numerical() {
20449        use rlx_ir::Philox4x32;
20450        // Tiny shapes so finite differences finish in <1s.
20451        let n = 1usize;
20452        let c_in = 1usize;
20453        let h = 6usize;
20454        let w_in = 6usize;
20455        let c_mid = 2usize; // first conv output channels
20456        let kh = 3;
20457        let kw = 3;
20458        let h1 = h - kh + 1; // 4
20459        let w1 = w_in - kw + 1; // 4
20460        let h2 = h1 / 2;
20461        let w2 = w1 / 2; // 2 × 2 after 2× pool
20462        let flat = c_mid * h2 * w2; // 8
20463        let num_classes = 3usize;
20464
20465        let mut rng = Philox4x32::new(31);
20466        let mut x = vec![0f32; n * c_in * h * w_in];
20467        rng.fill_normal(&mut x);
20468        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
20469        rng.fill_normal(&mut wc);
20470        for v in wc.iter_mut() {
20471            *v *= 0.2;
20472        }
20473        // Shift conv-bias well away from the ReLU zero-boundary. Without
20474        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
20475        // pre-activation that happened to land near zero — making the
20476        // central-difference numerical gradient discontinuous and
20477        // diverge from the analytical (which assumes local smoothness).
20478        // +5.0 keeps every pre-activation positive for any random init
20479        // produced by Philox seed 31 with the wc/x scales used here, so
20480        // ReLU acts as an identity and finite differences are exact.
20481        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
20482        let mut wfc = vec![0f32; flat * num_classes];
20483        rng.fill_normal(&mut wfc);
20484        for v in wfc.iter_mut() {
20485            *v *= 0.5;
20486        }
20487        let mut bfc = vec![0f32; num_classes];
20488        rng.fill_normal(&mut bfc);
20489        let labels: Vec<f32> = vec![1.0]; // batch=1
20490
20491        let f = DType::F32;
20492        let mut fwd = Graph::new("tinyconv");
20493        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
20494        let lb = fwd.input("labels", Shape::new(&[n], f));
20495        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
20496        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
20497        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
20498        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
20499
20500        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
20501        let conv = fwd.add_node(
20502            Op::Conv {
20503                kernel_size: vec![kh, kw],
20504                stride: vec![1, 1],
20505                padding: vec![0, 0],
20506                dilation: vec![1, 1],
20507                groups: 1,
20508            },
20509            vec![xn, wcp],
20510            Shape::new(&[n, c_mid, h1, w1], f),
20511        );
20512        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
20513        // shape so the Add becomes a plain element-wise op. Going through
20514        // an explicit Reshape→Expand instead of relying on the Add to
20515        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
20516        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
20517        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
20518        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
20519        // across all positions instead of channel-broadcasting. Going
20520        // through Expand (a real broadcast thunk) avoids that path
20521        // entirely. The autodiff still exercises `unbroadcast` because
20522        // `Op::Expand`'s VJP reduces over the broadcast axes.
20523        let bc_4d = fwd.add_node(
20524            Op::Reshape {
20525                new_shape: vec![1, c_mid as i64, 1, 1],
20526            },
20527            vec![bcp],
20528            Shape::new(&[1, c_mid, 1, 1], f),
20529        );
20530        let bc_expanded = fwd.add_node(
20531            Op::Expand {
20532                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
20533            },
20534            vec![bc_4d],
20535            Shape::new(&[n, c_mid, h1, w1], f),
20536        );
20537        let conv_b = fwd.binary(
20538            BinaryOp::Add,
20539            conv,
20540            bc_expanded,
20541            Shape::new(&[n, c_mid, h1, w1], f),
20542        );
20543        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
20544        let pool = fwd.add_node(
20545            Op::Pool {
20546                kind: ReduceOp::Max,
20547                kernel_size: vec![2, 2],
20548                stride: vec![2, 2],
20549                padding: vec![0, 0],
20550            },
20551            vec![relu],
20552            Shape::new(&[n, c_mid, h2, w2], f),
20553        );
20554        let flatn = fwd.add_node(
20555            Op::Reshape {
20556                new_shape: vec![n as i64, flat as i64],
20557            },
20558            vec![pool],
20559            Shape::new(&[n, flat], f),
20560        );
20561        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
20562        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
20563        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20564        let loss = fwd.add_node(
20565            Op::Reduce {
20566                op: ReduceOp::Mean,
20567                axes: vec![0],
20568                keep_dim: false,
20569            },
20570            vec![loss_per],
20571            Shape::from_dims(&[], f),
20572        );
20573        fwd.set_outputs(vec![loss]);
20574
20575        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
20576        let d_out = bwd_graph
20577            .nodes()
20578            .iter()
20579            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20580            .map(|n| n.id)
20581            .unwrap();
20582
20583        let (sched, mut arena) = prepare(
20584            &bwd_graph,
20585            &[
20586                (xn, &x),
20587                (lb, &labels),
20588                (wcp, &wc),
20589                (bcp, &bc),
20590                (wfp, &wfc),
20591                (bfp, &bfc),
20592                (d_out, &[1.0]),
20593            ],
20594        );
20595        execute_thunks(&sched, arena.raw_buf_mut());
20596
20597        let outs = bwd_graph.outputs.clone();
20598        let loss_id = outs[0];
20599        let g_wc_id = outs[1];
20600        let g_bc_id = outs[2];
20601        let g_wfc_id = outs[3];
20602        let g_bfc_id = outs[4];
20603        let loss_actual = read_arena(&arena, loss_id, 1)[0];
20604        let g_wc = read_arena(&arena, g_wc_id, wc.len());
20605        let g_bc = read_arena(&arena, g_bc_id, bc.len());
20606        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
20607        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
20608
20609        // Forward-only arena for finite differences.
20610        let plan = rlx_opt::memory::plan_memory(&fwd);
20611        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20612        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20613        write_arena(&mut fwd_arena, xn, &x);
20614        write_arena(&mut fwd_arena, lb, &labels);
20615
20616        // Closure variant: we need to set all four params each call so
20617        // perturbations to one don't leak between sweeps.
20618        let run_loss = |arena: &mut crate::arena::Arena,
20619                        wc: &[f32],
20620                        bc: &[f32],
20621                        wfc: &[f32],
20622                        bfc: &[f32]|
20623         -> f32 {
20624            write_arena(arena, wcp, wc);
20625            write_arena(arena, bcp, bc);
20626            write_arena(arena, wfp, wfc);
20627            write_arena(arena, bfp, bfc);
20628            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20629            read_arena(arena, loss, 1)[0]
20630        };
20631
20632        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
20633        assert!(
20634            (loss_actual - loss_check).abs() < 1e-4,
20635            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
20636        );
20637
20638        let eps = 1e-3f32;
20639        let check_grad = |arena: &mut crate::arena::Arena,
20640                          name: &str,
20641                          analytical: &[f32],
20642                          mut perturb: Box<
20643            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
20644        >,
20645                          n: usize| {
20646            for i in 0..n {
20647                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
20648                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
20649                let num = (lp - lm) / (2.0 * eps);
20650                assert!(
20651                    (analytical[i] - num).abs() < 5e-3,
20652                    "{name}[{i}]: analytical {} vs numerical {num}",
20653                    analytical[i]
20654                );
20655            }
20656        };
20657
20658        // Helper to perturb one param and run forward. Kept as a
20659        // reference for the explicit per-param sweep pattern below.
20660        #[allow(unused_macros)]
20661        macro_rules! sweep {
20662            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
20663                let n = $base.len();
20664                for i in 0..n {
20665                    let mut p = $base.clone();
20666                    let s = p[i];
20667                    p[i] = s + eps;
20668                    let lp = {
20669                        let $set_param = &p;
20670                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
20671                        // Reset others, set the one being swept, run.
20672                        // (the macro receives one of the four params via $set_param)
20673                        let _ = $set_param;
20674                        // Fall through to the explicit per-param helper:
20675                        0.0_f32
20676                    };
20677                    let _ = lp;
20678                }
20679            }};
20680        }
20681        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
20682
20683        // Per-param sweeps (explicit, not macro — clearer).
20684        for i in 0..wc.len() {
20685            let mut p = wc.clone();
20686            let s = p[i];
20687            p[i] = s + eps;
20688            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20689            p[i] = s - eps;
20690            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20691            let num = (lp - lm) / (2.0 * eps);
20692            assert!(
20693                (g_wc[i] - num).abs() < 5e-3,
20694                "g_wc[{i}]: {} vs {num}",
20695                g_wc[i]
20696            );
20697        }
20698        for i in 0..bc.len() {
20699            let mut p = bc.clone();
20700            let s = p[i];
20701            p[i] = s + eps;
20702            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20703            p[i] = s - eps;
20704            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20705            let num = (lp - lm) / (2.0 * eps);
20706            assert!(
20707                (g_bc[i] - num).abs() < 5e-3,
20708                "g_bc[{i}]: {} vs {num}",
20709                g_bc[i]
20710            );
20711        }
20712        for i in 0..wfc.len() {
20713            let mut p = wfc.clone();
20714            let s = p[i];
20715            p[i] = s + eps;
20716            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20717            p[i] = s - eps;
20718            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20719            let num = (lp - lm) / (2.0 * eps);
20720            assert!(
20721                (g_wfc[i] - num).abs() < 5e-3,
20722                "g_wfc[{i}]: {} vs {num}",
20723                g_wfc[i]
20724            );
20725        }
20726        for i in 0..bfc.len() {
20727            let mut p = bfc.clone();
20728            let s = p[i];
20729            p[i] = s + eps;
20730            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20731            p[i] = s - eps;
20732            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20733            let num = (lp - lm) / (2.0 * eps);
20734            assert!(
20735                (g_bfc[i] - num).abs() < 5e-3,
20736                "g_bfc[{i}]: {} vs {num}",
20737                g_bfc[i]
20738            );
20739        }
20740    }
20741
20742    /// Negative case: a Narrow whose output has multiple consumers
20743    /// must NOT be fused (we can't elide its write — something else
20744    /// reads it).
20745    #[test]
20746    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
20747        let f = DType::F32;
20748        let mut g = Graph::new("nr_skip");
20749        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
20750        let cos = g.input("cos", Shape::new(&[16], f));
20751        let sin = g.input("sin", Shape::new(&[16], f));
20752        let q = g.narrow_(qkv, 2, 0, 64);
20753        let q_rope = g.rope(q, cos, sin, 16);
20754        // Second consumer of `q` blocks the fusion.
20755        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
20756        g.set_outputs(vec![q_rope, q_dup]);
20757
20758        let plan = rlx_opt::memory::plan_memory(&g);
20759        let arena = crate::arena::Arena::from_plan(plan);
20760        let sched = compile_thunks(&g, &arena);
20761
20762        let narrow_count = sched
20763            .thunks
20764            .iter()
20765            .filter(|t| matches!(t, Thunk::Narrow { .. }))
20766            .count();
20767        assert!(
20768            narrow_count >= 1,
20769            "Narrow with multiple consumers must NOT be fused away"
20770        );
20771    }
20772
20773    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
20774    //
20775    // Validates: forward execution inlines fwd_body; VJP rule inlines
20776    // vjp_body in place of recursing into fwd_body; JVP rule inlines
20777    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
20778    // would yield a *different* gradient than the override, so we know
20779    // the override actually fired.
20780
20781    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
20782    /// without override AD bodies. Verifies the body is compiled,
20783    /// constants in the body fill correctly, and the output lands at
20784    /// the outer node's slot.
20785    #[test]
20786    fn custom_fn_forward_inlines_body() {
20787        let s = Shape::new(&[3], DType::F32);
20788
20789        // Body: f(x) = x + 1
20790        let mut body = Graph::new("addone_body");
20791        let x = body.input("x", s.clone());
20792        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
20793        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
20794        let y = body.binary(BinaryOp::Add, x, one, s.clone());
20795        body.set_outputs(vec![y]);
20796
20797        let mut g = Graph::new("custom_fn_outer");
20798        let xin = g.input("x_in", s.clone());
20799        let cf = g.custom_fn(vec![xin], body, None, None);
20800        g.set_outputs(vec![cf]);
20801
20802        let xs = vec![10.0_f32, 20.0, 30.0];
20803        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
20804        execute_thunks(&sched, arena.raw_buf_mut());
20805        let got = read_arena(&arena, cf, 3);
20806        assert_eq!(got, vec![11.0, 21.0, 31.0]);
20807    }
20808
20809    /// Locate an Op::Input or Op::Param by name in a graph.
20810    fn find_named(graph: &Graph, want: &str) -> NodeId {
20811        for n in graph.nodes() {
20812            let name = match &n.op {
20813                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20814                _ => None,
20815            };
20816            if name == Some(want) {
20817                return n.id;
20818            }
20819        }
20820        panic!("no node named {want:?} in graph");
20821    }
20822
20823    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
20824    /// reported gradient should be 2 — different from the natural 1
20825    /// you'd get by recursing into the identity body.
20826    #[test]
20827    fn custom_fn_vjp_overrides_natural_gradient() {
20828        use rlx_opt::autodiff::grad_with_loss;
20829        let s = Shape::new(&[1], DType::F32);
20830
20831        let mut fwd = Graph::new("id_fwd");
20832        let x = fwd.input("x", s.clone());
20833        fwd.set_outputs(vec![x]);
20834
20835        let mut vjp_g = Graph::new("id_vjp");
20836        let _x_p = vjp_g.input("x", s.clone());
20837        let _y_p = vjp_g.input("primal_output", s.clone());
20838        let dy = vjp_g.input("d_output", s.clone());
20839        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20840        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20841        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
20842        vjp_g.set_outputs(vec![dx]);
20843
20844        let mut g = Graph::new("outer");
20845        let xp = g.param("x", s.clone());
20846        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
20847        g.set_outputs(vec![cf]);
20848
20849        let bwd = grad_with_loss(&g, &[xp]);
20850        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
20851
20852        let xb = find_named(&bwd, "x");
20853        let dout = find_named(&bwd, "d_output");
20854        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
20855        execute_thunks(&sched, arena.raw_buf_mut());
20856        let loss = read_arena(&arena, bwd.outputs[0], 1);
20857        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
20858        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
20859        assert!(
20860            (dx_v[0] - 2.0).abs() < 1e-6,
20861            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
20862            dx_v[0]
20863        );
20864    }
20865
20866    /// VJP override: f(a, b) = a*b with vjp_body returning
20867    /// (b * d_output, a * d_output). Validates routing of multiple
20868    /// primals + d_output through the override; matches the natural
20869    /// autodiff-of-Mul gradient (b, a).
20870    #[test]
20871    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
20872        use rlx_opt::autodiff::grad_with_loss;
20873        let s = Shape::new(&[1], DType::F32);
20874
20875        let mut fwd = Graph::new("mul_fwd");
20876        let a_f = fwd.input("a", s.clone());
20877        let b_f = fwd.input("b", s.clone());
20878        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
20879        fwd.set_outputs(vec![y_f]);
20880
20881        let mut vjp_g = Graph::new("mul_vjp");
20882        let a_v = vjp_g.input("a", s.clone());
20883        let b_v = vjp_g.input("b", s.clone());
20884        let _y_v = vjp_g.input("primal_output", s.clone());
20885        let dy_v = vjp_g.input("d_output", s.clone());
20886        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
20887        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
20888        vjp_g.set_outputs(vec![da, db]);
20889
20890        let mut g = Graph::new("outer");
20891        let ap = g.param("a", s.clone());
20892        let bp = g.param("b", s.clone());
20893        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
20894        g.set_outputs(vec![cf]);
20895
20896        let bwd = grad_with_loss(&g, &[ap, bp]);
20897        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
20898
20899        let ab = find_named(&bwd, "a");
20900        let bb = find_named(&bwd, "b");
20901        let dout = find_named(&bwd, "d_output");
20902        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
20903        execute_thunks(&sched, arena.raw_buf_mut());
20904        let loss = read_arena(&arena, bwd.outputs[0], 1);
20905        let da_v = read_arena(&arena, bwd.outputs[1], 1);
20906        let db_v = read_arena(&arena, bwd.outputs[2], 1);
20907        assert!((loss[0] - 15.0).abs() < 1e-5);
20908        assert!(
20909            (da_v[0] - 5.0).abs() < 1e-5,
20910            "da should be b=5.0, got {}",
20911            da_v[0]
20912        );
20913        assert!(
20914            (db_v[0] - 3.0).abs() < 1e-5,
20915            "db should be a=3.0, got {}",
20916            db_v[0]
20917        );
20918    }
20919
20920    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
20921    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
20922    #[test]
20923    fn custom_fn_jvp_overrides_natural_tangent() {
20924        use rlx_opt::autodiff_fwd::jvp;
20925        let s = Shape::new(&[1], DType::F32);
20926
20927        let mut fwd = Graph::new("id_fwd");
20928        let x = fwd.input("x", s.clone());
20929        fwd.set_outputs(vec![x]);
20930
20931        let mut jvp_g = Graph::new("id_jvp");
20932        let _x_p = jvp_g.input("x", s.clone());
20933        let tx = jvp_g.input("tangent_0", s.clone());
20934        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20935        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20936        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
20937        jvp_g.set_outputs(vec![ty]);
20938
20939        let mut g = Graph::new("outer");
20940        let xin = g.input("x_in", s.clone());
20941        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
20942        g.set_outputs(vec![cf]);
20943
20944        let fwd_g = jvp(&g, &[xin]);
20945        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
20946
20947        let xb = find_named(&fwd_g, "x_in");
20948        let tan = find_named(&fwd_g, "tangent_x_in");
20949        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
20950        execute_thunks(&sched, arena.raw_buf_mut());
20951        let y = read_arena(&arena, fwd_g.outputs[0], 1);
20952        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
20953        assert!((y[0] - 7.0).abs() < 1e-6);
20954        assert!(
20955            (ty_v[0] - 2.0).abs() < 1e-6,
20956            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
20957            ty_v[0]
20958        );
20959    }
20960
20961    /// IR-level basic test: `DType::C64` is wired through the dtype
20962    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
20963    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
20964    /// 16 bytes.
20965    #[test]
20966    fn c64_dtype_storage_layout() {
20967        assert_eq!(
20968            DType::C64.size_bytes(),
20969            8,
20970            "C64 should be 8 bytes (f32 real + f32 imag)"
20971        );
20972        assert!(DType::C64.is_complex());
20973        assert!(!DType::C64.is_float());
20974
20975        // A length-2 C64 buffer should have shape size_bytes = 16.
20976        let s = Shape::new(&[2], DType::C64);
20977        assert_eq!(s.size_bytes().unwrap(), 16);
20978    }
20979
20980    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
20981    //
20982    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
20983    // output = a OP b. Run through CompileResult and compare against
20984    // the closed-form complex arithmetic on the four chosen pairs.
20985
20986    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
20987        let n = a.len();
20988        let s = Shape::new(&[n], DType::C64);
20989        let mut g = Graph::new("c64_bin");
20990        let in_a = g.input("a", s.clone());
20991        let in_b = g.input("b", s.clone());
20992        let out = g.binary(op, in_a, in_b, s.clone());
20993        g.set_outputs(vec![out]);
20994
20995        let plan = rlx_opt::memory::plan_memory(&g);
20996        let mut arena = crate::arena::Arena::from_plan(plan);
20997        let sched = compile_thunks(&g, &arena);
20998
20999        let a_off = arena.byte_offset(in_a);
21000        let b_off = arena.byte_offset(in_b);
21001        let out_off = arena.byte_offset(out);
21002        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
21003        let buf = arena.raw_buf_mut();
21004        unsafe {
21005            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21006            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21007            for (i, &(re, im)) in a.iter().enumerate() {
21008                *pa.add(2 * i) = re;
21009                *pa.add(2 * i + 1) = im;
21010            }
21011            for (i, &(re, im)) in b.iter().enumerate() {
21012                *pb.add(2 * i) = re;
21013                *pb.add(2 * i + 1) = im;
21014            }
21015        }
21016        execute_thunks(&sched, arena.raw_buf_mut());
21017        let raw_out: Vec<f32> = unsafe {
21018            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21019            (0..(2 * n)).map(|i| *p.add(i)).collect()
21020        };
21021        (0..n)
21022            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21023            .collect()
21024    }
21025
21026    #[track_caller]
21027    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21028        let dr = (got.0 - expected.0).abs();
21029        let di = (got.1 - expected.1).abs();
21030        assert!(
21031            dr < tol && di < tol,
21032            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21033            got.0,
21034            got.1,
21035            expected.0,
21036            expected.1
21037        );
21038    }
21039
21040    #[test]
21041    fn c64_binary_add_matches_complex_arithmetic() {
21042        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21043        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21044        let out = run_c64_binary(BinaryOp::Add, &a, &b);
21045        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21046        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21047    }
21048
21049    #[test]
21050    fn c64_binary_sub_matches_complex_arithmetic() {
21051        let a = [(5.0_f32, 1.0_f32)];
21052        let b = [(2.0_f32, 3.0_f32)];
21053        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21054        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21055    }
21056
21057    #[test]
21058    fn c64_binary_mul_matches_complex_arithmetic() {
21059        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
21060        let a = [(1.0_f32, 2.0_f32)];
21061        let b = [(3.0_f32, 4.0_f32)];
21062        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21063        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21064    }
21065
21066    #[test]
21067    fn c64_binary_div_matches_complex_arithmetic() {
21068        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
21069        //                     = (11 + 2i) / 25
21070        //                     = 0.44 + 0.08i
21071        let a = [(1.0_f32, 2.0_f32)];
21072        let b = [(3.0_f32, 4.0_f32)];
21073        let out = run_c64_binary(BinaryOp::Div, &a, &b);
21074        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
21075    }
21076
21077    #[test]
21078    fn c64_binary_mul_identity_one_is_no_op() {
21079        // (a + bi) · (1 + 0i) = a + bi.
21080        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
21081        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
21082        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21083        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
21084        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
21085    }
21086
21087    #[test]
21088    fn c64_binary_mul_by_i_rotates_90_degrees() {
21089        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
21090        let a = [(1.0_f32, 0.0_f32)];
21091        let b = [(0.0_f32, 1.0_f32)];
21092        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21093        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
21094    }
21095
21096    #[test]
21097    fn c64_binary_div_by_self_gives_unity() {
21098        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
21099        let out = run_c64_binary(BinaryOp::Div, &a, &a);
21100        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
21101        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
21102    }
21103
21104    #[test]
21105    #[should_panic(expected = "C64: complex max/min/pow")]
21106    fn c64_binary_max_is_rejected_at_lowering() {
21107        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
21108    }
21109
21110    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
21111        let n = a.len();
21112        let s = Shape::new(&[n], DType::C64);
21113        let mut g = Graph::new("c64_act");
21114        let in_a = g.input("a", s.clone());
21115        let out = g.activation(act, in_a, s.clone());
21116        g.set_outputs(vec![out]);
21117        let plan = rlx_opt::memory::plan_memory(&g);
21118        let mut arena = crate::arena::Arena::from_plan(plan);
21119        let sched = compile_thunks(&g, &arena);
21120        let a_off = arena.byte_offset(in_a);
21121        let out_off = arena.byte_offset(out);
21122        let buf = arena.raw_buf_mut();
21123        unsafe {
21124            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21125            for (i, &(re, im)) in a.iter().enumerate() {
21126                *pa.add(2 * i) = re;
21127                *pa.add(2 * i + 1) = im;
21128            }
21129        }
21130        execute_thunks(&sched, arena.raw_buf_mut());
21131        let raw: Vec<f32> = unsafe {
21132            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21133            (0..(2 * n)).map(|i| *p.add(i)).collect()
21134        };
21135        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
21136    }
21137
21138    #[test]
21139    fn c64_activation_neg_negates_both_components() {
21140        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
21141        let out = run_c64_activation(Activation::Neg, &inp);
21142        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
21143        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
21144    }
21145
21146    #[test]
21147    fn c64_activation_exp_matches_euler() {
21148        // exp(0 + i·π) = -1 + 0i.
21149        // exp(1 + 0i) = e ≈ 2.71828.
21150        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
21151        let out = run_c64_activation(Activation::Exp, &inp);
21152        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
21153        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
21154    }
21155
21156    #[test]
21157    fn c64_activation_log_matches_principal_branch() {
21158        // log(1 + 0i) = 0.
21159        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
21160        // log(-1 + 0i) = 0 + i·π.
21161        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
21162        let out = run_c64_activation(Activation::Log, &inp);
21163        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
21164        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
21165        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
21166    }
21167
21168    #[test]
21169    fn c64_activation_sqrt_squared_recovers_input() {
21170        // For positive-real-part inputs, sqrt(z)² should equal z exactly
21171        // to f32 noise.
21172        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
21173        let roots = run_c64_activation(Activation::Sqrt, &inp);
21174        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
21175        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
21176        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
21177    }
21178
21179    #[test]
21180    #[should_panic(expected = "no natural complex extension")]
21181    fn c64_activation_relu_is_rejected_at_lowering() {
21182        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
21183    }
21184
21185    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
21186
21187    /// Forward `|z|²`: returns `[n]` f32.
21188    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
21189        let n = z.len();
21190        let mut g = Graph::new("cns_fwd");
21191        let in_z = g.input("z", Shape::new(&[n], DType::C64));
21192        let out = g.complex_norm_sq(in_z);
21193        g.set_outputs(vec![out]);
21194        let plan = rlx_opt::memory::plan_memory(&g);
21195        let mut arena = crate::arena::Arena::from_plan(plan);
21196        let sched = compile_thunks(&g, &arena);
21197        let z_off = arena.byte_offset(in_z);
21198        let out_off = arena.byte_offset(out);
21199        let buf = arena.raw_buf_mut();
21200        unsafe {
21201            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21202            for (i, &(re, im)) in z.iter().enumerate() {
21203                *pz.add(2 * i) = re;
21204                *pz.add(2 * i + 1) = im;
21205            }
21206        }
21207        execute_thunks(&sched, arena.raw_buf_mut());
21208        unsafe {
21209            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21210            (0..n).map(|i| *p.add(i)).collect()
21211        }
21212    }
21213
21214    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
21215    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
21216        let n = z.len();
21217        let mut gr = Graph::new("cns_bwd");
21218        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
21219        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
21220        let out = gr.complex_norm_sq_backward(in_z, in_g);
21221        gr.set_outputs(vec![out]);
21222        let plan = rlx_opt::memory::plan_memory(&gr);
21223        let mut arena = crate::arena::Arena::from_plan(plan);
21224        let sched = compile_thunks(&gr, &arena);
21225        let z_off = arena.byte_offset(in_z);
21226        let g_off = arena.byte_offset(in_g);
21227        let out_off = arena.byte_offset(out);
21228        let buf = arena.raw_buf_mut();
21229        unsafe {
21230            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21231            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
21232            for (i, &(re, im)) in z.iter().enumerate() {
21233                *pz.add(2 * i) = re;
21234                *pz.add(2 * i + 1) = im;
21235            }
21236            for (i, &v) in g.iter().enumerate() {
21237                *pg.add(i) = v;
21238            }
21239        }
21240        execute_thunks(&sched, arena.raw_buf_mut());
21241        unsafe {
21242            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21243            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
21244        }
21245    }
21246
21247    #[test]
21248    fn complex_norm_sq_matches_textbook() {
21249        // |3 + 4i|² = 9 + 16 = 25.
21250        // |1 + 0i|² = 1.
21251        // |0 + 0i|² = 0.
21252        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
21253        let out = run_complex_norm_sq(&z);
21254        assert!((out[0] - 25.0).abs() < 1e-5);
21255        assert!((out[1] - 1.0).abs() < 1e-6);
21256        assert!(out[2].abs() < 1e-6);
21257    }
21258
21259    #[test]
21260    fn complex_norm_sq_backward_matches_wirtinger_formula() {
21261        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
21262        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
21263        let g = [1.0_f32, 1.0_f32];
21264        let dz = run_complex_norm_sq_bwd(&z, &g);
21265        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
21266        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
21267    }
21268
21269    #[test]
21270    fn complex_norm_sq_backward_scales_with_upstream() {
21271        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
21272        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
21273        let g = [0.5_f32, -2.0_f32];
21274        let dz = run_complex_norm_sq_bwd(&z, &g);
21275        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
21276        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
21277    }
21278
21279    /// Multi-output Op::CustomFn via the concat-with-Narrow design
21280    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
21281    /// fwd_body returns two outputs (x², 2x), then materialize each
21282    /// via the MultiOutputHandle and verify both numerically.
21283    #[test]
21284    fn custom_fn_multi_extracts_each_subgraph_output() {
21285        use rlx_ir::ops::special::MultiOutputHandle;
21286
21287        let _ = MultiOutputHandle {
21288            source: NodeId(0),
21289            sub_shapes: vec![],
21290            offsets: vec![],
21291        }; // import sanity
21292
21293        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
21294        let mut body = Graph::new("multi_body");
21295        let s3 = Shape::new(&[3], DType::F32);
21296        let x = body.input("x", s3.clone());
21297        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
21298        let two = body.add_node(
21299            Op::Constant {
21300                data: vec![
21301                    2.0_f32.to_le_bytes(),
21302                    2.0_f32.to_le_bytes(),
21303                    2.0_f32.to_le_bytes(),
21304                ]
21305                .into_iter()
21306                .flatten()
21307                .collect(),
21308            },
21309            vec![],
21310            s3.clone(),
21311        );
21312        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
21313        body.set_outputs(vec![x_sq, two_x]);
21314
21315        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
21316        let mut outer = Graph::new("multi_outer");
21317        let in_x = outer.input("xin", s3.clone());
21318        let handle = outer.custom_fn_multi(vec![in_x], body);
21319        assert_eq!(handle.n_outputs(), 2);
21320        let out0 = handle.output(&mut outer, 0); // x²
21321        let out1 = handle.output(&mut outer, 1); // 2x
21322        outer.set_outputs(vec![out0, out1]);
21323
21324        let plan = rlx_opt::memory::plan_memory(&outer);
21325        let mut arena = crate::arena::Arena::from_plan(plan);
21326        let sched = compile_thunks(&outer, &arena);
21327        let xin_off = arena.byte_offset(in_x);
21328        let out0_off = arena.byte_offset(out0);
21329        let out1_off = arena.byte_offset(out1);
21330        let xs = [1.0_f32, 2.0, 3.0];
21331        unsafe {
21332            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
21333            for (i, &v) in xs.iter().enumerate() {
21334                *p.add(i) = v;
21335            }
21336        }
21337        execute_thunks(&sched, arena.raw_buf_mut());
21338        let out0_v: Vec<f32> = unsafe {
21339            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
21340            (0..3).map(|i| *p.add(i)).collect()
21341        };
21342        let out1_v: Vec<f32> = unsafe {
21343            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
21344            (0..3).map(|i| *p.add(i)).collect()
21345        };
21346        // x² = [1, 4, 9]; 2x = [2, 4, 6].
21347        for i in 0..3 {
21348            assert!(
21349                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
21350                "out0[{i}] = {} != x² = {}",
21351                out0_v[i],
21352                xs[i] * xs[i]
21353            );
21354            assert!(
21355                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
21356                "out1[{i}] = {} != 2x = {}",
21357                out1_v[i],
21358                2.0 * xs[i]
21359            );
21360        }
21361    }
21362
21363    #[test]
21364    fn complex_norm_sq_gradient_matches_finite_difference() {
21365        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
21366        let z = [(3.0_f32, 4.0_f32)];
21367        let eps = 1e-3_f32;
21368        let v0 = run_complex_norm_sq(&z)[0];
21369        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
21370        let v1 = run_complex_norm_sq(&z_pert)[0];
21371        let fd_re = (v1 - v0) / eps;
21372        let analytic_re = 2.0 * z[0].0;
21373        assert!((fd_re - analytic_re).abs() < 1e-2);
21374
21375        // ∂/∂im at z = (3, 4) is 2·im = 8.
21376        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
21377        let v2 = run_complex_norm_sq(&z_pert_im)[0];
21378        let fd_im = (v2 - v0) / eps;
21379        let analytic_im = 2.0 * z[0].1;
21380        assert!((fd_im - analytic_im).abs() < 1e-2);
21381
21382        // Compare with the Wirtinger backward at upstream g = 1.
21383        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
21384        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
21385        // 2·im) — that's the factor 2 difference between Wirtinger
21386        // ∂/∂z̄ and the real-vector gradient on (re, im).
21387        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
21388        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
21389        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
21390    }
21391
21392    /// Direct regression test for the 5-D mid-shape singleton broadcast
21393    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
21394    /// The SAM port worked around this by `concat`-tiling the rhs; this
21395    /// test verifies the in-graph broadcast path is bit-correct.
21396    #[test]
21397    fn binary_full_5d_mid_singleton_broadcast() {
21398        let bh = 2usize;
21399        let h = 3;
21400        let w = 4;
21401        let f = DType::F32;
21402
21403        let mut g = Graph::new("bcast_5d");
21404        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
21405        // rhs shape with size-1 at axis 3 (mid-shape singleton).
21406        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
21407        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
21408        g.set_outputs(vec![out]);
21409
21410        // Deterministic data.
21411        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
21412        let rhs_data: Vec<f32> = (0..bh * h * w * w)
21413            .map(|i| (i as f32 + 100.0) * 0.01)
21414            .collect();
21415
21416        // Compute expected output by hand.
21417        let mut expected = vec![0f32; bh * h * w * h * w];
21418        for b_ in 0..bh {
21419            for hq in 0..h {
21420                for wq in 0..w {
21421                    for hk in 0..h {
21422                        for wk in 0..w {
21423                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
21424                            // rhs has hk dim = 1, so it's always index 0 there.
21425                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
21426                            expected[li] = lhs_data[li] + rhs_data[ri];
21427                        }
21428                    }
21429                }
21430            }
21431        }
21432
21433        let plan = rlx_opt::memory::plan_memory(&g);
21434        let mut arena = crate::arena::Arena::from_plan(plan);
21435        let sched = compile_thunks(&g, &arena);
21436        let lhs_off = arena.byte_offset(lhs);
21437        let rhs_off = arena.byte_offset(rhs);
21438        let out_off = arena.byte_offset(out);
21439        let buf = arena.raw_buf_mut();
21440        unsafe {
21441            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
21442            for (i, &v) in lhs_data.iter().enumerate() {
21443                *p.add(i) = v;
21444            }
21445            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
21446            for (i, &v) in rhs_data.iter().enumerate() {
21447                *p.add(i) = v;
21448            }
21449        }
21450        execute_thunks(&sched, arena.raw_buf_mut());
21451        let actual: Vec<f32> = unsafe {
21452            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21453            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
21454        };
21455
21456        // Bit-exact check.
21457        let mut max_diff = 0f32;
21458        let mut max_idx = 0;
21459        for i in 0..actual.len() {
21460            let d = (actual[i] - expected[i]).abs();
21461            if d > max_diff {
21462                max_diff = d;
21463                max_idx = i;
21464            }
21465        }
21466        assert!(
21467            max_diff < 1e-6,
21468            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
21469             (actual={}, expected={})",
21470            actual[max_idx],
21471            expected[max_idx]
21472        );
21473    }
21474
21475    #[test]
21476    fn layer_norm2d_and_conv_transpose2d_kernels() {
21477        let mut out = vec![0f32; 8];
21478        crate::kernels::layer_norm2d_nchw(
21479            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
21480            &[1.0, 1.0],
21481            &[0.0, 0.0],
21482            &mut out,
21483            1,
21484            2,
21485            2,
21486            2,
21487            1e-5,
21488        );
21489        let mean0: f32 = (1.0 + 3.0) / 2.0;
21490        assert!((out[0] - mean0).abs() > 0.1);
21491
21492        let mut up = vec![0f32; 4];
21493        crate::kernels::conv_transpose2d_nchw(
21494            &[2.0],
21495            &[1.0, 0.0, 0.0, 1.0],
21496            &mut up,
21497            1,
21498            1,
21499            1,
21500            1,
21501            1,
21502            2,
21503            2,
21504            2,
21505            2,
21506            2,
21507            2,
21508            0,
21509            0,
21510            1,
21511            1,
21512            1,
21513        );
21514        assert!((up[0] - 2.0).abs() < 1e-5);
21515        assert!((up[3] - 2.0).abs() < 1e-5);
21516    }
21517}