Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
48    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
49    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
50    /// The kernel materializes scratch copies of A and b internally
51    /// (LAPACK overwrites both with LU factors and solution).
52    DenseSolveF64 {
53        a: usize,
54        b: usize,
55        x: usize,
56        n: u32,
57        nrhs: u32,
58    },
59    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
60    /// no-blas Rust fallback). Same arena byte-offset contract.
61    DenseSolveF32 {
62        a: usize,
63        b: usize,
64        x: usize,
65        n: u32,
66        nrhs: u32,
67    },
68    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
69    /// the leading slice; `batch` is the number of independent
70    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
71    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
72    BatchedDenseSolveF64 {
73        a: usize,
74        b: usize,
75        x: usize,
76        batch: u32,
77        n: u32,
78        nrhs: u32,
79    },
80    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
81    BatchedDenseSolveF32 {
82        a: usize,
83        b: usize,
84        x: usize,
85        batch: u32,
86        n: u32,
87        nrhs: u32,
88    },
89    /// Batched f64 matmul. Both inputs and output have a leading
90    /// batch axis of size `batch`. Per-batch independent dgemm:
91    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
92    /// that emit per-batch outer products (e.g., BatchedDenseSolve
93    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
94    BatchedDgemmF64 {
95        a: usize,
96        b: usize,
97        c: usize,
98        batch: u32,
99        m: u32,
100        k: u32,
101        n: u32,
102    },
103    /// Batched f32 matmul — same loop-per-batch shape as
104    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
105    /// patterns where both operands carry a batch dim (e.g. q@k^T
106    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
107    /// flatten trick is wrong in that case because it treats `b` as
108    /// a single shared RHS across every batch.
109    BatchedSgemm {
110        a: usize,
111        b: usize,
112        c: usize,
113        batch: u32,
114        m: u32,
115        k: u32,
116        n: u32,
117    },
118    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
119    Dgemm {
120        a: usize,
121        b: usize,
122        c: usize,
123        m: u32,
124        k: u32,
125        n: u32,
126    },
127    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
128    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
129    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
130    TransposeF64 {
131        src: usize,
132        dst: usize,
133        in_total: u32,
134        out_dims: Vec<u32>,
135        in_strides: Vec<u32>,
136    },
137    /// f64 element-wise activation. Single-input, single-output. The
138    /// kernel always reads from `src` and writes to `dst`, so it works
139    /// whether or not the planner aliased the two slots.
140    ActivationF64 {
141        src: usize,
142        dst: usize,
143        len: u32,
144        kind: Activation,
145    },
146    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
147    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
148    /// writes `len` f32 to `dst`.
149    ComplexNormSqF32 {
150        src: usize,
151        dst: usize,
152        /// Logical element count (number of complex values).
153        len: u32,
154    },
155    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
156    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
157    /// `2·len` f32 to `dz`.
158    ComplexNormSqBackwardF32 {
159        z: usize,
160        g: usize,
161        dz: usize,
162        len: u32,
163    },
164    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
165    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
166    ConjugateC64 { src: usize, dst: usize, len: u32 },
167    /// C64 element-wise activation. Only kinds with well-defined
168    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
169    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
170    /// Round, GeLU family) is rejected at lowering — those don't have
171    /// single natural complex definitions. `len` is the **complex
172    /// element count** (the f32 buffer holds `2·len` floats).
173    ActivationC64 {
174        src: usize,
175        dst: usize,
176        len: u32,
177        kind: Activation,
178    },
179    /// f64 contiguous reduction along a single axis range. Layout
180    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
181    /// Sum only for now (Mean composes via 1/N multiply post-pass).
182    ReduceSumF64 {
183        src: usize,
184        dst: usize,
185        outer: u32,
186        reduced: u32,
187        inner: u32,
188    },
189    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
190    /// but at 8 bytes per element.
191    CopyF64 { src: usize, dst: usize, len: u32 },
192    /// i64 element copy (Reshape/Cast on i64 tensors).
193    CopyI64 { src: usize, dst: usize, len: u32 },
194    /// Round f32 → i64 (ONNX Cast on duration scalar).
195    CastF32ToI64 { src: usize, dst: usize, len: u32 },
196    /// i64 → f32 (ONNX Cast on shape scalars, e.g. Albert head-dim).
197    CastI64ToF32 { src: usize, dst: usize, len: u32 },
198    /// bool → i32 (BERT attention mask grid).
199    CastBoolToI32 { src: usize, dst: usize, len: u32 },
200    /// i32 → f32 (BERT attention mask cast before subtract).
201    CastI32ToF32 { src: usize, dst: usize, len: u32 },
202    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
203    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
204    /// Mirror of `BinaryFull` at 8 bytes per element.
205    BinaryFullF64 {
206        lhs: usize,
207        rhs: usize,
208        dst: usize,
209        len: u32,
210        lhs_len: u32,
211        rhs_len: u32,
212        op: BinaryOp,
213        /// Output shape dims (row-major). Empty in the fast path. See
214        /// `BinaryFull` doc for the broadcast convention.
215        out_dims_bcast: Vec<u32>,
216        bcast_lhs_strides: Vec<u32>,
217        bcast_rhs_strides: Vec<u32>,
218    },
219    /// f64 concat — byte-for-byte mirror of `Concat` but copies
220    /// 8 bytes per element. Element-counted offsets/strides match
221    /// the f32 variant; the executor scales by elem_size internally.
222    ConcatF64 {
223        dst: usize,
224        outer: u32,
225        inner: u32,
226        total_axis: u32,
227        inputs: Vec<(usize, u32)>,
228    },
229    /// C64 element-wise binary with broadcast. Same `len` /
230    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
231    /// "element" is one complex value (8 bytes = `[re, im]` as two
232    /// f32s). The executor reads the underlying f32 buffer at
233    /// `2·len` floats and walks element pairs. Supports Add / Sub /
234    /// Mul / Div; Max / Min / Pow have no single natural complex
235    /// definition and panic at lowering.
236    BinaryFullC64 {
237        lhs: usize,
238        rhs: usize,
239        dst: usize,
240        /// Complex element count (NOT f32 count). f32 buffer length
241        /// is `2·len`.
242        len: u32,
243        lhs_len: u32,
244        rhs_len: u32,
245        op: BinaryOp,
246        out_dims_bcast: Vec<u32>,
247        bcast_lhs_strides: Vec<u32>,
248        bcast_rhs_strides: Vec<u32>,
249    },
250    /// Bounded scan. Holds a recursively-compiled body schedule + a
251    /// pre-initialized body arena snapshot (constants filled). Each
252    /// outer execution clones the snapshot, copies the carry-in slot
253    /// from the outer arena, runs the body schedule `length` times,
254    /// then writes the final carry to the outer arena.
255    ///
256    /// Single-carry MVP — body has exactly one Input and one output,
257    /// both same shape and dtype.
258    Scan {
259        body: Arc<ThunkSchedule>,
260        body_init: Arc<Vec<u8>>, // pristine body arena bytes
261        body_input_off: usize,   // byte offset of the body's carry-Input slot
262        body_output_off: usize,  // byte offset of the body's output slot
263        outer_init_off: usize,   // outer-arena offset of the initial carry
264        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
265        length: u32,
266        carry_bytes: u32, // carry size in bytes
267        /// When true, write each step's carry to the outer arena at
268        /// offset `outer_final_off + t * carry_bytes`, producing a
269        /// `[length, *carry]` stacked trajectory. When false, only the
270        /// final carry lands at `outer_final_off`.
271        save_trajectory: bool,
272        /// Per-step `xs` inputs. For each: (body_x_input_off,
273        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
274        /// executor copies `outer_xs_base_off + t * per_step_bytes`
275        /// into `body_x_input_off`. Empty when the scan has no xs.
276        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
277        /// Broadcast inputs — values constant across iterations. For
278        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
279        /// Filled into `body_buf` ONCE before the scan loop starts
280        /// (xs in contrast are re-filled every iteration). Empty when
281        /// the scan has no bcasts.
282        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
283        /// Number of trajectory checkpoints (when `save_trajectory`).
284        /// `0` or `length` ⇒ save every iteration. Otherwise save only
285        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
286        /// `k in 0..K`. Last index is always `length-1` so the final
287        /// carry is always cached.
288        num_checkpoints: u32,
289    },
290
291    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
292    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
293    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
294    /// slice, and the current `dcarry` into the body_vjp's Input
295    /// slots, runs body_vjp, reads new `dcarry` from its single output.
296    /// f64 carry only — the upstream-accumulation step in trajectory
297    /// mode does an element-wise f64 add.
298    ScanBackward {
299        body_vjp: Arc<ThunkSchedule>,
300        body_init: Arc<Vec<u8>>,
301        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
302        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
303        body_d_output_off: usize, // body_vjp's "d_output" Input slot
304        body_dcarry_out_off: usize, // body_vjp's gradient output
305        outer_init_off: usize,    // original init carry
306        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
307        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
308        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
309        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
310        outer_xs_offs: Arc<Vec<(usize, u32)>>,
311        outer_dinit_off: usize, // output: dinit
312        length: u32,
313        carry_bytes: u32,
314        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
315        /// Used to dispatch the trajectory-mode upstream accumulation
316        /// kernel (the dcarry += upstream\[t\] add must use the right
317        /// floating-point type — a hard-coded f64 add silently does
318        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
319        carry_elem_size: u32,
320        save_trajectory: bool, // true → upstream is per-step; false → just final
321        /// Recursive checkpointing config. `0` or `length` ⇒ full
322        /// trajectory cached, no recompute (existing behavior).
323        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
324        /// recomputes intermediate carries via `forward_body` between
325        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
326        num_checkpoints: u32,
327        /// Forward body schedule (same compiled body as the forward
328        /// Op::Scan), used for recompute when `num_checkpoints` is
329        /// active. `None` for the All strategy.
330        forward_body: Option<Arc<ThunkSchedule>>,
331        /// Pristine forward body arena bytes (constants filled).
332        forward_body_init: Option<Arc<Vec<u8>>>,
333        /// Forward body's carry-Input and output slot offsets — needed
334        /// to seed/read the body during recompute.
335        forward_body_carry_in_off: usize,
336        forward_body_output_off: usize,
337        /// Forward body's per-step xs Input slots (one per outer xs).
338        /// Same indexing convention as `body_x_offs`.
339        forward_body_x_offs: Arc<Vec<usize>>,
340    },
341
342    /// Companion to `ScanBackward` that materializes one stacked
343    /// `dxs_i`. Same backward loop; per iteration, after running
344    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
345    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
346    /// is identical — we still need it for the body_vjp recurrence
347    /// even though we don't write it back to the outer arena.
348    ScanBackwardXs {
349        body_vjp: Arc<ThunkSchedule>,
350        body_init: Arc<Vec<u8>>,
351        body_carry_in_off: usize,
352        body_x_offs: Arc<Vec<usize>>,
353        body_d_output_off: usize,
354        body_dcarry_out_off: usize,
355        body_dxs_out_off: usize, // the body_vjp output we extract per step
356        outer_init_off: usize,
357        outer_traj_off: usize,
358        outer_upstream_off: usize,
359        outer_xs_offs: Arc<Vec<(usize, u32)>>,
360        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
361        length: u32,
362        carry_bytes: u32,
363        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
364        carry_elem_size: u32,
365        per_step_bytes: u32, // bytes per row of the dxs output
366        save_trajectory: bool,
367        /// Recursive checkpointing config. Same semantics as
368        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
369        /// means "save every step's carry"; `0 < K < length` means
370        /// the trajectory has only K rows and the executor recomputes
371        /// intermediate carries via `forward_body` (which must be
372        /// `Some`). Implemented via segment-cached recompute,
373        /// mirroring the `ScanBackward` path.
374        num_checkpoints: u32,
375        forward_body: Option<Arc<ThunkSchedule>>,
376        forward_body_init: Option<Arc<Vec<u8>>>,
377        forward_body_carry_in_off: usize,
378        forward_body_output_off: usize,
379        forward_body_x_offs: Arc<Vec<usize>>,
380    },
381    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
382    /// Per execution: clone `body_init`, copy each primal input from the
383    /// outer arena into its body Input slot, run the body schedule,
384    /// copy the body's single output back to the outer arena.
385    CustomFn {
386        body: Arc<ThunkSchedule>,
387        body_init: Arc<Vec<u8>>,
388        /// Per primal input: (body_input_off, outer_input_off, bytes).
389        inputs: Arc<Vec<(usize, usize, u32)>>,
390        body_output_off: usize,
391        outer_output_off: usize,
392        out_bytes: u32,
393    },
394    /// C = A @ B; C += bias; C = act(C)
395    FusedMmBiasAct {
396        a: usize,
397        w: usize,
398        bias: usize,
399        c: usize,
400        m: u32,
401        k: u32,
402        n: u32,
403        act: Option<Activation>,
404    },
405    /// out = LN(x + residual + bias, gamma, beta)
406    FusedResidualLN {
407        x: usize,
408        res: usize,
409        bias: usize,
410        g: usize,
411        b: usize,
412        out: usize,
413        rows: u32,
414        h: u32,
415        eps: f32,
416        has_bias: bool,
417    },
418    /// out = RmsNorm(x + residual + bias, gamma, beta)
419    FusedResidualRmsNorm {
420        x: usize,
421        res: usize,
422        bias: usize,
423        g: usize,
424        b: usize,
425        out: usize,
426        rows: u32,
427        h: u32,
428        eps: f32,
429        has_bias: bool,
430    },
431    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
432    BiasAdd {
433        src: usize,
434        bias: usize,
435        dst: usize,
436        m: u32,
437        n: u32,
438    },
439    /// Element-wise binary op with NumPy-style broadcast.
440    ///
441    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
442    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
443    /// are unused.
444    ///
445    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
446    /// `bcast_rhs_strides` to compute per-cell indices into each
447    /// operand. The strides are precomputed at thunk-construction
448    /// time from the operands' true shapes (with stride 0 on any axis
449    /// where the operand has size 1). This is the only correct way
450    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
451    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
452    /// wrong cells.
453    BinaryFull {
454        lhs: usize,
455        rhs: usize,
456        dst: usize,
457        len: u32,
458        lhs_len: u32,
459        rhs_len: u32,
460        op: BinaryOp,
461        /// Output shape dims (row-major). Empty in the fast path.
462        out_dims_bcast: Vec<u32>,
463        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
464        bcast_lhs_strides: Vec<u32>,
465        /// Per-dim stride into `rhs`.
466        bcast_rhs_strides: Vec<u32>,
467        /// Element size (4 = F32, 8 = I64).
468        elem_bytes: u8,
469    },
470    /// Activation in-place
471    ActivationInPlace {
472        data: usize,
473        len: u32,
474        act: Activation,
475    },
476    /// Gather axis=0: table\[idx\] → out
477    Gather {
478        table: usize,
479        table_len: u32,
480        idx: usize,
481        dst: usize,
482        num_idx: u32,
483        trailing: u32,
484        /// 1 when the index tensor is i64 (ONNX Gather indices).
485        idx_i64: u8,
486        /// Element size of table/output (4 = f32, 8 = i64).
487        table_bytes: u8,
488    },
489    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
490    Narrow {
491        src: usize,
492        dst: usize,
493        outer: u32,
494        src_stride: u32,
495        dst_stride: u32,
496        inner: u32,
497        elem_bytes: u8,
498    },
499    /// Copy (reshape, expand)
500    Copy { src: usize, dst: usize, len: u32 },
501    /// LayerNorm standalone
502    LayerNorm {
503        src: usize,
504        g: usize,
505        b: usize,
506        dst: usize,
507        rows: u32,
508        h: u32,
509        eps: f32,
510    },
511    /// GroupNorm on NCHW `[N,C,H,W]`.
512    GroupNorm {
513        src: usize,
514        g: usize,
515        b: usize,
516        dst: usize,
517        n: u32,
518        c: u32,
519        h: u32,
520        w: u32,
521        num_groups: u32,
522        eps: f32,
523    },
524    /// BatchNorm inference: frozen mean/var, feature axis last.
525    BatchNormInference {
526        src: usize,
527        g: usize,
528        b: usize,
529        mean: usize,
530        var: usize,
531        dst: usize,
532        count: u32,
533        channels: u32,
534        eps: f32,
535    },
536    BatchNormInferenceBackwardInput {
537        x: usize,
538        gamma: usize,
539        mean: usize,
540        var: usize,
541        dy: usize,
542        dx: usize,
543        count: u32,
544        channels: u32,
545        eps: f32,
546    },
547    BatchNormInferenceBackwardGamma {
548        x: usize,
549        mean: usize,
550        var: usize,
551        dy: usize,
552        dgamma: usize,
553        count: u32,
554        channels: u32,
555        eps: f32,
556    },
557    BatchNormInferenceBackwardBeta {
558        dy: usize,
559        dbeta: usize,
560        count: u32,
561        channels: u32,
562    },
563    /// LayerNorm2d on NCHW (SAM / candle semantics).
564    LayerNorm2d {
565        src: usize,
566        g: usize,
567        b: usize,
568        dst: usize,
569        n: u32,
570        c: u32,
571        h: u32,
572        w: u32,
573        eps: f32,
574    },
575    /// ConvTranspose2d on NCHW.
576    ConvTranspose2d {
577        src: usize,
578        weight: usize,
579        dst: usize,
580        n: u32,
581        c_in: u32,
582        h: u32,
583        w_in: u32,
584        c_out: u32,
585        h_out: u32,
586        w_out: u32,
587        kh: u32,
588        kw: u32,
589        sh: u32,
590        sw: u32,
591        ph: u32,
592        pw: u32,
593        dh: u32,
594        dw: u32,
595        groups: u32,
596    },
597    /// Nearest 2× upsample on NCHW (per-batch slice).
598    ResizeNearest2x {
599        src: usize,
600        dst: usize,
601        n: u32,
602        c: u32,
603        h: u32,
604        w: u32,
605    },
606    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
607    AxialRope2d {
608        src: usize,
609        dst: usize,
610        batch: u32,
611        seq: u32,
612        hidden: u32,
613        end_x: u32,
614        end_y: u32,
615        head_dim: u32,
616        num_heads: u32,
617        theta: f32,
618        repeat_factor: u32,
619    },
620    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
621    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
622    RmsNorm {
623        src: usize,
624        g: usize,
625        b: usize,
626        dst: usize,
627        rows: u32,
628        h: u32,
629        eps: f32,
630    },
631    /// Softmax
632    Softmax { data: usize, rows: u32, cols: u32 },
633    /// Inclusive (or exclusive) cumulative sum along the last axis
634    /// (callers pre-flatten higher-dim cumsums via reshape views).
635    Cumsum {
636        src: usize,
637        dst: usize,
638        rows: u32,
639        cols: u32,
640        exclusive: bool,
641    },
642    /// Mamba-style selective scan (plan #15).
643    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
644    /// Output: y \[b,s,h\]. State h carries through the seq.
645    SelectiveScan {
646        x: usize,
647        delta: usize,
648        a: usize,
649        b: usize,
650        c: usize,
651        dst: usize,
652        batch: u32,
653        seq: u32,
654        hidden: u32,
655        state_size: u32,
656    },
657
658    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
659    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
660    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
661    GatedDeltaNet {
662        q: usize,
663        k: usize,
664        v: usize,
665        g: usize,
666        beta: usize,
667        /// When non-zero, load initial `[b, h, n, n]` state and write
668        /// the final state back in place after the scan.
669        state: usize,
670        dst: usize,
671        batch: u32,
672        seq: u32,
673        heads: u32,
674        state_size: u32,
675    },
676
677    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
678    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
679    /// groups-1 conv is mathematically a per-batch matmul, and
680    /// dispatching it through BLAS is 3-10× faster than the
681    /// scalar nest. Common case: ViT patch-projection follow-on,
682    /// transformer "expert" reductions in some MoE designs.
683    ///
684    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
685    ///         = output `[c_out, h*w]`.
686    Conv2D1x1 {
687        src: usize,
688        weight: usize,
689        dst: usize,
690        n: u32,
691        c_in: u32,
692        c_out: u32,
693        hw: u32,
694    },
695
696    /// Fused dequant + matmul (plan #5). Today supports
697    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
698    /// at lowering time with a clear message until kernels are added.
699    DequantMatMul {
700        x: usize,
701        w_q: usize,   // packed i8 bytes for Int8 schemes
702        scale: usize, // [k/block, n] f32 scale
703        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
704        dst: usize,
705        m: u32,
706        k: u32,
707        n: u32,
708        block_size: u32,
709        is_asymmetric: bool,
710    },
711
712    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
713    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
714    /// Q8_K). Scales / mins live inside the packed bytes — no
715    /// side-channel scale tensor.
716    ///
717    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
718    /// keeps the *arena* memory footprint down (weights stay packed)
719    /// but the dequant itself happens per matmul. A future fully
720    /// fused tile-streaming kernel would close the compute gap.
721    DequantMatMulGguf {
722        x: usize,   // f32 activations [m, k]
723        w_q: usize, // packed weight bytes (k*n elements packed)
724        dst: usize, // f32 output [m, n]
725        m: u32,
726        k: u32,
727        n: u32,
728        scheme: rlx_ir::quant::QuantScheme,
729    },
730
731    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
732    DequantMatMulInt4 {
733        x: usize,
734        w_q: usize,
735        scale: usize,
736        zp: usize,
737        dst: usize,
738        m: u32,
739        k: u32,
740        n: u32,
741        block_size: u32,
742        is_asymmetric: bool,
743    },
744
745    /// FP8 dequant + matmul (per-tensor or per-column scale).
746    DequantMatMulFp8 {
747        x: usize,
748        w_q: usize,
749        scale: usize,
750        dst: usize,
751        m: u32,
752        k: u32,
753        n: u32,
754        e5m2: bool,
755    },
756
757    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
758    DequantMatMulNvfp4 {
759        x: usize,
760        w_q: usize,
761        scale: usize,
762        global_scale: usize,
763        dst: usize,
764        m: u32,
765        k: u32,
766        n: u32,
767    },
768
769    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
770    /// `r` is the LoRA rank (typically 4-64) — the rank-r
771    /// intermediate `x·A` lives in scratch, never on the arena.
772    LoraMatMul {
773        x: usize,
774        w: usize,
775        a: usize,
776        b: usize,
777        dst: usize,
778        m: u32,
779        k: u32,
780        n: u32,
781        r: u32,
782        scale: f32,
783    },
784    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
785    /// See Op::Sample. Output values are f32-encoded usize indices
786    /// (matches the rest of the IR's "ids as f32" convention).
787    Sample {
788        logits: usize,
789        dst: usize,
790        batch: u32,
791        vocab: u32,
792        top_k: u32,       // 0 = disabled
793        top_p: f32,       // 1.0 = disabled
794        temperature: f32, // 1.0 = neutral
795        seed: u64,
796    },
797    /// Attention SDPA. `mask` is the offset of the optional mask tensor
798    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
799    /// kinds synthesize the mask in-kernel.
800    ///
801    /// Q/K/V each carry a `_row_stride` (elements per source row).
802    /// Defaults to `heads * head_dim` — matches the standalone
803    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
804    /// Attention fusion below rewrites these to the parent QKV stride
805    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
806    /// directly without materializing the per-head buffers (plan #46).
807    Attention {
808        q: usize,
809        k: usize,
810        v: usize,
811        mask: usize,
812        out: usize,
813        batch: u32,
814        /// Query sequence length.
815        seq: u32,
816        /// Key/value sequence length. Differs from `seq` during cached decode.
817        kv_seq: u32,
818        heads: u32,
819        head_dim: u32,
820        mask_kind: rlx_ir::op::MaskKind,
821        q_row_stride: u32,
822        k_row_stride: u32,
823        v_row_stride: u32,
824        /// Memory layout flag. `false` (the historical default) →
825        /// `[B, S, H, D]` row-major: per-head offset is
826        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
827        /// (head-major), matching the convention used by rlx-cuda /
828        /// rlx-rocm / rlx-tpu: per-head offset is
829        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
830        /// from the input shape vs `num_heads` / `head_dim`.
831        bhsd: bool,
832    },
833    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
834    AttentionBackward {
835        q: usize,
836        k: usize,
837        v: usize,
838        dy: usize,
839        mask: usize,
840        out: usize,
841        batch: u32,
842        seq: u32,
843        kv_seq: u32,
844        heads: u32,
845        head_dim: u32,
846        mask_kind: rlx_ir::op::MaskKind,
847        wrt: rlx_ir::op::AttentionBwdWrt,
848        bhsd: bool,
849    },
850    /// RoPE (rotary position embeddings).
851    /// `src_row_stride` is elements per source row (defaults to `hidden`
852    /// for the standalone case; set to `qkv_axis * inner` when the
853    /// thunk fusion pass below rewires Rope to read directly from the
854    /// fused QKV buffer — plan #45).
855    Rope {
856        src: usize,
857        cos: usize,
858        sin: usize,
859        dst: usize,
860        batch: u32,
861        seq: u32,
862        hidden: u32,
863        head_dim: u32,
864        n_rot: u32,
865        cos_len: u32,
866        src_row_stride: u32,
867    },
868    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
869    /// All intermediates stay in L1 cache. Zero arena writes between ops.
870    FusedAttnBlock {
871        hidden: usize,
872        qkv_w: usize,
873        out_w: usize,
874        mask: usize,
875        out: usize,
876        qkv_b: usize,
877        out_b: usize, // 0 = no bias
878        cos: usize,
879        sin: usize,
880        cos_len: u32, // 0 = no RoPE
881        batch: u32,
882        seq: u32,
883        hs: u32,
884        nh: u32,
885        dh: u32,
886        has_bias: bool,
887        has_rope: bool,
888    },
889    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
890    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
891    FusedBertLayer {
892        // attention
893        hidden: usize,
894        qkv_w: usize,
895        qkv_b: usize,
896        out_w: usize,
897        out_b: usize,
898        mask: usize,
899        // LN1
900        ln1_g: usize,
901        ln1_b: usize,
902        eps1: f32,
903        // FFN (GELU)
904        fc1_w: usize,
905        fc1_b: usize,
906        fc2_w: usize,
907        fc2_b: usize,
908        // LN2
909        ln2_g: usize,
910        ln2_b: usize,
911        eps2: f32,
912        // output
913        out: usize,
914        // dims
915        batch: u32,
916        seq: u32,
917        hs: u32,
918        nh: u32,
919        dh: u32,
920        int_dim: u32,
921    },
922    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
923    FusedNomicLayer {
924        hidden: usize,
925        qkv_w: usize,
926        out_w: usize,
927        mask: usize,
928        cos: usize,
929        sin: usize,
930        cos_len: u32,
931        ln1_g: usize,
932        ln1_b: usize,
933        eps1: f32,
934        fc11_w: usize,
935        fc12_w: usize,
936        fc2_w: usize,
937        ln2_g: usize,
938        ln2_b: usize,
939        eps2: f32,
940        out: usize,
941        batch: u32,
942        seq: u32,
943        hs: u32,
944        nh: u32,
945        dh: u32,
946        int_dim: u32,
947    },
948    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
949    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
950    /// Output: [outer, n_half].
951    FusedSwiGLU {
952        src: usize,
953        dst: usize,
954        n_half: u32,
955        total: u32,
956        gate_first: bool,
957    },
958    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
959    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
960    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
961    /// at compile time to avoid per-run shape work.
962    Concat {
963        dst: usize,
964        outer: u32,
965        inner: u32,
966        total_axis: u32,
967        inputs: Vec<(usize, u32)>,
968    },
969    /// Element-wise comparison: out = (lhs CMP rhs) ? 1 : 0 (Bool u8 or F32 0/1).
970    Compare {
971        lhs: usize,
972        rhs: usize,
973        dst: usize,
974        len: u32,
975        op: CmpOp,
976        /// Nonzero when lhs/rhs are i64 (mask/range ops).
977        inputs_i64: u8,
978        /// Input element size (1 = Bool, 4 = F32, 8 = I64).
979        inputs_elem_bytes: u8,
980        /// Output element size (1 = Bool, 4 = F32).
981        dst_elem_bytes: u8,
982    },
983    /// Reduction along a contiguous range of axes. Input layout (after
984    /// shape decomposition) is `[outer, reduced, inner]`; output is
985    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
986    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
987    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
988    /// all map onto this triplet. Non-contiguous axes are not supported
989    /// and bail to Nop in the compile pass.
990    Reduce {
991        src: usize,
992        dst: usize,
993        outer: u32,
994        reduced: u32,
995        inner: u32,
996        op: ReduceOp,
997    },
998    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
999    /// output `[outer, k]` (f32 or i64 per `indices_i64`). Ties broken by
1000    /// smaller index. Used by MoE gating + beam search.
1001    TopK {
1002        src: usize,
1003        dst: usize,
1004        outer: u32,
1005        axis_dim: u32,
1006        k: u32,
1007        indices_i64: u8,
1008    },
1009    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
1010    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
1011    /// segmented GEMM would amortize. Done when there's a workload.
1012    GroupedMatMul {
1013        input: usize,
1014        weight: usize,
1015        expert_idx: usize,
1016        dst: usize,
1017        m: u32,
1018        k_dim: u32,
1019        n: u32,
1020        num_experts: u32,
1021    },
1022    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
1023    DequantGroupedMatMulGguf {
1024        input: usize,
1025        w_q: usize,
1026        expert_idx: usize,
1027        dst: usize,
1028        m: u32,
1029        k_dim: u32,
1030        n: u32,
1031        num_experts: u32,
1032        scheme: rlx_ir::quant::QuantScheme,
1033    },
1034    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
1035    DequantMoEWeightsGguf {
1036        w_q: usize,
1037        dst: usize,
1038        k_dim: u32,
1039        n: u32,
1040        num_experts: u32,
1041        scheme: rlx_ir::quant::QuantScheme,
1042    },
1043    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
1044    /// Output is zeroed first; multiple updates to the same row accumulate.
1045    ScatterAdd {
1046        updates: usize,
1047        indices: usize,
1048        dst: usize,
1049        num_updates: u32,
1050        out_dim: u32,
1051        trailing: u32,
1052    },
1053    /// Ternary select: out = cond != 0 ? on_true : on_false
1054    Where {
1055        cond: usize,
1056        on_true: usize,
1057        on_false: usize,
1058        dst: usize,
1059        len: u32,
1060        elem_bytes: u8,
1061        /// Element size for cond (1 = Bool mask, 4 = F32 0/1).
1062        cond_elem_bytes: u8,
1063    },
1064    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1065    /// i length; `in_strides[i]` is the input stride (in elements) used to
1066    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1067    /// total element count in the source buffer (≤ output total when
1068    /// broadcasting). Strides are pre-computed at compile time.
1069    Transpose {
1070        src: usize,
1071        dst: usize,
1072        in_total: u32,
1073        out_dims: Vec<u32>,
1074        in_strides: Vec<u32>,
1075        elem_bytes: u8,
1076    },
1077    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1078    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1079    /// being indexed into. Output: outer × num_idx × trailing.
1080    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1081    GatherAxis {
1082        table: usize,
1083        idx: usize,
1084        dst: usize,
1085        outer: u32,
1086        axis_dim: u32,
1087        num_idx: u32,
1088        trailing: u32,
1089        idx_i64: u8,
1090        table_bytes: u8,
1091    },
1092    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1093    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1094    /// the full kernel area (matches torch's `count_include_pad=True`).
1095    Pool2D {
1096        src: usize,
1097        dst: usize,
1098        n: u32,
1099        c: u32,
1100        h: u32,
1101        w: u32,
1102        h_out: u32,
1103        w_out: u32,
1104        kh: u32,
1105        kw: u32,
1106        sh: u32,
1107        sw: u32,
1108        ph: u32,
1109        pw: u32,
1110        kind: ReduceOp,
1111    },
1112    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1113    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1114    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1115    /// Naive direct convolution; sufficient for correctness, not optimised.
1116    Conv2D {
1117        src: usize,
1118        weight: usize,
1119        dst: usize,
1120        n: u32,
1121        c_in: u32,
1122        h: u32,
1123        w: u32,
1124        c_out: u32,
1125        h_out: u32,
1126        w_out: u32,
1127        kh: u32,
1128        kw: u32,
1129        sh: u32,
1130        sw: u32,
1131        ph: u32,
1132        pw: u32,
1133        dh: u32,
1134        dw: u32,
1135        groups: u32,
1136    },
1137
1138    // ── Backward / training kernels ─────────────────────────────
1139    /// Real INT8 matmul with i32 accumulation.
1140    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1141    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1142    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1143    /// to a desktop thunk so a quantized graph compiled here doesn't
1144    /// have to round-trip through fake-quant.
1145    QMatMul {
1146        x: usize,
1147        w: usize,
1148        bias: usize,
1149        out: usize,
1150        m: u32,
1151        k: u32,
1152        n: u32,
1153        x_zp: i32,
1154        w_zp: i32,
1155        out_zp: i32,
1156        mult: f32,
1157    },
1158
1159    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1160    /// but with i8 reads, i32 accumulation, and per-output requantize
1161    /// to i8. Bias is i32 in the accumulator scale.
1162    QConv2d {
1163        x: usize,
1164        w: usize,
1165        bias: usize,
1166        out: usize,
1167        n: u32,
1168        c_in: u32,
1169        h: u32,
1170        w_in: u32,
1171        c_out: u32,
1172        h_out: u32,
1173        w_out: u32,
1174        kh: u32,
1175        kw: u32,
1176        sh: u32,
1177        sw: u32,
1178        ph: u32,
1179        pw: u32,
1180        dh: u32,
1181        dw: u32,
1182        groups: u32,
1183        x_zp: i32,
1184        w_zp: i32,
1185        out_zp: i32,
1186        mult: f32,
1187    },
1188
1189    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1190    /// `chan = (i / inner) % chan_dim` selects the per-channel
1191    /// scale/zp; `chan_axis` is informational only (the kernel uses
1192    /// `chan_dim` and `inner` directly).
1193    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1194    /// always 0.
1195    Quantize {
1196        x: usize,
1197        q: usize,
1198        len: u32,
1199        chan_axis: u32,
1200        chan_dim: u32,
1201        inner: u32,
1202        scales: Vec<f32>,
1203        zero_points: Vec<i32>,
1204    },
1205
1206    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1207    Dequantize {
1208        q: usize,
1209        x: usize,
1210        len: u32,
1211        chan_axis: u32,
1212        chan_dim: u32,
1213        inner: u32,
1214        scales: Vec<f32>,
1215        zero_points: Vec<i32>,
1216    },
1217
1218    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1219    /// quantize-then-dequantize on the fly. Computes
1220    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1221    /// then
1222    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1223    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1224    /// channel-layout convention as `Thunk::Quantize`: every
1225    /// element's channel is `(i / inner) % chan_dim`. The kernel
1226    /// does two passes — one to scan max-abs per channel, one to
1227    /// quant-dequant per element.
1228    FakeQuantize {
1229        x: usize,
1230        out: usize,
1231        len: u32,
1232        chan_axis: u32,
1233        chan_dim: u32,
1234        inner: u32,
1235        bits: u8,
1236        /// STE variant — informational on the forward side (output is
1237        /// the same regardless), kernel-relevant in the matching
1238        /// `FakeQuantizeBackward` thunk.
1239        ste: rlx_ir::op::SteKind,
1240        /// Scale-tracking strategy. `PerBatch` recomputes
1241        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1242        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1243        /// reads `state_off` and never updates it.
1244        scale_mode: rlx_ir::op::ScaleMode,
1245        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1246        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1247        /// per channel.
1248        state_off: Option<usize>,
1249    },
1250
1251    /// Backward pass for `Op::FakeQuantize` under one of four STE
1252    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1253    /// the upstream gradient `dy`, using the same per-channel scale
1254    /// scheme as the forward.
1255    FakeQuantizeBackward {
1256        x: usize,
1257        dy: usize,
1258        dx: usize,
1259        len: u32,
1260        chan_axis: u32,
1261        chan_dim: u32,
1262        inner: u32,
1263        bits: u8,
1264        ste: rlx_ir::op::SteKind,
1265    },
1266
1267    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1268    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1269    FakeQuantizeLSQ {
1270        x: usize,
1271        scale_off: usize,
1272        out: usize,
1273        len: u32,
1274        chan_axis: u32,
1275        chan_dim: u32,
1276        inner: u32,
1277        bits: u8,
1278    },
1279
1280    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1281    /// through inside the quantization range, zeros outside.
1282    FakeQuantizeLSQBackwardX {
1283        x: usize,
1284        scale_off: usize,
1285        dy: usize,
1286        dx: usize,
1287        len: u32,
1288        chan_axis: u32,
1289        chan_dim: u32,
1290        inner: u32,
1291        bits: u8,
1292    },
1293
1294    /// LSQ backward, scale-gradient. Per-channel:
1295    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1296    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1297    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1298    FakeQuantizeLSQBackwardScale {
1299        x: usize,
1300        scale_off: usize,
1301        dy: usize,
1302        dscale: usize,
1303        len: u32,
1304        chan_axis: u32,
1305        chan_dim: u32,
1306        inner: u32,
1307        bits: u8,
1308    },
1309
1310    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1311    ReluBackward {
1312        x: usize,
1313        dy: usize,
1314        dx: usize,
1315        len: u32,
1316    },
1317    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1318    /// but reads/writes 8 bytes per element. Required because
1319    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1320    /// otherwise → backward silently produces 0 gradients on an f64
1321    /// graph. Mirrors the `ActivationBackwardF64` split.
1322    ReluBackwardF64 {
1323        x: usize,
1324        dy: usize,
1325        dx: usize,
1326        len: u32,
1327    },
1328
1329    /// Generic element-wise activation backward.
1330    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1331    /// per-element; expensive activations (Gelu) recompute internals
1332    /// inline rather than threading an extra "saved y" tensor through.
1333    ActivationBackward {
1334        x: usize,
1335        dy: usize,
1336        dx: usize,
1337        len: u32,
1338        kind: Activation,
1339    },
1340    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1341    /// elements; kernel reads/writes 8 bytes per element. Required
1342    /// because `ActivationBackward`'s `&[f32]` slot view silently
1343    /// returns garbage on an f64 graph (cb % 4 still works but every
1344    /// loaded value is half of an f64 → wrong gradient).
1345    ActivationBackwardF64 {
1346        x: usize,
1347        dy: usize,
1348        dx: usize,
1349        len: u32,
1350        kind: Activation,
1351    },
1352
1353    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1354    /// `x` and emits the closed-form `d_x` per row.
1355    LayerNormBackwardInput {
1356        x: usize,
1357        gamma: usize,
1358        dy: usize,
1359        dx: usize,
1360        rows: u32,
1361        h: u32,
1362        eps: f32,
1363    },
1364
1365    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1366    LayerNormBackwardGamma {
1367        x: usize,
1368        dy: usize,
1369        dgamma: usize,
1370        rows: u32,
1371        h: u32,
1372        eps: f32,
1373    },
1374
1375    RmsNormBackwardInput {
1376        x: usize,
1377        gamma: usize,
1378        beta: usize,
1379        dy: usize,
1380        dx: usize,
1381        rows: u32,
1382        h: u32,
1383        eps: f32,
1384    },
1385    RmsNormBackwardGamma {
1386        x: usize,
1387        gamma: usize,
1388        beta: usize,
1389        dy: usize,
1390        dgamma: usize,
1391        rows: u32,
1392        h: u32,
1393        eps: f32,
1394    },
1395    RmsNormBackwardBeta {
1396        x: usize,
1397        gamma: usize,
1398        beta: usize,
1399        dy: usize,
1400        dbeta: usize,
1401        rows: u32,
1402        h: u32,
1403        eps: f32,
1404    },
1405    RopeBackward {
1406        dy: usize,
1407        cos: usize,
1408        sin: usize,
1409        dx: usize,
1410        batch: u32,
1411        seq: u32,
1412        hidden: u32,
1413        head_dim: u32,
1414        n_rot: u32,
1415        cos_len: u32,
1416    },
1417    CumsumBackward {
1418        dy: usize,
1419        dx: usize,
1420        rows: u32,
1421        cols: u32,
1422        exclusive: bool,
1423    },
1424    GatherBackward {
1425        dy: usize,
1426        indices: usize,
1427        dst: usize,
1428        outer: u32,
1429        axis_dim: u32,
1430        num_idx: u32,
1431        trailing: u32,
1432    },
1433
1434    GroupNormBackwardInput {
1435        x: usize,
1436        gamma: usize,
1437        beta: usize,
1438        dy: usize,
1439        dx: usize,
1440        n: u32,
1441        c: u32,
1442        h: u32,
1443        w: u32,
1444        num_groups: u32,
1445        eps: f32,
1446    },
1447    GroupNormBackwardGamma {
1448        x: usize,
1449        dy: usize,
1450        dgamma: usize,
1451        n: u32,
1452        c: u32,
1453        h: u32,
1454        w: u32,
1455        num_groups: u32,
1456        eps: f32,
1457    },
1458    GroupNormBackwardBeta {
1459        dy: usize,
1460        dbeta: usize,
1461        n: u32,
1462        c: u32,
1463        h: u32,
1464        w: u32,
1465    },
1466
1467    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1468    /// inside each window and accumulates `dy` into `dx` at that
1469    /// position. Output is zeroed first; ties resolve to the first
1470    /// hit (lowest (kh,kw) index), matching what the forward kernel
1471    /// does with `acc.max(v)`.
1472    MaxPool2dBackward {
1473        x: usize,
1474        dy: usize,
1475        dx: usize,
1476        n: u32,
1477        c: u32,
1478        h: u32,
1479        w: u32,
1480        h_out: u32,
1481        w_out: u32,
1482        kh: u32,
1483        kw: u32,
1484        sh: u32,
1485        sw: u32,
1486        ph: u32,
1487        pw: u32,
1488    },
1489
1490    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1491    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1492    /// `dx [N, C_in, H, W]`.
1493    Conv2dBackwardInput {
1494        dy: usize,
1495        w: usize,
1496        dx: usize,
1497        n: u32,
1498        c_in: u32,
1499        h: u32,
1500        w_in: u32,
1501        c_out: u32,
1502        h_out: u32,
1503        w_out: u32,
1504        kh: u32,
1505        kw: u32,
1506        sh: u32,
1507        sw: u32,
1508        ph: u32,
1509        pw: u32,
1510        dh: u32,
1511        dw: u32,
1512        groups: u32,
1513    },
1514
1515    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1516    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1517    /// `dw` is zeroed before accumulation.
1518    Conv2dBackwardWeight {
1519        x: usize,
1520        dy: usize,
1521        dw: usize,
1522        n: u32,
1523        c_in: u32,
1524        h: u32,
1525        w: u32,
1526        c_out: u32,
1527        h_out: u32,
1528        w_out: u32,
1529        kh: u32,
1530        kw: u32,
1531        sh: u32,
1532        sw: u32,
1533        ph: u32,
1534        pw: u32,
1535        dh: u32,
1536        dw_dil: u32,
1537        groups: u32,
1538    },
1539
1540    /// NCHW im2col for conv backward-weight matmul. Output `[M, C·kH·kW]`
1541    /// with `M = N · H_out · W_out`. `n == 0` means infer batch from `x`.
1542    Im2Col {
1543        x: usize,
1544        col: usize,
1545        n: u32,
1546        c_in: u32,
1547        h: u32,
1548        w: u32,
1549        h_out: u32,
1550        w_out: u32,
1551        kh: u32,
1552        kw: u32,
1553        sh: u32,
1554        sw: u32,
1555        ph: u32,
1556        pw: u32,
1557        dh: u32,
1558        dw_dil: u32,
1559    },
1560
1561    /// Fused softmax + cross-entropy loss with f32-encoded integer
1562    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1563    /// Numerically stable (max-subtract before exp).
1564    SoftmaxCrossEntropy {
1565        logits: usize,
1566        labels: usize,
1567        dst: usize,
1568        n: u32,
1569        c: u32,
1570    },
1571
1572    /// Backward of the fused loss above.
1573    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1574    SoftmaxCrossEntropyBackward {
1575        logits: usize,
1576        labels: usize,
1577        d_loss: usize,
1578        dlogits: usize,
1579        n: u32,
1580        c: u32,
1581    },
1582
1583    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1584    /// `kernel` is resolved against the global CPU kernel registry at
1585    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1586    /// avoids per-call lookups. v1: f32 contiguous only — see
1587    /// `op_registry::CpuKernel::execute_f32`.
1588    CustomOp {
1589        kernel: Arc<dyn CpuKernel>,
1590        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1591        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1592        attrs: Vec<u8>,
1593    },
1594
1595    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1596    /// real-block layout (first N real, second N imag along the
1597    /// transformed axis). `outer` is the product of all leading axes;
1598    /// `n_complex` is N (the number of complex points). Both halves
1599    /// of the real-block layout are read together by the kernel.
1600    /// `dtype` selects the f32 or f64 path; the two share structure
1601    /// but not buffers, so a flag at compile time avoids per-row
1602    /// dispatch.
1603    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1604    GaussianSplatRender {
1605        positions_off: usize,
1606        positions_len: usize,
1607        scales_off: usize,
1608        scales_len: usize,
1609        rotations_off: usize,
1610        rotations_len: usize,
1611        opacities_off: usize,
1612        opacities_len: usize,
1613        colors_off: usize,
1614        colors_len: usize,
1615        sh_coeffs_off: usize,
1616        sh_coeffs_len: usize,
1617        meta_off: usize,
1618        dst_off: usize,
1619        dst_len: usize,
1620        width: u32,
1621        height: u32,
1622        tile_size: u32,
1623        radius_scale: f32,
1624        alpha_cutoff: f32,
1625        max_splat_steps: u32,
1626        transmittance_threshold: f32,
1627        max_list_entries: u32,
1628    },
1629    GaussianSplatRenderBackward {
1630        positions_off: usize,
1631        positions_len: usize,
1632        scales_off: usize,
1633        scales_len: usize,
1634        rotations_off: usize,
1635        rotations_len: usize,
1636        opacities_off: usize,
1637        opacities_len: usize,
1638        colors_off: usize,
1639        colors_len: usize,
1640        sh_coeffs_off: usize,
1641        sh_coeffs_len: usize,
1642        meta_off: usize,
1643        d_loss_off: usize,
1644        d_loss_len: usize,
1645        packed_off: usize,
1646        packed_len: usize,
1647        width: u32,
1648        height: u32,
1649        tile_size: u32,
1650        radius_scale: f32,
1651        alpha_cutoff: f32,
1652        max_splat_steps: u32,
1653        transmittance_threshold: f32,
1654        max_list_entries: u32,
1655        loss_grad_clip: f32,
1656        sh_band: u32,
1657        max_anisotropy: f32,
1658    },
1659    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1660    GaussianSplatPrepare {
1661        positions_off: usize,
1662        positions_len: usize,
1663        scales_off: usize,
1664        scales_len: usize,
1665        rotations_off: usize,
1666        rotations_len: usize,
1667        opacities_off: usize,
1668        opacities_len: usize,
1669        colors_off: usize,
1670        colors_len: usize,
1671        sh_coeffs_off: usize,
1672        sh_coeffs_len: usize,
1673        meta_off: usize,
1674        meta_len: usize,
1675        prep_off: usize,
1676        prep_len: usize,
1677        width: u32,
1678        height: u32,
1679        tile_size: u32,
1680        radius_scale: f32,
1681        alpha_cutoff: f32,
1682        max_splat_steps: u32,
1683        transmittance_threshold: f32,
1684        max_list_entries: u32,
1685    },
1686    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1687    GaussianSplatRasterize {
1688        prep_off: usize,
1689        prep_len: usize,
1690        meta_off: usize,
1691        meta_len: usize,
1692        dst_off: usize,
1693        dst_len: usize,
1694        count: usize,
1695        width: u32,
1696        height: u32,
1697        tile_size: u32,
1698        alpha_cutoff: f32,
1699        max_splat_steps: u32,
1700        transmittance_threshold: f32,
1701        max_list_entries: u32,
1702    },
1703    Fft1d {
1704        src: usize,
1705        dst: usize,
1706        outer: u32,
1707        n_complex: u32,
1708        inverse: bool,
1709        norm_tag: u32,
1710        dtype: rlx_ir::DType,
1711    },
1712    FftButterflyStage {
1713        state_src: usize,
1714        state_dst: usize,
1715        gate_src: usize,
1716        rev_src: usize,
1717        tw_re_src: usize,
1718        tw_im_src: usize,
1719        batch: u32,
1720        n_fft: u32,
1721        stage: u32,
1722    },
1723    LogMel {
1724        spec: usize,
1725        filters: usize,
1726        dst: usize,
1727        outer: u32,
1728        n_fft: u32,
1729        n_bins: u32,
1730        n_mels: u32,
1731    },
1732    LogMelBackward {
1733        spec: usize,
1734        filters: usize,
1735        dy: usize,
1736        dst: usize,
1737        outer: u32,
1738        n_fft: u32,
1739        n_bins: u32,
1740        n_mels: u32,
1741    },
1742    WelchPeaks {
1743        spec: usize,
1744        dst: usize,
1745        welch_batch: u32,
1746        n_fft: u32,
1747        n_segments: u32,
1748        k: u32,
1749    },
1750}
1751
1752/// Compiled thunk schedule — the runtime hot path.
1753/// Nop thunks are filtered out at compile time for zero iteration overhead.
1754#[derive(Clone)]
1755pub struct ThunkSchedule {
1756    pub thunks: Vec<Thunk>,
1757    /// TIDE merged placement mask (union across layers).
1758    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1759    /// Per MoE layer placement (`layer[e]`); preferred when set.
1760    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1761    /// MoE router TopK capture (per-layer refresh).
1762    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1763    /// Cached config values.
1764    pub mask_threshold: f32,
1765    pub mask_neg_inf: f32,
1766    pub score_skip: f32,
1767    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1768    /// `Box`) so the schedule can be `Clone` — multiple parallel
1769    /// executors share the same compiled closures (they're read-only
1770    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1771    /// they receive is the only mutable state and is per-executor).
1772    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1773}
1774
1775impl ThunkSchedule {
1776    pub fn strip_nops(&mut self) {
1777        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1778        // compiled_fns must be rebuilt after stripping — caller should
1779        // call strip_nops() before compile_closures().
1780        self.compiled_fns.clear();
1781    }
1782}
1783
1784/// Get the arena byte offset for a node.
1785fn node_offset(arena: &Arena, id: NodeId) -> usize {
1786    if arena.has_buffer(id) {
1787        arena.byte_offset(id)
1788    } else {
1789        usize::MAX
1790    }
1791}
1792
1793/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1794/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1795/// before eliding it. Conservative: when in doubt about reads (an op
1796/// not yet listed here), the fusion will skip — correctness over
1797/// completeness.
1798fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1799    match t {
1800        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1801        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1802        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1803        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1804        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1805        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1806        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1807        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1808        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1809        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1810        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1811        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1812        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1813        Thunk::ConjugateC64 { src, .. } => vec![*src],
1814        Thunk::Scan {
1815            outer_init_off,
1816            xs_inputs,
1817            ..
1818        } => {
1819            let mut v = vec![*outer_init_off];
1820            for (_, outer_xs_off, _) in xs_inputs.iter() {
1821                v.push(*outer_xs_off);
1822            }
1823            v
1824        }
1825        Thunk::ScanBackward {
1826            outer_init_off,
1827            outer_traj_off,
1828            outer_upstream_off,
1829            outer_xs_offs,
1830            ..
1831        } => {
1832            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1833            for (off, _) in outer_xs_offs.iter() {
1834                v.push(*off);
1835            }
1836            v
1837        }
1838        Thunk::ScanBackwardXs {
1839            outer_init_off,
1840            outer_traj_off,
1841            outer_upstream_off,
1842            outer_xs_offs,
1843            ..
1844        } => {
1845            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1846            for (off, _) in outer_xs_offs.iter() {
1847                v.push(*off);
1848            }
1849            v
1850        }
1851        Thunk::CustomFn { inputs, .. } => {
1852            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1853        }
1854        Thunk::ActivationInPlace { data, .. } => vec![*data],
1855        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1856            vec![*src, *g, *b]
1857        }
1858        Thunk::BatchNormInference {
1859            src,
1860            g,
1861            b,
1862            mean,
1863            var,
1864            ..
1865        } => vec![*src, *g, *b, *mean, *var],
1866        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1867        Thunk::AxialRope2d { src, .. } => vec![*src],
1868        Thunk::FusedResidualLN {
1869            x, res, bias, g, b, ..
1870        } => vec![*x, *res, *bias, *g, *b],
1871        Thunk::FusedResidualRmsNorm {
1872            x, res, bias, g, b, ..
1873        } => vec![*x, *res, *bias, *g, *b],
1874        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1875        Thunk::Softmax { data, .. } => vec![*data],
1876        Thunk::Cumsum { src, .. } => vec![*src],
1877        Thunk::Sample { logits, .. } => vec![*logits],
1878        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1879        Thunk::DequantMatMul {
1880            x, w_q, scale, zp, ..
1881        } => vec![*x, *w_q, *scale, *zp],
1882        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1883        Thunk::DequantMatMulInt4 {
1884            x, w_q, scale, zp, ..
1885        } => vec![*x, *w_q, *scale, *zp],
1886        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1887        Thunk::DequantMatMulNvfp4 {
1888            x,
1889            w_q,
1890            scale,
1891            global_scale,
1892            ..
1893        } => vec![*x, *w_q, *scale, *global_scale],
1894        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1895        Thunk::SelectiveScan {
1896            x, delta, a, b, c, ..
1897        } => vec![*x, *delta, *a, *b, *c],
1898        Thunk::GatedDeltaNet {
1899            q,
1900            k,
1901            v,
1902            g,
1903            beta,
1904            state,
1905            ..
1906        } => {
1907            let mut v = vec![*q, *k, *v, *g, *beta];
1908            if *state != 0 {
1909                v.push(*state);
1910            }
1911            v
1912        }
1913        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1914        Thunk::AttentionBackward {
1915            q, k, v, dy, mask, ..
1916        } => {
1917            let mut v = vec![*q, *k, *v, *dy];
1918            if *mask != 0 {
1919                v.push(*mask);
1920            }
1921            v
1922        }
1923        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1924        Thunk::FusedAttnBlock {
1925            hidden,
1926            qkv_w,
1927            out_w,
1928            mask,
1929            qkv_b,
1930            out_b,
1931            cos,
1932            sin,
1933            ..
1934        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1935        Thunk::FusedSwiGLU { src, .. } => vec![*src],
1936        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1937        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1938        Thunk::Narrow { src, .. } => vec![*src],
1939        Thunk::Copy { src, .. } => vec![*src],
1940        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1941        // Anything not enumerated → return the dst as a "read" too,
1942        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
1943        // this list safe to be incomplete.
1944        _ => vec![],
1945    }
1946}
1947
1948/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
1949/// `block_size` consecutive elements of a column share one f32
1950/// scale (and optionally a zero-point). The dequant happens inside
1951/// the inner accumulate so the f32 weight is never materialized.
1952///
1953/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
1954/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
1955/// empty.
1956///
1957/// Today this is the reference scalar implementation — the win is
1958/// memory bandwidth, not flops, since LLM weights dominate the
1959/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
1960/// fused-multiply-add is the natural follow-on.
1961#[allow(clippy::too_many_arguments)]
1962fn dequant_matmul_int8(
1963    x: &[f32],       // [m, k]
1964    w_bytes: &[i8],  // [k, n]
1965    scales: &[f32],  // [k/block, n]
1966    zps: &[f32],     // [k/block, n] or empty
1967    out: &mut [f32], // [m, n]
1968    m: usize,
1969    k: usize,
1970    n: usize,
1971    block_size: usize,
1972    asym: bool,
1973) {
1974    let blocks_per_col = k.div_ceil(block_size);
1975    for i in 0..m {
1976        for j in 0..n {
1977            let mut acc = 0f32;
1978            for p in 0..k {
1979                let block = p / block_size;
1980                let s = scales[block * n + j];
1981                let z = if asym { zps[block * n + j] } else { 0.0 };
1982                let q = w_bytes[p * n + j] as f32;
1983                let dequantized = (q - z) * s;
1984                acc += x[i * k + p] * dequantized;
1985            }
1986            out[i * n + j] = acc;
1987        }
1988    }
1989    let _ = blocks_per_col;
1990}
1991
1992#[allow(clippy::too_many_arguments)]
1993fn dequant_matmul_int4(
1994    x: &[f32],
1995    w_bytes: &[u8],
1996    scales: &[f32],
1997    zps: &[f32],
1998    out: &mut [f32],
1999    m: usize,
2000    k: usize,
2001    n: usize,
2002    block_size: usize,
2003    asym: bool,
2004) {
2005    for i in 0..m {
2006        for j in 0..n {
2007            let mut acc = 0f32;
2008            for p in 0..k {
2009                let block = p / block_size;
2010                let s = scales[block * n + j];
2011                let z = if asym { zps[block * n + j] } else { 0.0 };
2012                let byte_idx = (p * n + j) / 2;
2013                let nibble = if (p * n + j) & 1 == 0 {
2014                    w_bytes[byte_idx] & 0x0F
2015                } else {
2016                    w_bytes[byte_idx] >> 4
2017                };
2018                let dequantized = (nibble as f32 - z) * s;
2019                acc += x[i * k + p] * dequantized;
2020            }
2021            out[i * n + j] = acc;
2022        }
2023    }
2024}
2025
2026fn fp8_e4m3_to_f32(b: u8) -> f32 {
2027    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2028    let exp = (b >> 3) & 0x0F;
2029    let mant = b & 0x07;
2030    if exp == 0 {
2031        if mant == 0 {
2032            return 0.0;
2033        }
2034        return sign * (mant as f32) * 2f32.powi(-9);
2035    }
2036    if exp == 0x0F {
2037        return if mant == 0 {
2038            sign * f32::INFINITY
2039        } else {
2040            f32::NAN
2041        };
2042    }
2043    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2044}
2045
2046fn fp8_e5m2_to_f32(b: u8) -> f32 {
2047    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2048    let exp = (b >> 2) & 0x1F;
2049    let mant = b & 0x03;
2050    if exp == 0 {
2051        if mant == 0 {
2052            return 0.0;
2053        }
2054        return sign * (mant as f32) * 2f32.powi(-16);
2055    }
2056    if exp == 0x1F {
2057        return if mant == 0 {
2058            sign * f32::INFINITY
2059        } else {
2060            f32::NAN
2061        };
2062    }
2063    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2064}
2065
2066#[allow(clippy::too_many_arguments)]
2067fn dequant_matmul_fp8(
2068    x: &[f32],
2069    w_bytes: &[u8],
2070    scales: &[f32],
2071    out: &mut [f32],
2072    m: usize,
2073    k: usize,
2074    n: usize,
2075    e5m2: bool,
2076) {
2077    let dequant = if e5m2 {
2078        fp8_e5m2_to_f32
2079    } else {
2080        fp8_e4m3_to_f32
2081    };
2082    for i in 0..m {
2083        for j in 0..n {
2084            let mut acc = 0f32;
2085            for p in 0..k {
2086                let w = dequant(w_bytes[p * n + j]);
2087                let s = scales.get(j).copied().unwrap_or(1.0);
2088                acc += x[i * k + p] * w * s;
2089            }
2090            out[i * n + j] = acc;
2091        }
2092    }
2093}
2094
2095#[allow(clippy::too_many_arguments)]
2096pub fn dequant_matmul_nvfp4(
2097    x: &[f32],
2098    w_bytes: &[u8],
2099    scale_bytes: &[u8],
2100    global_scale: f32,
2101    out: &mut [f32],
2102    m: usize,
2103    k: usize,
2104    n: usize,
2105) {
2106    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2107    let gs = NVFP4_GROUP_SIZE;
2108    for i in 0..m {
2109        for j in 0..n {
2110            let mut acc = 0f32;
2111            for p in 0..k {
2112                let byte_idx = (p * n + j) / 2;
2113                let nibble = if (p * n + j) & 1 == 0 {
2114                    w_bytes[byte_idx] & 0x0F
2115                } else {
2116                    w_bytes[byte_idx] >> 4
2117                };
2118                let block = p / gs;
2119                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2120                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2121                acc += x[i * k + p] * w;
2122            }
2123            out[i * n + j] = acc;
2124        }
2125    }
2126}
2127
2128/// Fused sampling step: logits → top-k filter → top-p truncation
2129/// → softmax → multinomial sample. Operates on one row of length
2130/// `vocab` and returns the sampled index. Plan #42.
2131///
2132/// Internal scratch is on the stack via SmallVec-style fallback —
2133/// for `vocab > 8192` we heap-allocate a working buffer; below
2134/// that we keep things in a fixed array. (TODO: thread the
2135/// scratch through ThunkSchedule like sdpa_scores does.)
2136fn sample_row(
2137    logits: &[f32],
2138    top_k: usize,
2139    top_p: f32,
2140    temperature: f32,
2141    rng: &mut rlx_ir::Philox4x32,
2142) -> usize {
2143    let v = logits.len();
2144    if v == 0 {
2145        return 0;
2146    }
2147    let temp = temperature.max(1e-6);
2148    // Copy + temperature-scale into a working buffer.
2149    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2150
2151    // Top-k: zero out everything but the k largest by setting to -inf.
2152    if top_k > 0 && top_k < v {
2153        // Partial selection: find k-th largest then mask below.
2154        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2155        // Sort descending; partial would be O(n log k), full sort is fine
2156        // for typical vocab sizes (32k-128k) — single-row work.
2157        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2158        let cutoff = indexed[top_k - 1].1;
2159        for x in scaled.iter_mut() {
2160            if *x < cutoff {
2161                *x = f32::NEG_INFINITY;
2162            }
2163        }
2164    }
2165
2166    // Stable softmax.
2167    let mut max_l = f32::NEG_INFINITY;
2168    for &x in &scaled {
2169        if x > max_l {
2170            max_l = x;
2171        }
2172    }
2173    let mut sum = 0.0f32;
2174    for x in scaled.iter_mut() {
2175        *x = (*x - max_l).exp();
2176        sum += *x;
2177    }
2178    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2179    for x in scaled.iter_mut() {
2180        *x *= inv;
2181    }
2182
2183    // Top-p: keep the smallest set of tokens whose cumulative
2184    // probability exceeds top_p (after sorting descending).
2185    if top_p < 1.0 {
2186        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2187        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2188        let mut cum = 0.0f32;
2189        let mut keep = vec![false; v];
2190        for (idx, p) in indexed.iter() {
2191            keep[*idx] = true;
2192            cum += *p;
2193            if cum >= top_p {
2194                break;
2195            }
2196        }
2197        let mut new_sum = 0.0f32;
2198        for (i, x) in scaled.iter_mut().enumerate() {
2199            if !keep[i] {
2200                *x = 0.0;
2201            }
2202            new_sum += *x;
2203        }
2204        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2205        for x in scaled.iter_mut() {
2206            *x *= inv;
2207        }
2208    }
2209
2210    // Multinomial sample via inverse-CDF.
2211    let r = rng.next_f32();
2212    let mut acc = 0.0f32;
2213    for (i, &p) in scaled.iter().enumerate() {
2214        acc += p;
2215        if r <= acc {
2216            return i;
2217        }
2218    }
2219    v - 1 // floating-point edge case fallback
2220}
2221
2222/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2223/// scores matrix. Custom masks are read from a tensor and not handled here.
2224/// `None` is a no-op so callers don't need to special-case it.
2225#[inline]
2226fn apply_synthetic_mask(
2227    scores: &mut [f32],
2228    q_seq: usize,
2229    k_seq: usize,
2230    kind: rlx_ir::op::MaskKind,
2231) {
2232    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2233    let q_offset = k_seq.saturating_sub(q_seq);
2234    match kind {
2235        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2236        rlx_ir::op::MaskKind::Causal => {
2237            for qi in 0..q_seq {
2238                let abs_q = q_offset + qi;
2239                for ki in (abs_q + 1)..k_seq {
2240                    scores[qi * k_seq + ki] = neg;
2241                }
2242            }
2243        }
2244        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2245            for qi in 0..q_seq {
2246                let abs_q = q_offset + qi;
2247                let lo = abs_q.saturating_sub(w);
2248                for ki in 0..k_seq {
2249                    if ki < lo || ki > abs_q {
2250                        scores[qi * k_seq + ki] = neg;
2251                    }
2252                }
2253            }
2254        }
2255    }
2256}
2257
2258/// NCL `[N,C,L]` or NCHW `[N,C,H,W]` → `(n, c, h, w)` for 2D conv/norm thunks.
2259fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2260    match shape.rank() {
2261        3 => (
2262            shape.dim(0).unwrap_static() as u32,
2263            shape.dim(1).unwrap_static() as u32,
2264            1,
2265            shape.dim(2).unwrap_static() as u32,
2266        ),
2267        4 => (
2268            shape.dim(0).unwrap_static() as u32,
2269            shape.dim(1).unwrap_static() as u32,
2270            shape.dim(2).unwrap_static() as u32,
2271            shape.dim(3).unwrap_static() as u32,
2272        ),
2273        r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2274    }
2275}
2276
2277/// Compile graph into thunk schedule.
2278pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2279    let mut thunks = Vec::with_capacity(graph.len());
2280
2281    for node in graph.nodes() {
2282        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2283        // to their parent's slot by the memory planner — no copy needed.
2284        // Plan #46.
2285        if rlx_opt::is_pure_view(graph, node) {
2286            thunks.push(Thunk::Nop);
2287            continue;
2288        }
2289        let t = match &node.op {
2290            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2291
2292            Op::FusedMatMulBiasAct { activation } => {
2293                let shape = &node.shape;
2294                let n = shape.dim(shape.rank() - 1).unwrap_static();
2295                let total = shape.num_elements().unwrap();
2296                let m = total / n;
2297                let a_len = get_len(graph, node.inputs[0]);
2298                let k = a_len / m;
2299                Thunk::FusedMmBiasAct {
2300                    a: node_offset(arena, node.inputs[0]),
2301                    w: node_offset(arena, node.inputs[1]),
2302                    bias: node_offset(arena, node.inputs[2]),
2303                    c: node_offset(arena, node.id),
2304                    m: m as u32,
2305                    k: k as u32,
2306                    n: n as u32,
2307                    act: *activation,
2308                }
2309            }
2310
2311            Op::FusedResidualLN { has_bias, eps } => {
2312                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2313                let total = node.shape.num_elements().unwrap();
2314                let rows = total / h;
2315                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2316                Thunk::FusedResidualLN {
2317                    x: node_offset(arena, node.inputs[0]),
2318                    res: node_offset(arena, node.inputs[1]),
2319                    bias: if *has_bias {
2320                        node_offset(arena, node.inputs[2])
2321                    } else {
2322                        0
2323                    },
2324                    g: node_offset(arena, node.inputs[g_idx]),
2325                    b: node_offset(arena, node.inputs[b_idx]),
2326                    out: node_offset(arena, node.id),
2327                    rows: rows as u32,
2328                    h: h as u32,
2329                    eps: *eps,
2330                    has_bias: *has_bias,
2331                }
2332            }
2333
2334            Op::FusedResidualRmsNorm { has_bias, eps } => {
2335                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2336                let total = node.shape.num_elements().unwrap();
2337                let rows = total / h;
2338                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2339                Thunk::FusedResidualRmsNorm {
2340                    x: node_offset(arena, node.inputs[0]),
2341                    res: node_offset(arena, node.inputs[1]),
2342                    bias: if *has_bias {
2343                        node_offset(arena, node.inputs[2])
2344                    } else {
2345                        0
2346                    },
2347                    g: node_offset(arena, node.inputs[g_idx]),
2348                    b: node_offset(arena, node.inputs[b_idx]),
2349                    out: node_offset(arena, node.id),
2350                    rows: rows as u32,
2351                    h: h as u32,
2352                    eps: *eps,
2353                    has_bias: *has_bias,
2354                }
2355            }
2356
2357            Op::MatMul => {
2358                let shape = &node.shape;
2359                let a_shape = &graph.node(node.inputs[0]).shape;
2360                let b_shape = &graph.node(node.inputs[1]).shape;
2361                // Prefer inferred matmul shape from operands — ONNX bundle
2362                // meta often over-ranks outputs (e.g. [seq, seq, H]).
2363                let eff =
2364                    rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2365                let rank = eff.rank().max(2);
2366                let n = eff.dim(rank - 1).unwrap_static();
2367                let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2368                // Batched GEMM only when both operands carry batch dimensions.
2369                // 3D×2D (activations × shared weight) must flatten to one Sgemm.
2370                let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2371                let batched_3d = rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2372                if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2373                    let mut batch_prod = 1usize;
2374                    for d in 0..rank - 2 {
2375                        batch_prod *= eff.dim(d).unwrap_static();
2376                    }
2377                    let m_dim = eff.dim(rank - 2).unwrap_static();
2378                    Thunk::BatchedDgemmF64 {
2379                        a: node_offset(arena, node.inputs[0]),
2380                        b: node_offset(arena, node.inputs[1]),
2381                        c: node_offset(arena, node.id),
2382                        batch: batch_prod as u32,
2383                        m: m_dim as u32,
2384                        k: k_dim as u32,
2385                        n: n as u32,
2386                    }
2387                } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2388                    let mut batch_prod = 1usize;
2389                    for d in 0..rank - 2 {
2390                        batch_prod *= eff.dim(d).unwrap_static();
2391                    }
2392                    let m_dim = eff.dim(rank - 2).unwrap_static();
2393                    Thunk::BatchedSgemm {
2394                        a: node_offset(arena, node.inputs[0]),
2395                        b: node_offset(arena, node.inputs[1]),
2396                        c: node_offset(arena, node.id),
2397                        batch: batch_prod as u32,
2398                        m: m_dim as u32,
2399                        k: k_dim as u32,
2400                        n: n as u32,
2401                    }
2402                } else {
2403                    let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2404                        let mut m_prod = 1usize;
2405                        for d in 0..a_shape.rank() - 1 {
2406                            m_prod *= a_shape.dim(d).unwrap_static();
2407                        }
2408                        m_prod
2409                    } else if a_shape.rank() >= 2 {
2410                        a_shape.dim(a_shape.rank() - 2).unwrap_static()
2411                    } else {
2412                        eff.num_elements().unwrap_or(1) / n.max(1)
2413                    };
2414                    match shape.dtype() {
2415                        rlx_ir::DType::F64 => Thunk::Dgemm {
2416                            a: node_offset(arena, node.inputs[0]),
2417                            b: node_offset(arena, node.inputs[1]),
2418                            c: node_offset(arena, node.id),
2419                            m: m as u32,
2420                            k: k_dim as u32,
2421                            n: n as u32,
2422                        },
2423                        _ => Thunk::Sgemm {
2424                            a: node_offset(arena, node.inputs[0]),
2425                            b: node_offset(arena, node.inputs[1]),
2426                            c: node_offset(arena, node.id),
2427                            m: m as u32,
2428                            k: k_dim as u32,
2429                            n: n as u32,
2430                        },
2431                    }
2432                }
2433            }
2434
2435            Op::Binary(op) => {
2436                let lhs_len = get_len(graph, node.inputs[0]);
2437                let rhs_len = get_len(graph, node.inputs[1]);
2438                let out_len = node.shape.num_elements().unwrap();
2439                if node.shape.dtype() == rlx_ir::DType::C64 {
2440                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2441                    // to `BinaryFullC64`; the rest don't have a
2442                    // single natural complex definition.
2443                    match op {
2444                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2445                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2446                            "Op::Binary({op:?}) on DType::C64: complex \
2447                             max/min/pow have no single natural definition \
2448                             — caller should drop to 2N-real-block (see \
2449                             spike-ac) and pick a convention there"
2450                        ),
2451                    }
2452                }
2453                // Compute broadcast strides for the slow path. Empty
2454                // vectors when no broadcast is needed (the fast-path
2455                // kernel ignores them anyway).
2456                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2457                    if lhs_len == out_len && rhs_len == out_len {
2458                        (Vec::new(), Vec::new(), Vec::new())
2459                    } else {
2460                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2461                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2462                        let out_dims_v = get_static_dims(graph, node.id);
2463                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2464                            // Dynamic shape — fall back to the legacy
2465                            // modulo path (correct for scalar / last-
2466                            // axis broadcast, which is the only
2467                            // dynamic case in practice).
2468                            (Vec::new(), Vec::new(), Vec::new())
2469                        } else {
2470                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2471                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2472                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2473                            (od, ls, rs)
2474                        }
2475                    };
2476                if node.shape.dtype() == rlx_ir::DType::C64 {
2477                    Thunk::BinaryFullC64 {
2478                        lhs: node_offset(arena, node.inputs[0]),
2479                        rhs: node_offset(arena, node.inputs[1]),
2480                        dst: node_offset(arena, node.id),
2481                        len: out_len as u32,
2482                        lhs_len: lhs_len as u32,
2483                        rhs_len: rhs_len as u32,
2484                        op: *op,
2485                        out_dims_bcast,
2486                        bcast_lhs_strides,
2487                        bcast_rhs_strides,
2488                    }
2489                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2490                    // f64 path — no BiasAdd fast-path (yet); use the
2491                    // general binary-with-broadcast kernel.
2492                    Thunk::BinaryFullF64 {
2493                        lhs: node_offset(arena, node.inputs[0]),
2494                        rhs: node_offset(arena, node.inputs[1]),
2495                        dst: node_offset(arena, node.id),
2496                        len: out_len as u32,
2497                        lhs_len: lhs_len as u32,
2498                        rhs_len: rhs_len as u32,
2499                        op: *op,
2500                        out_dims_bcast,
2501                        bcast_lhs_strides,
2502                        bcast_rhs_strides,
2503                    }
2504                } else if matches!(op, BinaryOp::Add)
2505                    && rhs_len < out_len
2506                    && out_len % rhs_len == 0
2507                    && is_trailing_bias_broadcast(
2508                        graph.node(node.inputs[1]).shape.dims(),
2509                        graph.node(node.id).shape.dims(),
2510                    )
2511                {
2512                    // `BiasAdd` is only correct when the bias is a
2513                    // *trailing* broadcast — rhs dims match the right-
2514                    // hand side of the output dims (with size-1 only
2515                    // allowed in left-padded outer positions).
2516                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2517                    // has rhs_len divide out_len cleanly but is a
2518                    // mid-shape singleton, NOT a trailing broadcast.
2519                    // Routing it through BiasAdd silently treats it as
2520                    // last-`rhs_len`-cols repeated — wrong values.
2521                    Thunk::BiasAdd {
2522                        src: node_offset(arena, node.inputs[0]),
2523                        bias: node_offset(arena, node.inputs[1]),
2524                        dst: node_offset(arena, node.id),
2525                        m: (out_len / rhs_len) as u32,
2526                        n: rhs_len as u32,
2527                    }
2528                } else {
2529                    let lhs_len = get_len(graph, node.inputs[0]);
2530                    Thunk::BinaryFull {
2531                        lhs: node_offset(arena, node.inputs[0]),
2532                        rhs: node_offset(arena, node.inputs[1]),
2533                        dst: node_offset(arena, node.id),
2534                        len: out_len as u32,
2535                        lhs_len: lhs_len as u32,
2536                        rhs_len: rhs_len as u32,
2537                        op: *op,
2538                        out_dims_bcast,
2539                        bcast_lhs_strides,
2540                        bcast_rhs_strides,
2541                        elem_bytes: node.shape.dtype().size_bytes() as u8,
2542                    }
2543                }
2544            }
2545
2546            Op::Activation(act) => {
2547                let len = node.shape.num_elements().unwrap();
2548                let in_off = node_offset(arena, node.inputs[0]);
2549                let out_off = node_offset(arena, node.id);
2550                if node.shape.dtype() == rlx_ir::DType::C64 {
2551                    // Only Neg/Exp/Log/Sqrt have natural complex
2552                    // extensions used in signal-processing graphs.
2553                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2554                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2555                    match act {
2556                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2557                        other => panic!(
2558                            "Op::Activation({other:?}) on DType::C64: no \
2559                             natural complex extension — supported on C64: \
2560                             Neg, Exp, Log, Sqrt"
2561                        ),
2562                    }
2563                    Thunk::ActivationC64 {
2564                        src: in_off,
2565                        dst: out_off,
2566                        len: len as u32,
2567                        kind: *act,
2568                    }
2569                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2570                    Thunk::ActivationF64 {
2571                        src: in_off,
2572                        dst: out_off,
2573                        len: len as u32,
2574                        kind: *act,
2575                    }
2576                } else if in_off == out_off {
2577                    // ActivationInPlace operates on a single buffer. When the
2578                    // planner has assigned input and output the same slot
2579                    // (typical post-fusion case), we just run on that slot.
2580                    Thunk::ActivationInPlace {
2581                        data: out_off,
2582                        len: len as u32,
2583                        act: *act,
2584                    }
2585                } else {
2586                    // Two-step: copy input → output, then activate output in place.
2587                    // The schedule executes them in this order; downstream
2588                    // thunks see the activated output at out_off.
2589                    thunks.push(Thunk::Copy {
2590                        src: in_off,
2591                        dst: out_off,
2592                        len: len as u32,
2593                    });
2594                    Thunk::ActivationInPlace {
2595                        data: out_off,
2596                        len: len as u32,
2597                        act: *act,
2598                    }
2599                }
2600            }
2601
2602            Op::Gather { axis } if *axis == 0 => {
2603                let table_shape = &graph.node(node.inputs[0]).shape;
2604                let table_total = table_shape.num_elements().unwrap();
2605                let trailing: usize = (1..table_shape.rank())
2606                    .map(|i| table_shape.dim(i).unwrap_static())
2607                    .product();
2608                let idx_len = get_len(graph, node.inputs[1]);
2609                let idx_i64 =
2610                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2611                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2612                Thunk::Gather {
2613                    table: node_offset(arena, node.inputs[0]),
2614                    table_len: table_total as u32,
2615                    idx: node_offset(arena, node.inputs[1]),
2616                    dst: node_offset(arena, node.id),
2617                    num_idx: idx_len as u32,
2618                    trailing: trailing as u32,
2619                    idx_i64,
2620                    table_bytes,
2621                }
2622            }
2623
2624            Op::Gather { axis } => {
2625                // Non-zero axis: outer × num_idx × trailing layout.
2626                let table_shape = &graph.node(node.inputs[0]).shape;
2627                let rank = table_shape.rank();
2628                let outer: usize = (0..*axis)
2629                    .map(|i| table_shape.dim(i).unwrap_static())
2630                    .product::<usize>()
2631                    .max(1);
2632                let trailing: usize = (*axis + 1..rank)
2633                    .map(|i| table_shape.dim(i).unwrap_static())
2634                    .product::<usize>()
2635                    .max(1);
2636                let axis_dim = table_shape.dim(*axis).unwrap_static();
2637                let idx_len = get_len(graph, node.inputs[1]);
2638                let idx_i64 =
2639                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2640                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2641                Thunk::GatherAxis {
2642                    table: node_offset(arena, node.inputs[0]),
2643                    idx: node_offset(arena, node.inputs[1]),
2644                    dst: node_offset(arena, node.id),
2645                    outer: outer as u32,
2646                    axis_dim: axis_dim as u32,
2647                    num_idx: idx_len as u32,
2648                    trailing: trailing as u32,
2649                    idx_i64,
2650                    table_bytes,
2651                }
2652            }
2653
2654            Op::Narrow { axis, start, len } => {
2655                let in_shape = &graph.node(node.inputs[0]).shape;
2656                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2657                let rank = in_shape.rank();
2658                let outer: usize = (0..*axis)
2659                    .map(|i| in_shape.dim(i).unwrap_static())
2660                    .product::<usize>()
2661                    .max(1);
2662                let inner: usize = (*axis + 1..rank)
2663                    .map(|i| in_shape.dim(i).unwrap_static())
2664                    .product::<usize>()
2665                    .max(1);
2666                let in_axis = in_shape.dim(*axis).unwrap_static();
2667                let src_byte_offset =
2668                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2669                Thunk::Narrow {
2670                    src: src_byte_offset,
2671                    dst: node_offset(arena, node.id),
2672                    outer: outer as u32,
2673                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2674                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2675                    inner: (*len * inner) as u32,         // elements to copy per outer step
2676                    elem_bytes,
2677                }
2678            }
2679
2680            Op::Reshape { .. } | Op::StopGradient => {
2681                // Pure layout change: same total element count, plain copy.
2682                let len = node.shape.num_elements().unwrap();
2683                let src = node_offset(arena, node.inputs[0]);
2684                let dst = node_offset(arena, node.id);
2685                match node.shape.dtype() {
2686                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2687                        src,
2688                        dst,
2689                        len: len as u32,
2690                    },
2691                    rlx_ir::DType::I64 => Thunk::CopyI64 {
2692                        src,
2693                        dst,
2694                        len: len as u32,
2695                    },
2696                    _ => Thunk::Copy {
2697                        src,
2698                        dst,
2699                        len: len as u32,
2700                    },
2701                }
2702            }
2703
2704            Op::Cast { to } => {
2705                let in_node = graph.node(node.inputs[0]);
2706                let in_dtype = in_node.shape.dtype();
2707                let out_dtype = *to;
2708                let len = node.shape.num_elements().unwrap();
2709                let src = node_offset(arena, node.inputs[0]);
2710                let dst = node_offset(arena, node.id);
2711                if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2712                    Thunk::CastF32ToI64 {
2713                        src,
2714                        dst,
2715                        len: len as u32,
2716                    }
2717                } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2718                    Thunk::CastI64ToF32 {
2719                        src,
2720                        dst,
2721                        len: len as u32,
2722                    }
2723                } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2724                    Thunk::CastBoolToI32 {
2725                        src,
2726                        dst,
2727                        len: len as u32,
2728                    }
2729                } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2730                    Thunk::CastI32ToF32 {
2731                        src,
2732                        dst,
2733                        len: len as u32,
2734                    }
2735                } else if in_dtype == out_dtype {
2736                    match out_dtype {
2737                        rlx_ir::DType::F64 => Thunk::CopyF64 {
2738                            src,
2739                            dst,
2740                            len: len as u32,
2741                        },
2742                        rlx_ir::DType::I64 => Thunk::CopyI64 {
2743                            src,
2744                            dst,
2745                            len: len as u32,
2746                        },
2747                        _ => Thunk::Copy {
2748                            src,
2749                            dst,
2750                            len: len as u32,
2751                        },
2752                    }
2753                } else {
2754                    Thunk::Copy {
2755                        src,
2756                        dst,
2757                        len: len as u32,
2758                    }
2759                }
2760            }
2761
2762            Op::Quantize {
2763                axis,
2764                scales,
2765                zero_points,
2766            } => {
2767                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2768                Thunk::Quantize {
2769                    x: node_offset(arena, node.inputs[0]),
2770                    q: node_offset(arena, node.id),
2771                    len: node.shape.num_elements().unwrap() as u32,
2772                    chan_axis: chan_axis as u32,
2773                    chan_dim: chan_dim as u32,
2774                    inner: inner as u32,
2775                    scales: scales.clone(),
2776                    zero_points: zero_points.clone(),
2777                }
2778            }
2779
2780            Op::FakeQuantize {
2781                bits,
2782                axis,
2783                ste,
2784                scale_mode,
2785            } => {
2786                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2787                let state_off = match scale_mode {
2788                    rlx_ir::op::ScaleMode::PerBatch => None,
2789                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2790                        // Second input carries the [chan_dim] scale state.
2791                        debug_assert_eq!(
2792                            node.inputs.len(),
2793                            2,
2794                            "EMA/Fixed FakeQuantize needs a state input"
2795                        );
2796                        Some(node_offset(arena, node.inputs[1]))
2797                    }
2798                };
2799                Thunk::FakeQuantize {
2800                    x: node_offset(arena, node.inputs[0]),
2801                    out: node_offset(arena, node.id),
2802                    len: node.shape.num_elements().unwrap() as u32,
2803                    chan_axis: chan_axis as u32,
2804                    chan_dim: chan_dim as u32,
2805                    inner: inner as u32,
2806                    bits: *bits,
2807                    ste: *ste,
2808                    scale_mode: *scale_mode,
2809                    state_off,
2810                }
2811            }
2812
2813            Op::FakeQuantizeLSQ { bits, axis } => {
2814                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2815                Thunk::FakeQuantizeLSQ {
2816                    x: node_offset(arena, node.inputs[0]),
2817                    scale_off: node_offset(arena, node.inputs[1]),
2818                    out: node_offset(arena, node.id),
2819                    len: node.shape.num_elements().unwrap() as u32,
2820                    chan_axis: chan_axis as u32,
2821                    chan_dim: chan_dim as u32,
2822                    inner: inner as u32,
2823                    bits: *bits,
2824                }
2825            }
2826
2827            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2828                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2829                Thunk::FakeQuantizeLSQBackwardX {
2830                    x: node_offset(arena, node.inputs[0]),
2831                    scale_off: node_offset(arena, node.inputs[1]),
2832                    dy: node_offset(arena, node.inputs[2]),
2833                    dx: node_offset(arena, node.id),
2834                    len: node.shape.num_elements().unwrap() as u32,
2835                    chan_axis: chan_axis as u32,
2836                    chan_dim: chan_dim as u32,
2837                    inner: inner as u32,
2838                    bits: *bits,
2839                }
2840            }
2841
2842            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2843                // Output shape is [chan_dim] — node.shape doesn't
2844                // describe the input data layout, but inputs[0] does.
2845                let in_shape = &graph.node(node.inputs[0]).shape;
2846                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2847                Thunk::FakeQuantizeLSQBackwardScale {
2848                    x: node_offset(arena, node.inputs[0]),
2849                    scale_off: node_offset(arena, node.inputs[1]),
2850                    dy: node_offset(arena, node.inputs[2]),
2851                    dscale: node_offset(arena, node.id),
2852                    len: in_shape.num_elements().unwrap() as u32,
2853                    chan_axis: chan_axis as u32,
2854                    chan_dim: chan_dim as u32,
2855                    inner: inner as u32,
2856                    bits: *bits,
2857                }
2858            }
2859
2860            Op::FakeQuantizeBackward { bits, axis, ste } => {
2861                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2862                Thunk::FakeQuantizeBackward {
2863                    x: node_offset(arena, node.inputs[0]),
2864                    dy: node_offset(arena, node.inputs[1]),
2865                    dx: node_offset(arena, node.id),
2866                    len: node.shape.num_elements().unwrap() as u32,
2867                    chan_axis: chan_axis as u32,
2868                    chan_dim: chan_dim as u32,
2869                    inner: inner as u32,
2870                    bits: *bits,
2871                    ste: *ste,
2872                }
2873            }
2874
2875            Op::Dequantize {
2876                axis,
2877                scales,
2878                zero_points,
2879            } => {
2880                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2881                Thunk::Dequantize {
2882                    q: node_offset(arena, node.inputs[0]),
2883                    x: node_offset(arena, node.id),
2884                    len: node.shape.num_elements().unwrap() as u32,
2885                    chan_axis: chan_axis as u32,
2886                    chan_dim: chan_dim as u32,
2887                    inner: inner as u32,
2888                    scales: scales.clone(),
2889                    zero_points: zero_points.clone(),
2890                }
2891            }
2892
2893            Op::Expand { .. } => {
2894                // Broadcast: build per-output-dim strides where any input dim
2895                // of size 1 has stride 0 (read the same element repeatedly).
2896                // Reuses the Thunk::Transpose runtime — N-D walk with strides
2897                // is identical; only the strides differ.
2898                let in_shape = &graph.node(node.inputs[0]).shape;
2899                let out_shape = &node.shape;
2900                let in_rank = in_shape.rank();
2901                let out_rank = out_shape.rank();
2902                // Implicit leading 1s if input has lower rank.
2903                let pad = out_rank.saturating_sub(in_rank);
2904                let in_dims: Vec<usize> = (0..out_rank)
2905                    .map(|i| {
2906                        if i < pad {
2907                            1
2908                        } else {
2909                            in_shape.dim(i - pad).unwrap_static()
2910                        }
2911                    })
2912                    .collect();
2913                // Row-major input strides (over the padded shape).
2914                let mut in_strides_full = vec![1usize; out_rank];
2915                for d in (0..out_rank.saturating_sub(1)).rev() {
2916                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2917                }
2918                let out_dims: Vec<u32> = (0..out_rank)
2919                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
2920                    .collect();
2921                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
2922                let in_strides: Vec<u32> = (0..out_rank)
2923                    .map(|i| {
2924                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2925                            0
2926                        } else {
2927                            in_strides_full[i] as u32
2928                        }
2929                    })
2930                    .collect();
2931                let in_total = in_dims.iter().product::<usize>() as u32;
2932                let src = node_offset(arena, node.inputs[0]);
2933                let dst = node_offset(arena, node.id);
2934                let elem_bytes = node.shape.dtype().size_bytes() as u8;
2935                match node.shape.dtype() {
2936                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
2937                        src,
2938                        dst,
2939                        in_total,
2940                        out_dims,
2941                        in_strides,
2942                    },
2943                    _ => Thunk::Transpose {
2944                        src,
2945                        dst,
2946                        in_total,
2947                        out_dims,
2948                        in_strides,
2949                        elem_bytes,
2950                    },
2951                }
2952            }
2953
2954            Op::RmsNorm { eps, .. } => {
2955                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2956                let total = node.shape.num_elements().unwrap();
2957                Thunk::RmsNorm {
2958                    src: node_offset(arena, node.inputs[0]),
2959                    g: node_offset(arena, node.inputs[1]),
2960                    b: node_offset(arena, node.inputs[2]),
2961                    dst: node_offset(arena, node.id),
2962                    rows: (total / h) as u32,
2963                    h: h as u32,
2964                    eps: *eps,
2965                }
2966            }
2967
2968            Op::LayerNorm { eps, .. } => {
2969                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2970                let total = node.shape.num_elements().unwrap();
2971                Thunk::LayerNorm {
2972                    src: node_offset(arena, node.inputs[0]),
2973                    g: node_offset(arena, node.inputs[1]),
2974                    b: node_offset(arena, node.inputs[2]),
2975                    dst: node_offset(arena, node.id),
2976                    rows: (total / h) as u32,
2977                    h: h as u32,
2978                    eps: *eps,
2979                }
2980            }
2981
2982            Op::GroupNorm { num_groups, eps } => {
2983                let in_shape = &graph.node(node.inputs[0]).shape;
2984                let (n, c, h, w) = conv_nchw_dims(in_shape);
2985                Thunk::GroupNorm {
2986                    src: node_offset(arena, node.inputs[0]),
2987                    g: node_offset(arena, node.inputs[1]),
2988                    b: node_offset(arena, node.inputs[2]),
2989                    dst: node_offset(arena, node.id),
2990                    n,
2991                    c,
2992                    h,
2993                    w,
2994                    num_groups: *num_groups as u32,
2995                    eps: *eps,
2996                }
2997            }
2998
2999            Op::BatchNormInference { eps } => {
3000                let in_shape = &graph.node(node.inputs[0]).shape;
3001                let rank = in_shape.rank();
3002                let channels = in_shape.dim(rank - 1).unwrap_static();
3003                let total = in_shape.num_elements().unwrap_or(0);
3004                let count = (total / channels.max(1)) as u32;
3005                Thunk::BatchNormInference {
3006                    src: node_offset(arena, node.inputs[0]),
3007                    g: node_offset(arena, node.inputs[1]),
3008                    b: node_offset(arena, node.inputs[2]),
3009                    mean: node_offset(arena, node.inputs[3]),
3010                    var: node_offset(arena, node.inputs[4]),
3011                    dst: node_offset(arena, node.id),
3012                    count,
3013                    channels: channels as u32,
3014                    eps: *eps,
3015                }
3016            }
3017
3018            Op::BatchNormInferenceBackwardInput { eps } => {
3019                let x_shape = &graph.node(node.inputs[0]).shape;
3020                let rank = x_shape.rank();
3021                let channels = x_shape.dim(rank - 1).unwrap_static();
3022                let total = x_shape.num_elements().unwrap_or(0);
3023                Thunk::BatchNormInferenceBackwardInput {
3024                    x: node_offset(arena, node.inputs[0]),
3025                    gamma: node_offset(arena, node.inputs[1]),
3026                    mean: node_offset(arena, node.inputs[2]),
3027                    var: node_offset(arena, node.inputs[3]),
3028                    dy: node_offset(arena, node.inputs[4]),
3029                    dx: node_offset(arena, node.id),
3030                    count: (total / channels.max(1)) as u32,
3031                    channels: channels as u32,
3032                    eps: *eps,
3033                }
3034            }
3035
3036            Op::BatchNormInferenceBackwardGamma { eps } => {
3037                let x_shape = &graph.node(node.inputs[0]).shape;
3038                let rank = x_shape.rank();
3039                let channels = x_shape.dim(rank - 1).unwrap_static();
3040                let total = x_shape.num_elements().unwrap_or(0);
3041                let _gamma_shape = &graph.node(node.id).shape;
3042                Thunk::BatchNormInferenceBackwardGamma {
3043                    x: node_offset(arena, node.inputs[0]),
3044                    mean: node_offset(arena, node.inputs[1]),
3045                    var: node_offset(arena, node.inputs[2]),
3046                    dy: node_offset(arena, node.inputs[3]),
3047                    dgamma: node_offset(arena, node.id),
3048                    count: (total / channels.max(1)) as u32,
3049                    channels: channels as u32,
3050                    eps: *eps,
3051                }
3052            }
3053
3054            Op::BatchNormInferenceBackwardBeta => {
3055                let dy_shape = &graph.node(node.inputs[0]).shape;
3056                let rank = dy_shape.rank();
3057                let channels = dy_shape.dim(rank - 1).unwrap_static();
3058                let total = dy_shape.num_elements().unwrap_or(0);
3059                Thunk::BatchNormInferenceBackwardBeta {
3060                    dy: node_offset(arena, node.inputs[0]),
3061                    dbeta: node_offset(arena, node.id),
3062                    count: (total / channels.max(1)) as u32,
3063                    channels: channels as u32,
3064                }
3065            }
3066
3067            Op::LayerNorm2d { eps } => {
3068                let in_shape = &graph.node(node.inputs[0]).shape;
3069                let (n, c, h, w) = conv_nchw_dims(in_shape);
3070                Thunk::LayerNorm2d {
3071                    src: node_offset(arena, node.inputs[0]),
3072                    g: node_offset(arena, node.inputs[1]),
3073                    b: node_offset(arena, node.inputs[2]),
3074                    dst: node_offset(arena, node.id),
3075                    n,
3076                    c,
3077                    h,
3078                    w,
3079                    eps: *eps,
3080                }
3081            }
3082
3083            Op::ConvTranspose2d {
3084                kernel_size,
3085                stride,
3086                padding,
3087                dilation,
3088                output_padding: _,
3089                groups,
3090            } => {
3091                let in_shape = &graph.node(node.inputs[0]).shape;
3092                let out_shape = &node.shape;
3093                let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3094                let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3095                Thunk::ConvTranspose2d {
3096                    src: node_offset(arena, node.inputs[0]),
3097                    weight: node_offset(arena, node.inputs[1]),
3098                    dst: node_offset(arena, node.id),
3099                    n,
3100                    c_in,
3101                    h,
3102                    w_in,
3103                    c_out,
3104                    h_out,
3105                    w_out,
3106                    kh: kernel_size[0] as u32,
3107                    kw: kernel_size[1] as u32,
3108                    sh: stride.first().copied().unwrap_or(1) as u32,
3109                    sw: stride.get(1).copied().unwrap_or(1) as u32,
3110                    ph: padding.first().copied().unwrap_or(0) as u32,
3111                    pw: padding.get(1).copied().unwrap_or(0) as u32,
3112                    dh: dilation.first().copied().unwrap_or(1) as u32,
3113                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
3114                    groups: *groups as u32,
3115                }
3116            }
3117
3118            Op::ResizeNearest2x => {
3119                let in_shape = &graph.node(node.inputs[0]).shape;
3120                let (n, c, h, w) = conv_nchw_dims(in_shape);
3121                Thunk::ResizeNearest2x {
3122                    src: node_offset(arena, node.inputs[0]),
3123                    dst: node_offset(arena, node.id),
3124                    n,
3125                    c,
3126                    h,
3127                    w,
3128                }
3129            }
3130
3131            Op::AxialRope2d {
3132                end_x,
3133                end_y,
3134                head_dim,
3135                num_heads,
3136                theta,
3137                repeat_factor,
3138            } => {
3139                let in_shape = &graph.node(node.inputs[0]).shape;
3140                let batch = in_shape.dim(0).unwrap_static() as u32;
3141                let seq = in_shape.dim(1).unwrap_static() as u32;
3142                let hidden = in_shape.dim(2).unwrap_static() as u32;
3143                Thunk::AxialRope2d {
3144                    src: node_offset(arena, node.inputs[0]),
3145                    dst: node_offset(arena, node.id),
3146                    batch,
3147                    seq,
3148                    hidden,
3149                    end_x: *end_x as u32,
3150                    end_y: *end_y as u32,
3151                    head_dim: *head_dim as u32,
3152                    num_heads: *num_heads as u32,
3153                    theta: *theta,
3154                    repeat_factor: *repeat_factor as u32,
3155                }
3156            }
3157
3158            Op::Softmax { axis } => {
3159                let rank = node.shape.rank();
3160                let ax = if *axis < 0 {
3161                    (rank as i32 + axis) as usize
3162                } else {
3163                    *axis as usize
3164                };
3165                let cols = node.shape.dim(ax).unwrap_static();
3166                let total = node.shape.num_elements().unwrap();
3167                let in_off = node_offset(arena, node.inputs[0]);
3168                let out_off = node_offset(arena, node.id);
3169                // Softmax kernel runs in-place on its data buffer. If the
3170                // planner gave input and output separate slots (their live
3171                // ranges overlap, so no aliasing), the output starts
3172                // uninitialized — emit a Copy first so the data is there.
3173                // Same pattern as Op::Activation.
3174                if in_off != out_off {
3175                    thunks.push(Thunk::Copy {
3176                        src: in_off,
3177                        dst: out_off,
3178                        len: total as u32,
3179                    });
3180                }
3181                Thunk::Softmax {
3182                    data: out_off,
3183                    rows: (total / cols) as u32,
3184                    cols: cols as u32,
3185                }
3186            }
3187
3188            Op::SelectiveScan { state_size } => {
3189                let in_shape = &graph.node(node.inputs[0]).shape;
3190                let (batch, seq, hidden) = (
3191                    in_shape.dim(0).unwrap_static(),
3192                    in_shape.dim(1).unwrap_static(),
3193                    in_shape.dim(2).unwrap_static(),
3194                );
3195                Thunk::SelectiveScan {
3196                    x: node_offset(arena, node.inputs[0]),
3197                    delta: node_offset(arena, node.inputs[1]),
3198                    a: node_offset(arena, node.inputs[2]),
3199                    b: node_offset(arena, node.inputs[3]),
3200                    c: node_offset(arena, node.inputs[4]),
3201                    dst: node_offset(arena, node.id),
3202                    batch: batch as u32,
3203                    seq: seq as u32,
3204                    hidden: hidden as u32,
3205                    state_size: *state_size as u32,
3206                }
3207            }
3208
3209            Op::GatedDeltaNet {
3210                state_size,
3211                carry_state,
3212            } => {
3213                let q_shape = &graph.node(node.inputs[0]).shape;
3214                let (batch, seq, heads) = (
3215                    q_shape.dim(0).unwrap_static(),
3216                    q_shape.dim(1).unwrap_static(),
3217                    q_shape.dim(2).unwrap_static(),
3218                );
3219                let state_off = if *carry_state {
3220                    node_offset(arena, node.inputs[5])
3221                } else {
3222                    0
3223                };
3224                Thunk::GatedDeltaNet {
3225                    q: node_offset(arena, node.inputs[0]),
3226                    k: node_offset(arena, node.inputs[1]),
3227                    v: node_offset(arena, node.inputs[2]),
3228                    g: node_offset(arena, node.inputs[3]),
3229                    beta: node_offset(arena, node.inputs[4]),
3230                    state: state_off,
3231                    dst: node_offset(arena, node.id),
3232                    batch: batch as u32,
3233                    seq: seq as u32,
3234                    heads: heads as u32,
3235                    state_size: *state_size as u32,
3236                }
3237            }
3238
3239            Op::QMatMul {
3240                x_zp,
3241                w_zp,
3242                out_zp,
3243                mult,
3244            } => {
3245                let x_shape = &graph.node(node.inputs[0]).shape;
3246                let w_shape = &graph.node(node.inputs[1]).shape;
3247                let m = x_shape.dim(0).unwrap_static();
3248                let k = x_shape.dim(1).unwrap_static();
3249                let n = w_shape.dim(1).unwrap_static();
3250                Thunk::QMatMul {
3251                    x: node_offset(arena, node.inputs[0]),
3252                    w: node_offset(arena, node.inputs[1]),
3253                    bias: node_offset(arena, node.inputs[2]),
3254                    out: node_offset(arena, node.id),
3255                    m: m as u32,
3256                    k: k as u32,
3257                    n: n as u32,
3258                    x_zp: *x_zp,
3259                    w_zp: *w_zp,
3260                    out_zp: *out_zp,
3261                    mult: *mult,
3262                }
3263            }
3264
3265            Op::QConv2d {
3266                kernel_size,
3267                stride,
3268                padding,
3269                dilation,
3270                groups,
3271                x_zp,
3272                w_zp,
3273                out_zp,
3274                mult,
3275            } => {
3276                let in_shape = &graph.node(node.inputs[0]).shape;
3277                let w_shape = &graph.node(node.inputs[1]).shape;
3278                let out_shape = &node.shape;
3279                if kernel_size.len() == 2
3280                    && in_shape.rank() == 4
3281                    && w_shape.rank() == 4
3282                    && out_shape.rank() == 4
3283                {
3284                    Thunk::QConv2d {
3285                        x: node_offset(arena, node.inputs[0]),
3286                        w: node_offset(arena, node.inputs[1]),
3287                        bias: node_offset(arena, node.inputs[2]),
3288                        out: node_offset(arena, node.id),
3289                        n: in_shape.dim(0).unwrap_static() as u32,
3290                        c_in: in_shape.dim(1).unwrap_static() as u32,
3291                        h: in_shape.dim(2).unwrap_static() as u32,
3292                        w_in: in_shape.dim(3).unwrap_static() as u32,
3293                        c_out: out_shape.dim(1).unwrap_static() as u32,
3294                        h_out: out_shape.dim(2).unwrap_static() as u32,
3295                        w_out: out_shape.dim(3).unwrap_static() as u32,
3296                        kh: kernel_size[0] as u32,
3297                        kw: kernel_size[1] as u32,
3298                        sh: stride.first().copied().unwrap_or(1) as u32,
3299                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3300                        ph: padding.first().copied().unwrap_or(0) as u32,
3301                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3302                        dh: dilation.first().copied().unwrap_or(1) as u32,
3303                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3304                        groups: *groups as u32,
3305                        x_zp: *x_zp,
3306                        w_zp: *w_zp,
3307                        out_zp: *out_zp,
3308                        mult: *mult,
3309                    }
3310                } else {
3311                    Thunk::Nop
3312                }
3313            }
3314
3315            Op::DequantMatMul { scheme } => {
3316                use rlx_ir::quant::QuantScheme;
3317                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3318                let total = node.shape.num_elements().unwrap();
3319                let m = total / n.max(1);
3320                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3321                let k = x_total / m.max(1);
3322                if scheme.is_gguf() {
3323                    Thunk::DequantMatMulGguf {
3324                        x: node_offset(arena, node.inputs[0]),
3325                        w_q: node_offset(arena, node.inputs[1]),
3326                        dst: node_offset(arena, node.id),
3327                        m: m as u32,
3328                        k: k as u32,
3329                        n: n as u32,
3330                        scheme: *scheme,
3331                    }
3332                } else {
3333                    match scheme {
3334                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3335                            x: node_offset(arena, node.inputs[0]),
3336                            w_q: node_offset(arena, node.inputs[1]),
3337                            scale: node_offset(arena, node.inputs[2]),
3338                            global_scale: node_offset(arena, node.inputs[3]),
3339                            dst: node_offset(arena, node.id),
3340                            m: m as u32,
3341                            k: k as u32,
3342                            n: n as u32,
3343                        },
3344                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3345                            x: node_offset(arena, node.inputs[0]),
3346                            w_q: node_offset(arena, node.inputs[1]),
3347                            scale: node_offset(arena, node.inputs[2]),
3348                            zp: node_offset(arena, node.inputs[3]),
3349                            dst: node_offset(arena, node.id),
3350                            m: m as u32,
3351                            k: k as u32,
3352                            n: n as u32,
3353                            block_size: *block_size,
3354                            is_asymmetric: false,
3355                        },
3356                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3357                            x: node_offset(arena, node.inputs[0]),
3358                            w_q: node_offset(arena, node.inputs[1]),
3359                            scale: node_offset(arena, node.inputs[2]),
3360                            dst: node_offset(arena, node.id),
3361                            m: m as u32,
3362                            k: k as u32,
3363                            n: n as u32,
3364                            e5m2: false,
3365                        },
3366                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3367                            x: node_offset(arena, node.inputs[0]),
3368                            w_q: node_offset(arena, node.inputs[1]),
3369                            scale: node_offset(arena, node.inputs[2]),
3370                            dst: node_offset(arena, node.id),
3371                            m: m as u32,
3372                            k: k as u32,
3373                            n: n as u32,
3374                            e5m2: true,
3375                        },
3376                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3377                            x: node_offset(arena, node.inputs[0]),
3378                            w_q: node_offset(arena, node.inputs[1]),
3379                            scale: node_offset(arena, node.inputs[2]),
3380                            zp: node_offset(arena, node.inputs[3]),
3381                            dst: node_offset(arena, node.id),
3382                            m: m as u32,
3383                            k: k as u32,
3384                            n: n as u32,
3385                            block_size: *block_size,
3386                            is_asymmetric: false,
3387                        },
3388                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3389                            x: node_offset(arena, node.inputs[0]),
3390                            w_q: node_offset(arena, node.inputs[1]),
3391                            scale: node_offset(arena, node.inputs[2]),
3392                            zp: node_offset(arena, node.inputs[3]),
3393                            dst: node_offset(arena, node.id),
3394                            m: m as u32,
3395                            k: k as u32,
3396                            n: n as u32,
3397                            block_size: *block_size,
3398                            is_asymmetric: true,
3399                        },
3400                        other => panic!(
3401                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3402                        ),
3403                    }
3404                }
3405            }
3406
3407            Op::LoraMatMul { scale } => {
3408                // x [m, k], w [k, n], a [k, r], b [r, n].
3409                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3410                let total = node.shape.num_elements().unwrap();
3411                let m = total / n.max(1);
3412                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3413                let k = x_total / m.max(1);
3414                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3415                let r = a_total / k.max(1);
3416                Thunk::LoraMatMul {
3417                    x: node_offset(arena, node.inputs[0]),
3418                    w: node_offset(arena, node.inputs[1]),
3419                    a: node_offset(arena, node.inputs[2]),
3420                    b: node_offset(arena, node.inputs[3]),
3421                    dst: node_offset(arena, node.id),
3422                    m: m as u32,
3423                    k: k as u32,
3424                    n: n as u32,
3425                    r: r as u32,
3426                    scale: *scale,
3427                }
3428            }
3429
3430            Op::Sample {
3431                top_k,
3432                top_p,
3433                temperature,
3434                seed,
3435            } => {
3436                let in_shape = &graph.node(node.inputs[0]).shape;
3437                // Logits are [batch, vocab] (or [vocab] → batch=1).
3438                let (batch, vocab) = if in_shape.rank() >= 2 {
3439                    (
3440                        in_shape.dim(0).unwrap_static(),
3441                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3442                    )
3443                } else {
3444                    (1, in_shape.num_elements().unwrap_or(0))
3445                };
3446                Thunk::Sample {
3447                    logits: node_offset(arena, node.inputs[0]),
3448                    dst: node_offset(arena, node.id),
3449                    batch: batch as u32,
3450                    vocab: vocab as u32,
3451                    top_k: *top_k as u32,
3452                    top_p: *top_p,
3453                    temperature: *temperature,
3454                    seed: *seed,
3455                }
3456            }
3457
3458            Op::Cumsum { axis, exclusive } => {
3459                // For now CPU only supports last-axis cumsum (the
3460                // common case for sampling / ragged offsets).
3461                // Other axes can lower via Transpose → Cumsum →
3462                // Transpose; not on the hot path today.
3463                let rank = node.shape.rank();
3464                let ax = if *axis < 0 {
3465                    (rank as i32 + axis) as usize
3466                } else {
3467                    *axis as usize
3468                };
3469                assert_eq!(
3470                    ax,
3471                    rank - 1,
3472                    "Cumsum only supports the last axis on CPU today"
3473                );
3474                let cols = node.shape.dim(ax).unwrap_static();
3475                let total = node.shape.num_elements().unwrap();
3476                Thunk::Cumsum {
3477                    src: node_offset(arena, node.inputs[0]),
3478                    dst: node_offset(arena, node.id),
3479                    rows: (total / cols) as u32,
3480                    cols: cols as u32,
3481                    exclusive: *exclusive,
3482                }
3483            }
3484
3485            Op::Attention {
3486                num_heads,
3487                head_dim,
3488                mask_kind,
3489                score_scale: _,
3490                attn_logit_softcap: _,
3491            } => {
3492                // Layout dispatch: rank-4 input could be either
3493                // `[B, S, H, D]` (CPU's historical convention) or
3494                // `[B, H, S, D]` (the convention the GPU/TPU backends
3495                // share). Disambiguate by which axis matches
3496                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3497                let q_shape = &graph.node(node.inputs[0]).shape;
3498                let k_shape = &graph.node(node.inputs[1]).shape;
3499                let rank = q_shape.rank();
3500                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3501                    let d1 = q_shape.dim(1).unwrap_static();
3502                    let d2 = q_shape.dim(2).unwrap_static();
3503                    if d1 == *num_heads {
3504                        // [B, H, S, D]
3505                        (
3506                            q_shape.dim(0).unwrap_static(),
3507                            d2,
3508                            k_shape.dim(2).unwrap_static(),
3509                            true,
3510                        )
3511                    } else {
3512                        // [B, S, H, D]
3513                        (
3514                            q_shape.dim(0).unwrap_static(),
3515                            d1,
3516                            k_shape.dim(1).unwrap_static(),
3517                            false,
3518                        )
3519                    }
3520                } else if rank >= 3 {
3521                    (
3522                        q_shape.dim(0).unwrap_static(),
3523                        q_shape.dim(1).unwrap_static(),
3524                        k_shape.dim(1).unwrap_static(),
3525                        false,
3526                    )
3527                } else {
3528                    (
3529                        1,
3530                        q_shape.dim(0).unwrap_static(),
3531                        k_shape.dim(0).unwrap_static(),
3532                        false,
3533                    )
3534                };
3535                let mask_off = if matches!(
3536                    mask_kind,
3537                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3538                ) {
3539                    node_offset(arena, node.inputs[3])
3540                } else {
3541                    0
3542                };
3543                let hs = (*num_heads * *head_dim) as u32;
3544                Thunk::Attention {
3545                    q: node_offset(arena, node.inputs[0]),
3546                    k: node_offset(arena, node.inputs[1]),
3547                    v: node_offset(arena, node.inputs[2]),
3548                    mask: mask_off,
3549                    out: node_offset(arena, node.id),
3550                    batch: batch as u32,
3551                    seq: seq as u32,
3552                    kv_seq: kv_seq as u32,
3553                    heads: *num_heads as u32,
3554                    head_dim: *head_dim as u32,
3555                    mask_kind: *mask_kind,
3556                    // Defaults: each input is its own contiguous buffer
3557                    // with row stride = hidden. Rewritten by the
3558                    // Narrow→Attention fusion when applicable.
3559                    q_row_stride: hs,
3560                    k_row_stride: hs,
3561                    v_row_stride: hs,
3562                    bhsd,
3563                }
3564            }
3565
3566            Op::AttentionBackward {
3567                num_heads,
3568                head_dim,
3569                mask_kind,
3570                wrt,
3571            } => {
3572                let q_shape = &graph.node(node.inputs[0]).shape;
3573                let k_shape = &graph.node(node.inputs[1]).shape;
3574                let rank = q_shape.rank();
3575                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3576                    let d1 = q_shape.dim(1).unwrap_static();
3577                    let d2 = q_shape.dim(2).unwrap_static();
3578                    if d1 == *num_heads {
3579                        (
3580                            q_shape.dim(0).unwrap_static(),
3581                            d2,
3582                            k_shape.dim(2).unwrap_static(),
3583                            true,
3584                        )
3585                    } else {
3586                        (
3587                            q_shape.dim(0).unwrap_static(),
3588                            d1,
3589                            k_shape.dim(1).unwrap_static(),
3590                            false,
3591                        )
3592                    }
3593                } else if rank >= 3 {
3594                    (
3595                        q_shape.dim(0).unwrap_static(),
3596                        q_shape.dim(1).unwrap_static(),
3597                        k_shape.dim(1).unwrap_static(),
3598                        false,
3599                    )
3600                } else {
3601                    (
3602                        1,
3603                        q_shape.dim(0).unwrap_static(),
3604                        k_shape.dim(0).unwrap_static(),
3605                        false,
3606                    )
3607                };
3608                let mask_off = if matches!(
3609                    mask_kind,
3610                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3611                ) {
3612                    node_offset(arena, node.inputs[4])
3613                } else {
3614                    0
3615                };
3616                Thunk::AttentionBackward {
3617                    q: node_offset(arena, node.inputs[0]),
3618                    k: node_offset(arena, node.inputs[1]),
3619                    v: node_offset(arena, node.inputs[2]),
3620                    dy: node_offset(arena, node.inputs[3]),
3621                    mask: mask_off,
3622                    out: node_offset(arena, node.id),
3623                    batch: batch as u32,
3624                    seq: seq as u32,
3625                    kv_seq: kv_seq as u32,
3626                    heads: *num_heads as u32,
3627                    head_dim: *head_dim as u32,
3628                    mask_kind: *mask_kind,
3629                    wrt: *wrt,
3630                    bhsd,
3631                }
3632            }
3633
3634            Op::FusedAttentionBlock {
3635                num_heads,
3636                head_dim,
3637                has_bias,
3638                has_rope,
3639            } => {
3640                let x_shape = &graph.node(node.inputs[0]).shape;
3641                let (batch, seq) = if x_shape.rank() >= 3 {
3642                    (
3643                        x_shape.dim(0).unwrap_static(),
3644                        x_shape.dim(1).unwrap_static(),
3645                    )
3646                } else {
3647                    let total = x_shape.num_elements().unwrap();
3648                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3649                    (total / (s * num_heads * head_dim), s)
3650                };
3651                let hs = (*num_heads * *head_dim) as u32;
3652                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3653                let mut idx = 4;
3654                let (qkv_b_off, out_b_off) = if *has_bias {
3655                    let qb = node_offset(arena, node.inputs[idx]);
3656                    let ob = node_offset(arena, node.inputs[idx + 1]);
3657                    idx += 2;
3658                    (qb, ob)
3659                } else {
3660                    (0, 0)
3661                };
3662                let (cos_off, sin_off, cl) = if *has_rope {
3663                    let c = node_offset(arena, node.inputs[idx]);
3664                    let s = node_offset(arena, node.inputs[idx + 1]);
3665                    let clen = get_len(graph, node.inputs[idx]);
3666                    (c, s, clen as u32)
3667                } else {
3668                    (0, 0, 0)
3669                };
3670
3671                Thunk::FusedAttnBlock {
3672                    hidden: node_offset(arena, node.inputs[0]),
3673                    qkv_w: node_offset(arena, node.inputs[1]),
3674                    out_w: node_offset(arena, node.inputs[2]),
3675                    mask: node_offset(arena, node.inputs[3]),
3676                    out: node_offset(arena, node.id),
3677                    qkv_b: qkv_b_off,
3678                    out_b: out_b_off,
3679                    cos: cos_off,
3680                    sin: sin_off,
3681                    cos_len: cl,
3682                    batch: batch as u32,
3683                    seq: seq as u32,
3684                    hs,
3685                    nh: *num_heads as u32,
3686                    dh: *head_dim as u32,
3687                    has_bias: *has_bias,
3688                    has_rope: *has_rope,
3689                }
3690            }
3691
3692            Op::Rope { head_dim, n_rot } => {
3693                let x_shape = &graph.node(node.inputs[0]).shape;
3694                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3695                    (
3696                        x_shape.dim(0).unwrap_static(),
3697                        x_shape.dim(1).unwrap_static(),
3698                        x_shape.dim(2).unwrap_static(),
3699                    )
3700                } else {
3701                    let total = x_shape.num_elements().unwrap();
3702                    (
3703                        1,
3704                        x_shape.dim(0).unwrap_static(),
3705                        total / x_shape.dim(0).unwrap_static(),
3706                    )
3707                };
3708                let cos_len = get_len(graph, node.inputs[1]);
3709                Thunk::Rope {
3710                    src: node_offset(arena, node.inputs[0]),
3711                    cos: node_offset(arena, node.inputs[1]),
3712                    sin: node_offset(arena, node.inputs[2]),
3713                    dst: node_offset(arena, node.id),
3714                    batch: batch as u32,
3715                    seq: seq as u32,
3716                    hidden: hidden as u32,
3717                    head_dim: *head_dim as u32,
3718                    n_rot: *n_rot as u32,
3719                    cos_len: cos_len as u32,
3720                    // Default: source rows are tightly packed (rewritten
3721                    // by the Narrow→Rope fusion pass below if Rope ends
3722                    // up reading from a wider parent like QKV).
3723                    src_row_stride: hidden as u32,
3724                }
3725            }
3726
3727            Op::FusedSwiGLU {
3728                cast_to: _,
3729                gate_first,
3730            } => {
3731                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3732                let total = node.shape.num_elements().unwrap();
3733                Thunk::FusedSwiGLU {
3734                    src: node_offset(arena, node.inputs[0]),
3735                    dst: node_offset(arena, node.id),
3736                    n_half: n_half as u32,
3737                    total: total as u32,
3738                    gate_first: *gate_first,
3739                }
3740            }
3741
3742            Op::Conv {
3743                kernel_size,
3744                stride,
3745                padding,
3746                dilation,
3747                groups,
3748            } => {
3749                let in_shape = &graph.node(node.inputs[0]).shape;
3750                let w_shape = &graph.node(node.inputs[1]).shape;
3751                let out_shape = &node.shape;
3752                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3753                // padding=0, dilation=1, groups=1. Emits a single
3754                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3755                let is_1x1_simple = kernel_size.len() == 2
3756                    && kernel_size[0] == 1
3757                    && kernel_size[1] == 1
3758                    && stride.iter().all(|&s| s == 1)
3759                    && padding.iter().all(|&p| p == 0)
3760                    && dilation.iter().all(|&d| d == 1)
3761                    && *groups == 1;
3762                if is_1x1_simple
3763                    && in_shape.rank() >= 3
3764                    && out_shape.rank() >= 3
3765                    && w_shape.rank() >= 2
3766                {
3767                    let (n, c_in, h, w) = conv_nchw_dims(in_shape);
3768                    let (_, c_out, _, _) = conv_nchw_dims(out_shape);
3769                    Thunk::Conv2D1x1 {
3770                        src: node_offset(arena, node.inputs[0]),
3771                        weight: node_offset(arena, node.inputs[1]),
3772                        dst: node_offset(arena, node.id),
3773                        n,
3774                        c_in,
3775                        c_out,
3776                        hw: h.saturating_mul(w),
3777                    }
3778                } else if kernel_size.len() == 2
3779                    && in_shape.rank() >= 3
3780                    && w_shape.rank() >= 2
3781                    && out_shape.rank() >= 3
3782                {
3783                    let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3784                    let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3785                    Thunk::Conv2D {
3786                        src: node_offset(arena, node.inputs[0]),
3787                        weight: node_offset(arena, node.inputs[1]),
3788                        dst: node_offset(arena, node.id),
3789                        n,
3790                        c_in,
3791                        h,
3792                        w: w_in,
3793                        c_out,
3794                        h_out,
3795                        w_out,
3796                        kh: kernel_size[0] as u32,
3797                        kw: kernel_size[1] as u32,
3798                        sh: stride.first().copied().unwrap_or(1) as u32,
3799                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3800                        ph: padding.first().copied().unwrap_or(0) as u32,
3801                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3802                        dh: dilation.first().copied().unwrap_or(1) as u32,
3803                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3804                        groups: *groups as u32,
3805                    }
3806                } else {
3807                    Thunk::Nop
3808                }
3809            }
3810
3811            Op::Pool {
3812                kind,
3813                kernel_size,
3814                stride,
3815                padding,
3816            } => {
3817                // Currently support 2D pooling on rank-4 NCHW tensors.
3818                let in_shape = &graph.node(node.inputs[0]).shape;
3819                let out_shape = &node.shape;
3820                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3821                    Thunk::Pool2D {
3822                        src: node_offset(arena, node.inputs[0]),
3823                        dst: node_offset(arena, node.id),
3824                        n: in_shape.dim(0).unwrap_static() as u32,
3825                        c: in_shape.dim(1).unwrap_static() as u32,
3826                        h: in_shape.dim(2).unwrap_static() as u32,
3827                        w: in_shape.dim(3).unwrap_static() as u32,
3828                        h_out: out_shape.dim(2).unwrap_static() as u32,
3829                        w_out: out_shape.dim(3).unwrap_static() as u32,
3830                        kh: kernel_size[0] as u32,
3831                        kw: kernel_size[1] as u32,
3832                        sh: stride.first().copied().unwrap_or(1) as u32,
3833                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3834                        ph: padding.first().copied().unwrap_or(0) as u32,
3835                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3836                        kind: *kind,
3837                    }
3838                } else {
3839                    Thunk::Nop
3840                }
3841            }
3842
3843            Op::Transpose { perm } => {
3844                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
3845                // runtime loop is just an N-D index walk + scatter.
3846                let in_shape = &graph.node(node.inputs[0]).shape;
3847                let in_rank = in_shape.rank();
3848                if perm.iter().any(|&p| p >= in_rank) {
3849                    Thunk::Nop
3850                } else {
3851                    let in_dims: Vec<usize> = (0..in_rank)
3852                        .map(|i| in_shape.dim(i).unwrap_static())
3853                        .collect();
3854                    // Row-major input strides: stride[d] = product of dims[d+1..].
3855                    let mut in_strides_full = vec![1usize; in_rank];
3856                    for d in (0..in_rank.saturating_sub(1)).rev() {
3857                        in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3858                    }
3859                    let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3860                    let in_strides: Vec<u32> =
3861                        perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3862                    let in_total = in_dims.iter().product::<usize>() as u32;
3863                    let src = node_offset(arena, node.inputs[0]);
3864                    let dst = node_offset(arena, node.id);
3865                    let elem_bytes = node.shape.dtype().size_bytes() as u8;
3866                    match node.shape.dtype() {
3867                        rlx_ir::DType::F64 => Thunk::TransposeF64 {
3868                            src,
3869                            dst,
3870                            in_total,
3871                            out_dims,
3872                            in_strides,
3873                        },
3874                        _ => Thunk::Transpose {
3875                            src,
3876                            dst,
3877                            in_total,
3878                            out_dims,
3879                            in_strides,
3880                            elem_bytes,
3881                        },
3882                    }
3883                }
3884            }
3885
3886            Op::ScatterAdd => {
3887                // updates: [num_updates, ...trailing], indices: [num_updates],
3888                // output: [out_dim, ...trailing]
3889                let upd_shape = &graph.node(node.inputs[0]).shape;
3890                let out_shape = &node.shape;
3891                let num_updates = upd_shape.dim(0).unwrap_static();
3892                let out_dim = out_shape.dim(0).unwrap_static();
3893                let trailing: usize = (1..out_shape.rank())
3894                    .map(|i| out_shape.dim(i).unwrap_static())
3895                    .product::<usize>()
3896                    .max(1);
3897                Thunk::ScatterAdd {
3898                    updates: node_offset(arena, node.inputs[0]),
3899                    indices: node_offset(arena, node.inputs[1]),
3900                    dst: node_offset(arena, node.id),
3901                    num_updates: num_updates as u32,
3902                    out_dim: out_dim as u32,
3903                    trailing: trailing as u32,
3904                }
3905            }
3906
3907            Op::GroupedMatMul => {
3908                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
3909                let in_shape = &graph.node(node.inputs[0]).shape;
3910                let w_shape = &graph.node(node.inputs[1]).shape;
3911                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3912                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3913                let num_experts = w_shape.dim(0).unwrap_static();
3914                let n = w_shape.dim(2).unwrap_static();
3915                Thunk::GroupedMatMul {
3916                    input: node_offset(arena, node.inputs[0]),
3917                    weight: node_offset(arena, node.inputs[1]),
3918                    expert_idx: node_offset(arena, node.inputs[2]),
3919                    dst: node_offset(arena, node.id),
3920                    m: m as u32,
3921                    k_dim: k_dim as u32,
3922                    n: n as u32,
3923                    num_experts: num_experts as u32,
3924                }
3925            }
3926
3927            Op::DequantGroupedMatMul { scheme } => {
3928                let in_shape = &graph.node(node.inputs[0]).shape;
3929                let w_shape = &graph.node(node.inputs[1]).shape;
3930                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3931                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3932                let out_shape = &node.shape;
3933                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3934                let block_elems = scheme.gguf_block_size() as usize;
3935                let block_bytes = scheme.gguf_block_bytes() as usize;
3936                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3937                let total_bytes = w_shape.num_elements().unwrap();
3938                let num_experts = total_bytes / slab_bytes.max(1);
3939                Thunk::DequantGroupedMatMulGguf {
3940                    input: node_offset(arena, node.inputs[0]),
3941                    w_q: node_offset(arena, node.inputs[1]),
3942                    expert_idx: node_offset(arena, node.inputs[2]),
3943                    dst: node_offset(arena, node.id),
3944                    m: m as u32,
3945                    k_dim: k_dim as u32,
3946                    n: n as u32,
3947                    num_experts: num_experts as u32,
3948                    scheme: *scheme,
3949                }
3950            }
3951
3952            Op::DequantMoEWeights { scheme } => {
3953                let w_shape = &graph.node(node.inputs[0]).shape;
3954                let out_shape = &node.shape;
3955                let num_experts = out_shape.dim(0).unwrap_static();
3956                let k_dim = out_shape.dim(1).unwrap_static();
3957                let n = out_shape.dim(2).unwrap_static();
3958                let block_elems = scheme.gguf_block_size() as usize;
3959                let block_bytes = scheme.gguf_block_bytes() as usize;
3960                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3961                let total_bytes = w_shape.num_elements().unwrap();
3962                assert_eq!(
3963                    total_bytes,
3964                    num_experts * slab_bytes,
3965                    "DequantMoEWeights packed bytes mismatch"
3966                );
3967                Thunk::DequantMoEWeightsGguf {
3968                    w_q: node_offset(arena, node.inputs[0]),
3969                    dst: node_offset(arena, node.id),
3970                    k_dim: k_dim as u32,
3971                    n: n as u32,
3972                    num_experts: num_experts as u32,
3973                    scheme: *scheme,
3974                }
3975            }
3976
3977            Op::TopK { k } => {
3978                let in_shape = &graph.node(node.inputs[0]).shape;
3979                let rank = in_shape.rank();
3980                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3981                let outer = in_shape.num_elements().unwrap() / axis_dim;
3982                let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
3983                Thunk::TopK {
3984                    src: node_offset(arena, node.inputs[0]),
3985                    dst: node_offset(arena, node.id),
3986                    outer: outer as u32,
3987                    axis_dim: axis_dim as u32,
3988                    k: *k as u32,
3989                    indices_i64,
3990                }
3991            }
3992
3993            Op::Reduce {
3994                op,
3995                axes,
3996                keep_dim: _,
3997            } => {
3998                // Decompose the input shape into [outer, reduced, inner]
3999                // around the reduced axis range. Non-contiguous reduced
4000                // axes aren't supported here — caller must transpose them
4001                // contiguous first (the coverage tool would surface the
4002                // gap if a model needs it).
4003                let in_shape = &graph.node(node.inputs[0]).shape;
4004                let rank = in_shape.rank();
4005                let mut sorted = axes.clone();
4006                sorted.sort();
4007                sorted.dedup();
4008                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4009                    && !sorted.is_empty()
4010                    && *sorted.last().unwrap() < rank;
4011                if !contiguous {
4012                    Thunk::Nop
4013                } else {
4014                    let first = sorted[0];
4015                    let last = *sorted.last().unwrap();
4016                    let outer: usize = (0..first)
4017                        .map(|i| in_shape.dim(i).unwrap_static())
4018                        .product::<usize>()
4019                        .max(1);
4020                    let reduced: usize = (first..=last)
4021                        .map(|i| in_shape.dim(i).unwrap_static())
4022                        .product();
4023                    let inner: usize = (last + 1..rank)
4024                        .map(|i| in_shape.dim(i).unwrap_static())
4025                        .product::<usize>()
4026                        .max(1);
4027                    let src = node_offset(arena, node.inputs[0]);
4028                    let dst = node_offset(arena, node.id);
4029                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4030                        Thunk::ReduceSumF64 {
4031                            src,
4032                            dst,
4033                            outer: outer as u32,
4034                            reduced: reduced as u32,
4035                            inner: inner as u32,
4036                        }
4037                    } else {
4038                        Thunk::Reduce {
4039                            src,
4040                            dst,
4041                            outer: outer as u32,
4042                            reduced: reduced as u32,
4043                            inner: inner as u32,
4044                            op: *op,
4045                        }
4046                    }
4047                }
4048            }
4049
4050            Op::Compare(cmp) => {
4051                let len = node.shape.num_elements().unwrap();
4052                let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4053                let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4054                Thunk::Compare {
4055                    lhs: node_offset(arena, node.inputs[0]),
4056                    rhs: node_offset(arena, node.inputs[1]),
4057                    dst: node_offset(arena, node.id),
4058                    len: len as u32,
4059                    op: *cmp,
4060                    inputs_i64,
4061                    inputs_elem_bytes: in_dtype.size_bytes() as u8,
4062                    dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4063                }
4064            }
4065
4066            Op::Where => {
4067                let len = node.shape.num_elements().unwrap();
4068                let elem_bytes = node.shape.dtype().size_bytes() as u8;
4069                let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4070                Thunk::Where {
4071                    cond: node_offset(arena, node.inputs[0]),
4072                    on_true: node_offset(arena, node.inputs[1]),
4073                    on_false: node_offset(arena, node.inputs[2]),
4074                    dst: node_offset(arena, node.id),
4075                    len: len as u32,
4076                    elem_bytes,
4077                    cond_elem_bytes,
4078                }
4079            }
4080
4081            Op::ReluBackward => {
4082                let len: usize = (0..node.shape.rank())
4083                    .map(|i| node.shape.dim(i).unwrap_static())
4084                    .product();
4085                let x = node_offset(arena, node.inputs[0]);
4086                let dy = node_offset(arena, node.inputs[1]);
4087                let dx = node_offset(arena, node.id);
4088                match node.shape.dtype() {
4089                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4090                        x,
4091                        dy,
4092                        dx,
4093                        len: len as u32,
4094                    },
4095                    _ => Thunk::ReluBackward {
4096                        x,
4097                        dy,
4098                        dx,
4099                        len: len as u32,
4100                    },
4101                }
4102            }
4103
4104            Op::ComplexNormSq => {
4105                let len: usize = (0..node.shape.rank())
4106                    .map(|i| node.shape.dim(i).unwrap_static())
4107                    .product();
4108                let src = node_offset(arena, node.inputs[0]);
4109                let dst = node_offset(arena, node.id);
4110                Thunk::ComplexNormSqF32 {
4111                    src,
4112                    dst,
4113                    len: len as u32,
4114                }
4115            }
4116
4117            Op::ComplexNormSqBackward => {
4118                let len: usize = (0..node.shape.rank())
4119                    .map(|i| node.shape.dim(i).unwrap_static())
4120                    .product();
4121                let z = node_offset(arena, node.inputs[0]);
4122                let g = node_offset(arena, node.inputs[1]);
4123                let dz = node_offset(arena, node.id);
4124                Thunk::ComplexNormSqBackwardF32 {
4125                    z,
4126                    g,
4127                    dz,
4128                    len: len as u32,
4129                }
4130            }
4131
4132            Op::Conjugate => {
4133                let len: usize = (0..node.shape.rank())
4134                    .map(|i| node.shape.dim(i).unwrap_static())
4135                    .product();
4136                Thunk::ConjugateC64 {
4137                    src: node_offset(arena, node.inputs[0]),
4138                    dst: node_offset(arena, node.id),
4139                    len: len as u32,
4140                }
4141            }
4142
4143            Op::ActivationBackward { kind } => {
4144                let len: usize = (0..node.shape.rank())
4145                    .map(|i| node.shape.dim(i).unwrap_static())
4146                    .product();
4147                let x = node_offset(arena, node.inputs[0]);
4148                let dy = node_offset(arena, node.inputs[1]);
4149                let dx = node_offset(arena, node.id);
4150                match node.shape.dtype() {
4151                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4152                        x,
4153                        dy,
4154                        dx,
4155                        len: len as u32,
4156                        kind: *kind,
4157                    },
4158                    _ => Thunk::ActivationBackward {
4159                        x,
4160                        dy,
4161                        dx,
4162                        len: len as u32,
4163                        kind: *kind,
4164                    },
4165                }
4166            }
4167
4168            Op::LayerNormBackwardInput { eps, .. } => {
4169                // axis = -1 only (matches forward LayerNorm thunk).
4170                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4171                let total = node.shape.num_elements().unwrap();
4172                Thunk::LayerNormBackwardInput {
4173                    x: node_offset(arena, node.inputs[0]),
4174                    gamma: node_offset(arena, node.inputs[1]),
4175                    dy: node_offset(arena, node.inputs[2]),
4176                    dx: node_offset(arena, node.id),
4177                    rows: (total / h) as u32,
4178                    h: h as u32,
4179                    eps: *eps,
4180                }
4181            }
4182
4183            Op::LayerNormBackwardGamma { eps, .. } => {
4184                let x_shape = &graph.node(node.inputs[0]).shape;
4185                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4186                let x_total = x_shape.num_elements().unwrap();
4187                Thunk::LayerNormBackwardGamma {
4188                    x: node_offset(arena, node.inputs[0]),
4189                    dy: node_offset(arena, node.inputs[1]),
4190                    dgamma: node_offset(arena, node.id),
4191                    rows: (x_total / h) as u32,
4192                    h: h as u32,
4193                    eps: *eps,
4194                }
4195            }
4196
4197            Op::RmsNormBackwardInput { eps, .. }
4198            | Op::RmsNormBackwardGamma { eps, .. }
4199            | Op::RmsNormBackwardBeta { eps, .. } => {
4200                let x_shape = &graph.node(node.inputs[0]).shape;
4201                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4202                let rows = (x_shape.num_elements().unwrap() / h) as u32;
4203                let off = |i: usize| node_offset(arena, node.inputs[i]);
4204                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4205                match &node.op {
4206                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4207                        x: common.0,
4208                        gamma: common.1,
4209                        beta: common.2,
4210                        dy: common.3,
4211                        dx: node_offset(arena, node.id),
4212                        rows: common.4,
4213                        h: common.5,
4214                        eps: common.6,
4215                    },
4216                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4217                        x: common.0,
4218                        gamma: common.1,
4219                        beta: common.2,
4220                        dy: common.3,
4221                        dgamma: node_offset(arena, node.id),
4222                        rows: common.4,
4223                        h: common.5,
4224                        eps: common.6,
4225                    },
4226                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4227                        x: common.0,
4228                        gamma: common.1,
4229                        beta: common.2,
4230                        dy: common.3,
4231                        dbeta: node_offset(arena, node.id),
4232                        rows: common.4,
4233                        h: common.5,
4234                        eps: common.6,
4235                    },
4236                    _ => unreachable!(),
4237                }
4238            }
4239
4240            Op::RopeBackward { head_dim, n_rot } => {
4241                let dy_shape = &graph.node(node.inputs[0]).shape;
4242                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4243                    (
4244                        dy_shape.dim(0).unwrap_static(),
4245                        dy_shape.dim(1).unwrap_static(),
4246                        dy_shape.dim(2).unwrap_static(),
4247                    )
4248                } else {
4249                    (
4250                        1,
4251                        dy_shape.dim(0).unwrap_static(),
4252                        dy_shape.dim(1).unwrap_static(),
4253                    )
4254                };
4255                let cos_shape = &graph.node(node.inputs[1]).shape;
4256                let cos_len = cos_shape.num_elements().unwrap();
4257                Thunk::RopeBackward {
4258                    dy: node_offset(arena, node.inputs[0]),
4259                    cos: node_offset(arena, node.inputs[1]),
4260                    sin: node_offset(arena, node.inputs[2]),
4261                    dx: node_offset(arena, node.id),
4262                    batch: batch as u32,
4263                    seq: seq as u32,
4264                    hidden: hidden as u32,
4265                    head_dim: *head_dim as u32,
4266                    n_rot: *n_rot as u32,
4267                    cos_len: cos_len as u32,
4268                }
4269            }
4270
4271            Op::CumsumBackward { exclusive, .. } => {
4272                let dy_shape = &graph.node(node.inputs[0]).shape;
4273                let rank = dy_shape.rank();
4274                let cols = dy_shape.dim(rank - 1).unwrap_static();
4275                let rows = dy_shape.num_elements().unwrap() / cols;
4276                Thunk::CumsumBackward {
4277                    dy: node_offset(arena, node.inputs[0]),
4278                    dx: node_offset(arena, node.id),
4279                    rows: rows as u32,
4280                    cols: cols as u32,
4281                    exclusive: *exclusive,
4282                }
4283            }
4284
4285            Op::GatherBackward { .. } => {
4286                let dy_shape = &graph.node(node.inputs[0]).shape;
4287                let idx_shape = &graph.node(node.inputs[1]).shape;
4288                let out_shape = &node.shape;
4289                let rank = out_shape.rank();
4290                let axis = match &node.op {
4291                    Op::GatherBackward { axis } => *axis,
4292                    _ => 0,
4293                };
4294                let axis_u = if axis < 0 {
4295                    (rank as i32 + axis) as usize
4296                } else {
4297                    axis as usize
4298                };
4299                let outer: usize = (0..axis_u)
4300                    .map(|i| dy_shape.dim(i).unwrap_static())
4301                    .product::<usize>()
4302                    .max(1);
4303                let num_idx = idx_shape.dim(axis_u).unwrap_static();
4304                let trailing: usize = (axis_u + 1..dy_shape.rank())
4305                    .map(|i| dy_shape.dim(i).unwrap_static())
4306                    .product::<usize>()
4307                    .max(1);
4308                let axis_dim = out_shape.dim(axis_u).unwrap_static();
4309                Thunk::GatherBackward {
4310                    dy: node_offset(arena, node.inputs[0]),
4311                    indices: node_offset(arena, node.inputs[1]),
4312                    dst: node_offset(arena, node.id),
4313                    outer: outer as u32,
4314                    axis_dim: axis_dim as u32,
4315                    num_idx: num_idx as u32,
4316                    trailing: trailing as u32,
4317                }
4318            }
4319
4320            Op::GroupNormBackwardInput { num_groups, eps }
4321            | Op::GroupNormBackwardGamma { num_groups, eps }
4322            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4323                let x_shape = &graph.node(node.inputs[0]).shape;
4324                let n = x_shape.dim(0).unwrap_static() as u32;
4325                let c = x_shape.dim(1).unwrap_static() as u32;
4326                let h = x_shape.dim(2).unwrap_static() as u32;
4327                let w = x_shape.dim(3).unwrap_static() as u32;
4328                match &node.op {
4329                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4330                        x: node_offset(arena, node.inputs[0]),
4331                        gamma: node_offset(arena, node.inputs[1]),
4332                        beta: node_offset(arena, node.inputs[2]),
4333                        dy: node_offset(arena, node.inputs[3]),
4334                        dx: node_offset(arena, node.id),
4335                        n,
4336                        c,
4337                        h,
4338                        w,
4339                        num_groups: *num_groups as u32,
4340                        eps: *eps,
4341                    },
4342                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4343                        x: node_offset(arena, node.inputs[0]),
4344                        dy: node_offset(arena, node.inputs[1]),
4345                        dgamma: node_offset(arena, node.id),
4346                        n,
4347                        c,
4348                        h,
4349                        w,
4350                        num_groups: *num_groups as u32,
4351                        eps: *eps,
4352                    },
4353                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4354                        dy: node_offset(arena, node.inputs[1]),
4355                        dbeta: node_offset(arena, node.id),
4356                        n,
4357                        c,
4358                        h,
4359                        w,
4360                    },
4361                    _ => unreachable!(),
4362                }
4363            }
4364
4365            Op::MaxPool2dBackward {
4366                kernel_size,
4367                stride,
4368                padding,
4369            } => {
4370                let x_shape = &graph.node(node.inputs[0]).shape;
4371                let dy_shape = &graph.node(node.inputs[1]).shape;
4372                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4373                    Thunk::MaxPool2dBackward {
4374                        x: node_offset(arena, node.inputs[0]),
4375                        dy: node_offset(arena, node.inputs[1]),
4376                        dx: node_offset(arena, node.id),
4377                        n: x_shape.dim(0).unwrap_static() as u32,
4378                        c: x_shape.dim(1).unwrap_static() as u32,
4379                        h: x_shape.dim(2).unwrap_static() as u32,
4380                        w: x_shape.dim(3).unwrap_static() as u32,
4381                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4382                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4383                        kh: kernel_size[0] as u32,
4384                        kw: kernel_size[1] as u32,
4385                        sh: stride.first().copied().unwrap_or(1) as u32,
4386                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4387                        ph: padding.first().copied().unwrap_or(0) as u32,
4388                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4389                    }
4390                } else {
4391                    Thunk::Nop
4392                }
4393            }
4394
4395            Op::Conv2dBackwardInput {
4396                kernel_size,
4397                stride,
4398                padding,
4399                dilation,
4400                groups,
4401            } => {
4402                let dy_shape = &graph.node(node.inputs[0]).shape;
4403                let w_shape = &graph.node(node.inputs[1]).shape;
4404                let out_shape = &node.shape;
4405                if kernel_size.len() == 2
4406                    && dy_shape.rank() == 4
4407                    && w_shape.rank() == 4
4408                    && out_shape.rank() == 4
4409                {
4410                    Thunk::Conv2dBackwardInput {
4411                        dy: node_offset(arena, node.inputs[0]),
4412                        w: node_offset(arena, node.inputs[1]),
4413                        dx: node_offset(arena, node.id),
4414                        n: out_shape.dim(0).unwrap_static() as u32,
4415                        c_in: out_shape.dim(1).unwrap_static() as u32,
4416                        h: out_shape.dim(2).unwrap_static() as u32,
4417                        w_in: out_shape.dim(3).unwrap_static() as u32,
4418                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4419                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4420                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4421                        kh: kernel_size[0] as u32,
4422                        kw: kernel_size[1] as u32,
4423                        sh: stride.first().copied().unwrap_or(1) as u32,
4424                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4425                        ph: padding.first().copied().unwrap_or(0) as u32,
4426                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4427                        dh: dilation.first().copied().unwrap_or(1) as u32,
4428                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4429                        groups: *groups as u32,
4430                    }
4431                } else {
4432                    Thunk::Nop
4433                }
4434            }
4435
4436            Op::Conv2dBackwardWeight {
4437                kernel_size,
4438                stride,
4439                padding,
4440                dilation,
4441                groups,
4442            } => {
4443                let x_shape = &graph.node(node.inputs[0]).shape;
4444                let dy_shape = &graph.node(node.inputs[1]).shape;
4445                let dw_shape = &node.shape;
4446                if kernel_size.len() == 2
4447                    && x_shape.rank() == 4
4448                    && dy_shape.rank() == 4
4449                    && dw_shape.rank() == 4
4450                {
4451                    Thunk::Conv2dBackwardWeight {
4452                        x: node_offset(arena, node.inputs[0]),
4453                        dy: node_offset(arena, node.inputs[1]),
4454                        dw: node_offset(arena, node.id),
4455                        n: x_shape.dim(0).unwrap_static() as u32,
4456                        c_in: x_shape.dim(1).unwrap_static() as u32,
4457                        h: x_shape.dim(2).unwrap_static() as u32,
4458                        w: x_shape.dim(3).unwrap_static() as u32,
4459                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4460                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4461                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4462                        kh: kernel_size[0] as u32,
4463                        kw: kernel_size[1] as u32,
4464                        sh: stride.first().copied().unwrap_or(1) as u32,
4465                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4466                        ph: padding.first().copied().unwrap_or(0) as u32,
4467                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4468                        dh: dilation.first().copied().unwrap_or(1) as u32,
4469                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4470                        groups: *groups as u32,
4471                    }
4472                } else {
4473                    Thunk::Nop
4474                }
4475            }
4476
4477            Op::Im2Col {
4478                kernel_size,
4479                stride,
4480                padding,
4481                dilation,
4482            } => {
4483                let x_shape = &graph.node(node.inputs[0]).shape;
4484                let out_shape = &node.shape;
4485                if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4486                    let n = match x_shape.dim(0) {
4487                        rlx_ir::shape::Dim::Static(v) => v as u32,
4488                        _ => 0,
4489                    };
4490                    let c_in = x_shape.dim(1).unwrap_static() as u32;
4491                    let h = x_shape.dim(2).unwrap_static() as u32;
4492                    let w = x_shape.dim(3).unwrap_static() as u32;
4493                    let kh = kernel_size[0] as u32;
4494                    let kw = kernel_size[1] as u32;
4495                    let sh = stride.first().copied().unwrap_or(1) as u32;
4496                    let sw = stride.get(1).copied().unwrap_or(1) as u32;
4497                    let ph = padding.first().copied().unwrap_or(0) as u32;
4498                    let pw = padding.get(1).copied().unwrap_or(0) as u32;
4499                    let dh = dilation.first().copied().unwrap_or(1) as u32;
4500                    let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4501                    let h_out = rlx_ir::shape::conv2d_spatial_output(
4502                        h as usize,
4503                        kh as usize,
4504                        sh as usize,
4505                        ph as usize,
4506                        dh as usize,
4507                    ) as u32;
4508                    let w_out = rlx_ir::shape::conv2d_spatial_output(
4509                        w as usize,
4510                        kw as usize,
4511                        sw as usize,
4512                        pw as usize,
4513                        dw_dil as usize,
4514                    ) as u32;
4515                    Thunk::Im2Col {
4516                        x: node_offset(arena, node.inputs[0]),
4517                        col: node_offset(arena, node.id),
4518                        n,
4519                        c_in,
4520                        h,
4521                        w,
4522                        h_out,
4523                        w_out,
4524                        kh,
4525                        kw,
4526                        sh,
4527                        sw,
4528                        ph,
4529                        pw,
4530                        dh,
4531                        dw_dil,
4532                    }
4533                } else {
4534                    Thunk::Nop
4535                }
4536            }
4537
4538            Op::SoftmaxCrossEntropyWithLogits => {
4539                let logits_shape = &graph.node(node.inputs[0]).shape;
4540                if logits_shape.rank() == 2 {
4541                    Thunk::SoftmaxCrossEntropy {
4542                        logits: node_offset(arena, node.inputs[0]),
4543                        labels: node_offset(arena, node.inputs[1]),
4544                        dst: node_offset(arena, node.id),
4545                        n: logits_shape.dim(0).unwrap_static() as u32,
4546                        c: logits_shape.dim(1).unwrap_static() as u32,
4547                    }
4548                } else {
4549                    Thunk::Nop
4550                }
4551            }
4552
4553            Op::SoftmaxCrossEntropyBackward => {
4554                let logits_shape = &graph.node(node.inputs[0]).shape;
4555                if logits_shape.rank() == 2 {
4556                    Thunk::SoftmaxCrossEntropyBackward {
4557                        logits: node_offset(arena, node.inputs[0]),
4558                        labels: node_offset(arena, node.inputs[1]),
4559                        d_loss: node_offset(arena, node.inputs[2]),
4560                        dlogits: node_offset(arena, node.id),
4561                        n: logits_shape.dim(0).unwrap_static() as u32,
4562                        c: logits_shape.dim(1).unwrap_static() as u32,
4563                    }
4564                } else {
4565                    Thunk::Nop
4566                }
4567            }
4568
4569            Op::DenseSolve => {
4570                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4571                let a_shape = &graph.node(node.inputs[0]).shape;
4572                let n = a_shape.dim(0).unwrap_static();
4573                debug_assert_eq!(
4574                    n,
4575                    a_shape.dim(1).unwrap_static(),
4576                    "DenseSolve: A must be square"
4577                );
4578                let b_elems = node.shape.num_elements().unwrap();
4579                let nrhs = b_elems / n;
4580                match node.shape.dtype() {
4581                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4582                        a: node_offset(arena, node.inputs[0]),
4583                        b: node_offset(arena, node.inputs[1]),
4584                        x: node_offset(arena, node.id),
4585                        n: n as u32,
4586                        nrhs: nrhs as u32,
4587                    },
4588                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4589                        a: node_offset(arena, node.inputs[0]),
4590                        b: node_offset(arena, node.inputs[1]),
4591                        x: node_offset(arena, node.id),
4592                        n: n as u32,
4593                        nrhs: nrhs as u32,
4594                    },
4595                    other => panic!(
4596                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4597                         Add another variant when needed."
4598                    ),
4599                }
4600            }
4601
4602            Op::BatchedDenseSolve => {
4603                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4604                let a_shape = &graph.node(node.inputs[0]).shape;
4605                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4606                let batch = a_shape.dim(0).unwrap_static();
4607                let n = a_shape.dim(1).unwrap_static();
4608                debug_assert_eq!(
4609                    n,
4610                    a_shape.dim(2).unwrap_static(),
4611                    "BatchedDenseSolve: A's last two dims must match"
4612                );
4613                let total = node.shape.num_elements().unwrap();
4614                let nrhs = total / (batch * n);
4615                match node.shape.dtype() {
4616                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4617                        a: node_offset(arena, node.inputs[0]),
4618                        b: node_offset(arena, node.inputs[1]),
4619                        x: node_offset(arena, node.id),
4620                        batch: batch as u32,
4621                        n: n as u32,
4622                        nrhs: nrhs as u32,
4623                    },
4624                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4625                        a: node_offset(arena, node.inputs[0]),
4626                        b: node_offset(arena, node.inputs[1]),
4627                        x: node_offset(arena, node.id),
4628                        batch: batch as u32,
4629                        n: n as u32,
4630                        nrhs: nrhs as u32,
4631                    },
4632                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4633                }
4634            }
4635
4636            Op::Scan {
4637                body,
4638                length,
4639                save_trajectory,
4640                num_bcast,
4641                num_xs,
4642                num_checkpoints,
4643            } => {
4644                assert!(
4645                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4646                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4647                    *num_checkpoints,
4648                    *length
4649                );
4650                if *num_checkpoints != 0 && *num_checkpoints != *length {
4651                    assert!(
4652                        *save_trajectory,
4653                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4654                    );
4655                }
4656                // Plan + compile the body sub-graph standalone. The body
4657                // gets its own Arena; per execution we clone its
4658                // pristine bytes, copy the outer carry (and per-step xs
4659                // slices, if any) into the body's Input slots, run the
4660                // body schedule N times, then copy the body's output
4661                // back to the outer arena.
4662                //
4663                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4664                // — first declared is the carry, rest are x_t_i. Single
4665                // graph output (the next carry), same shape as carry.
4666                let body_plan = rlx_opt::memory::plan_memory(body);
4667                let _body_arena_size = body_plan.arena_size;
4668                // Snapshot per-input byte offsets before plan_memory
4669                // moves into the Arena below.
4670                let body_offsets: HashMap<NodeId, usize> = body_plan
4671                    .assignments
4672                    .iter()
4673                    .map(|(id, slot)| (*id, slot.offset))
4674                    .collect();
4675
4676                // Collect body Input nodes in NodeId order; first is
4677                // carry, rest are per-step xs in matching order.
4678                let mut body_inputs: Vec<NodeId> = body
4679                    .nodes()
4680                    .iter()
4681                    .filter(|n| matches!(n.op, Op::Input { .. }))
4682                    .map(|n| n.id)
4683                    .collect();
4684                body_inputs.sort();
4685                let n_body_inputs = body_inputs.len();
4686                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4687                if n_body_inputs != expected {
4688                    let names: Vec<String> = body
4689                        .nodes()
4690                        .iter()
4691                        .filter_map(|n| match &n.op {
4692                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4693                            _ => None,
4694                        })
4695                        .collect();
4696                    panic!(
4697                        "Op::Scan body has {} Op::Input nodes; expected {} \
4698                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4699                        n_body_inputs,
4700                        expected,
4701                        *num_bcast,
4702                        *num_xs,
4703                        names.join(", ")
4704                    );
4705                }
4706
4707                let body_input_id = body_inputs[0];
4708                let body_input_off = body_offsets[&body_input_id];
4709                let body_output_id = body
4710                    .outputs
4711                    .first()
4712                    .copied()
4713                    .expect("Op::Scan body must declare one output");
4714                let body_output_off = body_offsets[&body_output_id];
4715
4716                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4717                // Fill body Constant nodes — mirror the outer-graph logic
4718                // in rlx-runtime/src/backend.rs (dtype-aware).
4719                for n in body.nodes() {
4720                    if let Op::Constant { data } = &n.op
4721                        && body_arena.has_buffer(n.id)
4722                        && !data.is_empty()
4723                    {
4724                        match n.shape.dtype() {
4725                            rlx_ir::DType::F64 => {
4726                                let off = body_arena.byte_offset(n.id);
4727                                let buf = body_arena.raw_buf_mut();
4728                                let nbytes = (buf.len() - off).min(data.len());
4729                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4730                            }
4731                            _ => {
4732                                let buf = body_arena.slice_mut(n.id);
4733                                let n_floats = data.len() / 4;
4734                                let n_lim = buf.len().min(n_floats);
4735                                for i in 0..n_lim {
4736                                    let bytes = [
4737                                        data[i * 4],
4738                                        data[i * 4 + 1],
4739                                        data[i * 4 + 2],
4740                                        data[i * 4 + 3],
4741                                    ];
4742                                    buf[i] = f32::from_le_bytes(bytes);
4743                                }
4744                            }
4745                        }
4746                    }
4747                }
4748                let body_init = body_arena.raw_buf().to_vec();
4749                let body_schedule = compile_thunks(body, &body_arena);
4750
4751                // Carry bytes — for trajectory mode, the outer node's
4752                // shape is [length, *carry_shape], so dividing by length
4753                // gives one row's bytes; the body's input slot still
4754                // holds carry_shape bytes.
4755                let carry_bytes = if *save_trajectory {
4756                    let total = node
4757                        .shape
4758                        .size_bytes()
4759                        .expect("Op::Scan trajectory output must have static shape");
4760                    total / *length as usize
4761                } else {
4762                    node.shape
4763                        .size_bytes()
4764                        .expect("Op::Scan carry must have static shape")
4765                };
4766
4767                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
4768                // outer node.inputs[1..1+num_bcast]. They keep their
4769                // natural shape (no [length, ...] prefix) and are
4770                // copied into body_buf ONCE before the scan loop.
4771                let mut bcast_inputs: Vec<(usize, usize, u32)> =
4772                    Vec::with_capacity(*num_bcast as usize);
4773                for i in 0..*num_bcast as usize {
4774                    let body_b_id = body_inputs[1 + i];
4775                    let body_b_off = body_offsets[&body_b_id];
4776                    let outer_b_id = node.inputs[1 + i];
4777                    let outer_b_off = node_offset(arena, outer_b_id);
4778                    let outer_b_shape = &graph.node(outer_b_id).shape;
4779                    let total = outer_b_shape
4780                        .size_bytes()
4781                        .expect("Op::Scan bcast must have static shape");
4782                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4783                }
4784
4785                // xs occupy body_inputs[1+num_bcast..] and node.inputs
4786                // [1+num_bcast..]. Each has shape [length, *per_step];
4787                // per-step bytes = total / length.
4788                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4789                let xs_base = 1 + *num_bcast as usize;
4790                for i in 0..*num_xs as usize {
4791                    let body_x_id = body_inputs[xs_base + i];
4792                    let body_x_off = body_offsets[&body_x_id];
4793                    let outer_xs_id = node.inputs[xs_base + i];
4794                    let outer_xs_off = node_offset(arena, outer_xs_id);
4795                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4796                    let total = outer_xs_shape
4797                        .size_bytes()
4798                        .expect("Op::Scan xs must have static shape");
4799                    let per_step = total / *length as usize;
4800                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4801                }
4802
4803                Thunk::Scan {
4804                    body: Arc::new(body_schedule),
4805                    body_init: Arc::new(body_init),
4806                    body_input_off,
4807                    body_output_off,
4808                    outer_init_off: node_offset(arena, node.inputs[0]),
4809                    outer_final_off: node_offset(arena, node.id),
4810                    length: *length,
4811                    carry_bytes: carry_bytes as u32,
4812                    save_trajectory: *save_trajectory,
4813                    xs_inputs: Arc::new(xs_inputs),
4814                    bcast_inputs: Arc::new(bcast_inputs),
4815                    num_checkpoints: *num_checkpoints,
4816                }
4817            }
4818
4819            Op::ScanBackward {
4820                body_vjp,
4821                length,
4822                save_trajectory,
4823                num_xs,
4824                num_checkpoints,
4825                forward_body,
4826            } => {
4827                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4828                if is_recursive {
4829                    assert!(
4830                        forward_body.is_some(),
4831                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
4832                    );
4833                }
4834                // body_vjp has signature
4835                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
4836                // Identify slots:
4837                //   * "d_output" by exact name (AD-introduced seed Input).
4838                //   * Remaining Inputs sorted by NodeId — first is the
4839                //     carry mirror, rest are x_t_i mirrors in body's
4840                //     original Op::Input declaration order.
4841                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4842                let body_offsets: HashMap<NodeId, usize> = body_plan
4843                    .assignments
4844                    .iter()
4845                    .map(|(id, slot)| (*id, slot.offset))
4846                    .collect();
4847                let mut body_d_output_off: Option<usize> = None;
4848                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4849                for n in body_vjp.nodes() {
4850                    if let Op::Input { name } = &n.op {
4851                        let off = body_offsets[&n.id];
4852                        if name == "d_output" {
4853                            body_d_output_off = Some(off);
4854                        } else {
4855                            body_other_inputs.push((n.id, off));
4856                        }
4857                    }
4858                }
4859                body_other_inputs.sort_by_key(|(id, _)| *id);
4860                let body_d_output_off =
4861                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4862                let expected_others = 1 + *num_xs as usize;
4863                assert_eq!(
4864                    body_other_inputs.len(),
4865                    expected_others,
4866                    "ScanBackward body_vjp has {} non-d_output Inputs; \
4867                     expected {} (1 carry + {} xs)",
4868                    body_other_inputs.len(),
4869                    expected_others,
4870                    num_xs
4871                );
4872                let body_carry_in_off = body_other_inputs[0].1;
4873                let body_x_offs: Vec<usize> = body_other_inputs
4874                    .iter()
4875                    .skip(1)
4876                    .map(|(_, off)| *off)
4877                    .collect();
4878                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4879
4880                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4881                // Fill body_vjp's Constants (mirrors the Scan lowering).
4882                for n in body_vjp.nodes() {
4883                    if let Op::Constant { data } = &n.op
4884                        && body_arena.has_buffer(n.id)
4885                        && !data.is_empty()
4886                    {
4887                        match n.shape.dtype() {
4888                            rlx_ir::DType::F64 => {
4889                                let off = body_arena.byte_offset(n.id);
4890                                let buf = body_arena.raw_buf_mut();
4891                                let nb = (buf.len() - off).min(data.len());
4892                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4893                            }
4894                            _ => {
4895                                let buf = body_arena.slice_mut(n.id);
4896                                let nf = data.len() / 4;
4897                                let nl = buf.len().min(nf);
4898                                for i in 0..nl {
4899                                    let bytes = [
4900                                        data[i * 4],
4901                                        data[i * 4 + 1],
4902                                        data[i * 4 + 2],
4903                                        data[i * 4 + 3],
4904                                    ];
4905                                    buf[i] = f32::from_le_bytes(bytes);
4906                                }
4907                            }
4908                        }
4909                    }
4910                }
4911                let body_init = body_arena.raw_buf().to_vec();
4912                let body_schedule = compile_thunks(body_vjp, &body_arena);
4913
4914                // Carry bytes from the dcarry output node (== carry shape).
4915                let carry_bytes = body_vjp
4916                    .node(body_vjp.outputs[0])
4917                    .shape
4918                    .size_bytes()
4919                    .expect("ScanBackward dcarry must be statically shaped");
4920                let carry_elem_size = body_vjp
4921                    .node(body_vjp.outputs[0])
4922                    .shape
4923                    .dtype()
4924                    .size_bytes() as u32;
4925
4926                // For each xs input on the outer node:
4927                // (outer_xs_base, per_step_bytes).
4928                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4929                for i in 0..*num_xs as usize {
4930                    let outer_xs_id = node.inputs[3 + i];
4931                    let outer_xs_off = node_offset(arena, outer_xs_id);
4932                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4933                    let total = outer_xs_shape
4934                        .size_bytes()
4935                        .expect("ScanBackward xs must have static shape");
4936                    let per_step = total / *length as usize;
4937                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4938                }
4939
4940                // If recursive checkpointing is active, we also compile
4941                // the forward body so the executor can recompute
4942                // intermediate carries. The forward body is supplied
4943                // by the AD pass via `forward_body: Some(_)`.
4944                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4945                    if is_recursive {
4946                        let fb = forward_body.as_ref().unwrap();
4947                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4948                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4949                            .assignments
4950                            .iter()
4951                            .map(|(id, slot)| (*id, slot.offset))
4952                            .collect();
4953                        let mut fb_inputs: Vec<NodeId> = fb
4954                            .nodes()
4955                            .iter()
4956                            .filter(|n| matches!(n.op, Op::Input { .. }))
4957                            .map(|n| n.id)
4958                            .collect();
4959                        fb_inputs.sort();
4960                        let fb_carry = fb_offsets[&fb_inputs[0]];
4961                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4962                            .map(|i| fb_offsets[&fb_inputs[i]])
4963                            .collect();
4964                        let fb_out = fb_offsets[&fb.outputs[0]];
4965                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4966                        for n in fb.nodes() {
4967                            if let Op::Constant { data } = &n.op
4968                                && fb_arena.has_buffer(n.id)
4969                                && !data.is_empty()
4970                            {
4971                                // Byte-copy works for any
4972                                // numeric dtype as long as the
4973                                // arena slot is sized to hold
4974                                // it — the Constant's `data`
4975                                // already encodes the right
4976                                // bytes per element.
4977                                let off = fb_arena.byte_offset(n.id);
4978                                let buf = fb_arena.raw_buf_mut();
4979                                let nb = (buf.len() - off).min(data.len());
4980                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4981                            }
4982                        }
4983                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4984                        let fb_sched = compile_thunks(fb, &fb_arena);
4985                        (
4986                            Some(Arc::new(fb_sched)),
4987                            Some(Arc::new(fb_init_bytes)),
4988                            fb_carry,
4989                            fb_out,
4990                            fb_xs,
4991                        )
4992                    } else {
4993                        (None, None, 0, 0, Vec::new())
4994                    };
4995
4996                Thunk::ScanBackward {
4997                    body_vjp: Arc::new(body_schedule),
4998                    body_init: Arc::new(body_init),
4999                    body_carry_in_off,
5000                    body_x_offs: Arc::new(body_x_offs),
5001                    body_d_output_off,
5002                    body_dcarry_out_off,
5003                    outer_init_off: node_offset(arena, node.inputs[0]),
5004                    outer_traj_off: node_offset(arena, node.inputs[1]),
5005                    outer_upstream_off: node_offset(arena, node.inputs[2]),
5006                    outer_xs_offs: Arc::new(outer_xs_offs),
5007                    outer_dinit_off: node_offset(arena, node.id),
5008                    length: *length,
5009                    carry_bytes: carry_bytes as u32,
5010                    carry_elem_size,
5011                    save_trajectory: *save_trajectory,
5012                    num_checkpoints: *num_checkpoints,
5013                    forward_body: fb_schedule,
5014                    forward_body_init: fb_init,
5015                    forward_body_carry_in_off: fb_carry_in_off,
5016                    forward_body_output_off: fb_output_off,
5017                    forward_body_x_offs: Arc::new(fb_x_offs),
5018                }
5019            }
5020
5021            Op::ScanBackwardXs {
5022                body_vjp,
5023                length,
5024                save_trajectory,
5025                num_xs,
5026                xs_idx,
5027                num_checkpoints,
5028                forward_body,
5029            } => {
5030                assert!(
5031                    *num_checkpoints == 0 || *num_checkpoints <= *length,
5032                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5033                    *num_checkpoints,
5034                    *length
5035                );
5036                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5037                if is_recursive {
5038                    assert!(
5039                        forward_body.is_some(),
5040                        "Op::ScanBackwardXs with num_checkpoints<length \
5041                         requires forward_body"
5042                    );
5043                }
5044                // Mirror ScanBackward's body_vjp slot identification +
5045                // arena prep, then add: per-iteration extraction of the
5046                // body_vjp output that corresponds to the chosen xs.
5047                //
5048                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
5049                //   outputs[0]      = dcarry
5050                //   outputs[1 + i]  = dx_t_i
5051                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5052                let body_offsets: HashMap<NodeId, usize> = body_plan
5053                    .assignments
5054                    .iter()
5055                    .map(|(id, slot)| (*id, slot.offset))
5056                    .collect();
5057                let mut body_d_output_off: Option<usize> = None;
5058                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5059                for n in body_vjp.nodes() {
5060                    if let Op::Input { name } = &n.op {
5061                        let off = body_offsets[&n.id];
5062                        if name == "d_output" {
5063                            body_d_output_off = Some(off);
5064                        } else {
5065                            body_other_inputs.push((n.id, off));
5066                        }
5067                    }
5068                }
5069                body_other_inputs.sort_by_key(|(id, _)| *id);
5070                let body_d_output_off =
5071                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5072                let expected_others = 1 + *num_xs as usize;
5073                assert_eq!(
5074                    body_other_inputs.len(),
5075                    expected_others,
5076                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5077                    body_other_inputs.len(),
5078                    expected_others
5079                );
5080                let body_carry_in_off = body_other_inputs[0].1;
5081                let body_x_offs: Vec<usize> = body_other_inputs
5082                    .iter()
5083                    .skip(1)
5084                    .map(|(_, off)| *off)
5085                    .collect();
5086                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5087                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5088                let body_dxs_out_off = body_offsets[&dxs_out_node];
5089
5090                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5091                for n in body_vjp.nodes() {
5092                    if let Op::Constant { data } = &n.op
5093                        && body_arena.has_buffer(n.id)
5094                        && !data.is_empty()
5095                    {
5096                        match n.shape.dtype() {
5097                            rlx_ir::DType::F64 => {
5098                                let off = body_arena.byte_offset(n.id);
5099                                let buf = body_arena.raw_buf_mut();
5100                                let nb = (buf.len() - off).min(data.len());
5101                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5102                            }
5103                            _ => {
5104                                let buf = body_arena.slice_mut(n.id);
5105                                let nf = data.len() / 4;
5106                                let nl = buf.len().min(nf);
5107                                for i in 0..nl {
5108                                    let bytes = [
5109                                        data[i * 4],
5110                                        data[i * 4 + 1],
5111                                        data[i * 4 + 2],
5112                                        data[i * 4 + 3],
5113                                    ];
5114                                    buf[i] = f32::from_le_bytes(bytes);
5115                                }
5116                            }
5117                        }
5118                    }
5119                }
5120                let body_init = body_arena.raw_buf().to_vec();
5121                let body_schedule = compile_thunks(body_vjp, &body_arena);
5122
5123                let carry_bytes = body_vjp
5124                    .node(body_vjp.outputs[0])
5125                    .shape
5126                    .size_bytes()
5127                    .expect("ScanBackwardXs dcarry must be statically shaped");
5128                let carry_elem_size = body_vjp
5129                    .node(body_vjp.outputs[0])
5130                    .shape
5131                    .dtype()
5132                    .size_bytes() as u32;
5133                let per_step_bytes = body_vjp
5134                    .node(dxs_out_node)
5135                    .shape
5136                    .size_bytes()
5137                    .expect("ScanBackwardXs dxs body output must be statically shaped");
5138
5139                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5140                for i in 0..*num_xs as usize {
5141                    let outer_xs_id = node.inputs[3 + i];
5142                    let outer_xs_off = node_offset(arena, outer_xs_id);
5143                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
5144                    let total = outer_xs_shape
5145                        .size_bytes()
5146                        .expect("ScanBackwardXs xs must have static shape");
5147                    let per_step = total / *length as usize;
5148                    outer_xs_offs.push((outer_xs_off, per_step as u32));
5149                }
5150
5151                // Compile forward_body for recompute when checkpointed.
5152                // Mirrors the same code path in the ScanBackward arm.
5153                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5154                    if is_recursive {
5155                        let fb = forward_body.as_ref().unwrap();
5156                        let fb_plan = rlx_opt::memory::plan_memory(fb);
5157                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
5158                            .assignments
5159                            .iter()
5160                            .map(|(id, slot)| (*id, slot.offset))
5161                            .collect();
5162                        let mut fb_inputs: Vec<NodeId> = fb
5163                            .nodes()
5164                            .iter()
5165                            .filter(|n| matches!(n.op, Op::Input { .. }))
5166                            .map(|n| n.id)
5167                            .collect();
5168                        fb_inputs.sort();
5169                        let fb_carry = fb_offsets[&fb_inputs[0]];
5170                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
5171                            .map(|i| fb_offsets[&fb_inputs[i]])
5172                            .collect();
5173                        let fb_out = fb_offsets[&fb.outputs[0]];
5174                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5175                        for n in fb.nodes() {
5176                            if let Op::Constant { data } = &n.op
5177                                && fb_arena.has_buffer(n.id)
5178                                && !data.is_empty()
5179                            {
5180                                // Byte-copy works for any
5181                                // numeric dtype as long as the
5182                                // arena slot is sized to hold
5183                                // it — the Constant's `data`
5184                                // already encodes the right
5185                                // bytes per element.
5186                                let off = fb_arena.byte_offset(n.id);
5187                                let buf = fb_arena.raw_buf_mut();
5188                                let nb = (buf.len() - off).min(data.len());
5189                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5190                            }
5191                        }
5192                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
5193                        let fb_sched = compile_thunks(fb, &fb_arena);
5194                        (
5195                            Some(Arc::new(fb_sched)),
5196                            Some(Arc::new(fb_init_bytes)),
5197                            fb_carry,
5198                            fb_out,
5199                            fb_xs,
5200                        )
5201                    } else {
5202                        (None, None, 0, 0, Vec::new())
5203                    };
5204
5205                Thunk::ScanBackwardXs {
5206                    body_vjp: Arc::new(body_schedule),
5207                    body_init: Arc::new(body_init),
5208                    body_carry_in_off,
5209                    body_x_offs: Arc::new(body_x_offs),
5210                    body_d_output_off,
5211                    body_dcarry_out_off,
5212                    body_dxs_out_off,
5213                    outer_init_off: node_offset(arena, node.inputs[0]),
5214                    outer_traj_off: node_offset(arena, node.inputs[1]),
5215                    outer_upstream_off: node_offset(arena, node.inputs[2]),
5216                    outer_xs_offs: Arc::new(outer_xs_offs),
5217                    outer_dxs_off: node_offset(arena, node.id),
5218                    length: *length,
5219                    carry_bytes: carry_bytes as u32,
5220                    carry_elem_size,
5221                    per_step_bytes: per_step_bytes as u32,
5222                    save_trajectory: *save_trajectory,
5223                    num_checkpoints: *num_checkpoints,
5224                    forward_body: fb_schedule,
5225                    forward_body_init: fb_init,
5226                    forward_body_carry_in_off: fb_carry_in_off,
5227                    forward_body_output_off: fb_output_off,
5228                    forward_body_x_offs: Arc::new(fb_x_offs),
5229                }
5230            }
5231
5232            Op::Concat { axis } => {
5233                // Compute outer/inner from the OUTPUT shape: all inputs share
5234                // the same shape except along `axis`. The output's leading
5235                // and trailing dims match.
5236                let out_shape = &node.shape;
5237                let rank = out_shape.rank();
5238                let outer: usize = (0..*axis)
5239                    .map(|i| out_shape.dim(i).unwrap_static())
5240                    .product::<usize>()
5241                    .max(1);
5242                let inner: usize = (*axis + 1..rank)
5243                    .map(|i| out_shape.dim(i).unwrap_static())
5244                    .product::<usize>()
5245                    .max(1);
5246                let total_axis = out_shape.dim(*axis).unwrap_static();
5247                let inputs: Vec<(usize, u32)> = node
5248                    .inputs
5249                    .iter()
5250                    .map(|&in_id| {
5251                        let in_shape = &graph.node(in_id).shape;
5252                        let in_axis = in_shape.dim(*axis).unwrap_static();
5253                        (node_offset(arena, in_id), in_axis as u32)
5254                    })
5255                    .collect();
5256                let dst = node_offset(arena, node.id);
5257                match out_shape.dtype() {
5258                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
5259                        dst,
5260                        outer: outer as u32,
5261                        inner: inner as u32,
5262                        total_axis: total_axis as u32,
5263                        inputs,
5264                    },
5265                    _ => Thunk::Concat {
5266                        dst,
5267                        outer: outer as u32,
5268                        inner: inner as u32,
5269                        total_axis: total_axis as u32,
5270                        inputs,
5271                    },
5272                }
5273            }
5274
5275            Op::GaussianSplatRender {
5276                width,
5277                height,
5278                tile_size,
5279                radius_scale,
5280                alpha_cutoff,
5281                max_splat_steps,
5282                transmittance_threshold,
5283                max_list_entries,
5284            } => {
5285                let elem_len =
5286                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5287                Thunk::GaussianSplatRender {
5288                    positions_off: node_offset(arena, node.inputs[0]),
5289                    positions_len: elem_len(node.inputs[0]),
5290                    scales_off: node_offset(arena, node.inputs[1]),
5291                    scales_len: elem_len(node.inputs[1]),
5292                    rotations_off: node_offset(arena, node.inputs[2]),
5293                    rotations_len: elem_len(node.inputs[2]),
5294                    opacities_off: node_offset(arena, node.inputs[3]),
5295                    opacities_len: elem_len(node.inputs[3]),
5296                    colors_off: node_offset(arena, node.inputs[4]),
5297                    colors_len: elem_len(node.inputs[4]),
5298                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5299                    sh_coeffs_len: elem_len(node.inputs[5]),
5300                    meta_off: node_offset(arena, node.inputs[6]),
5301                    dst_off: node_offset(arena, node.id),
5302                    dst_len: node.shape.num_elements().unwrap_or(0),
5303                    width: *width,
5304                    height: *height,
5305                    tile_size: *tile_size,
5306                    radius_scale: *radius_scale,
5307                    alpha_cutoff: *alpha_cutoff,
5308                    max_splat_steps: *max_splat_steps,
5309                    transmittance_threshold: *transmittance_threshold,
5310                    max_list_entries: *max_list_entries,
5311                }
5312            }
5313
5314            Op::GaussianSplatRenderBackward {
5315                width,
5316                height,
5317                tile_size,
5318                radius_scale,
5319                alpha_cutoff,
5320                max_splat_steps,
5321                transmittance_threshold,
5322                max_list_entries,
5323                loss_grad_clip,
5324                sh_band,
5325                max_anisotropy,
5326            } => {
5327                let elem_len =
5328                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5329                Thunk::GaussianSplatRenderBackward {
5330                    positions_off: node_offset(arena, node.inputs[0]),
5331                    positions_len: elem_len(node.inputs[0]),
5332                    scales_off: node_offset(arena, node.inputs[1]),
5333                    scales_len: elem_len(node.inputs[1]),
5334                    rotations_off: node_offset(arena, node.inputs[2]),
5335                    rotations_len: elem_len(node.inputs[2]),
5336                    opacities_off: node_offset(arena, node.inputs[3]),
5337                    opacities_len: elem_len(node.inputs[3]),
5338                    colors_off: node_offset(arena, node.inputs[4]),
5339                    colors_len: elem_len(node.inputs[4]),
5340                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5341                    sh_coeffs_len: elem_len(node.inputs[5]),
5342                    meta_off: node_offset(arena, node.inputs[6]),
5343                    d_loss_off: node_offset(arena, node.inputs[7]),
5344                    d_loss_len: elem_len(node.inputs[7]),
5345                    packed_off: node_offset(arena, node.id),
5346                    packed_len: node.shape.num_elements().unwrap_or(0),
5347                    width: *width,
5348                    height: *height,
5349                    tile_size: *tile_size,
5350                    radius_scale: *radius_scale,
5351                    alpha_cutoff: *alpha_cutoff,
5352                    max_splat_steps: *max_splat_steps,
5353                    transmittance_threshold: *transmittance_threshold,
5354                    max_list_entries: *max_list_entries,
5355                    loss_grad_clip: *loss_grad_clip,
5356                    sh_band: *sh_band,
5357                    max_anisotropy: *max_anisotropy,
5358                }
5359            }
5360
5361            Op::GaussianSplatPrepare {
5362                width,
5363                height,
5364                tile_size,
5365                radius_scale,
5366                alpha_cutoff,
5367                max_splat_steps,
5368                transmittance_threshold,
5369                max_list_entries,
5370            } => {
5371                let elem_len =
5372                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5373                Thunk::GaussianSplatPrepare {
5374                    positions_off: node_offset(arena, node.inputs[0]),
5375                    positions_len: elem_len(node.inputs[0]),
5376                    scales_off: node_offset(arena, node.inputs[1]),
5377                    scales_len: elem_len(node.inputs[1]),
5378                    rotations_off: node_offset(arena, node.inputs[2]),
5379                    rotations_len: elem_len(node.inputs[2]),
5380                    opacities_off: node_offset(arena, node.inputs[3]),
5381                    opacities_len: elem_len(node.inputs[3]),
5382                    colors_off: node_offset(arena, node.inputs[4]),
5383                    colors_len: elem_len(node.inputs[4]),
5384                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5385                    sh_coeffs_len: elem_len(node.inputs[5]),
5386                    meta_off: node_offset(arena, node.inputs[6]),
5387                    meta_len: elem_len(node.inputs[6]),
5388                    prep_off: node_offset(arena, node.id),
5389                    prep_len: node.shape.num_elements().unwrap_or(0),
5390                    width: *width,
5391                    height: *height,
5392                    tile_size: *tile_size,
5393                    radius_scale: *radius_scale,
5394                    alpha_cutoff: *alpha_cutoff,
5395                    max_splat_steps: *max_splat_steps,
5396                    transmittance_threshold: *transmittance_threshold,
5397                    max_list_entries: *max_list_entries,
5398                }
5399            }
5400
5401            Op::GaussianSplatRasterize {
5402                width,
5403                height,
5404                tile_size,
5405                alpha_cutoff,
5406                max_splat_steps,
5407                transmittance_threshold,
5408                max_list_entries,
5409            } => {
5410                let elem_len =
5411                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5412                let prep_id = node.inputs[0];
5413                let count = match &graph.node(prep_id).op {
5414                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5415                        elem_len(graph.node(prep_id).inputs[0]) / 3
5416                    }
5417                    _ => 1,
5418                };
5419                Thunk::GaussianSplatRasterize {
5420                    prep_off: node_offset(arena, prep_id),
5421                    prep_len: elem_len(prep_id),
5422                    meta_off: node_offset(arena, node.inputs[1]),
5423                    meta_len: elem_len(node.inputs[1]),
5424                    dst_off: node_offset(arena, node.id),
5425                    dst_len: node.shape.num_elements().unwrap_or(0),
5426                    count,
5427                    width: *width,
5428                    height: *height,
5429                    tile_size: *tile_size,
5430                    alpha_cutoff: *alpha_cutoff,
5431                    max_splat_steps: *max_splat_steps,
5432                    transmittance_threshold: *transmittance_threshold,
5433                    max_list_entries: *max_list_entries,
5434                }
5435            }
5436
5437            Op::Custom { name, attrs, .. } => {
5438                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5439                    panic!(
5440                        "compile_thunks: no CPU kernel registered for \
5441                         Op::Custom('{name}'). Register one via \
5442                         rlx_cpu::op_registry::register_cpu_kernel \
5443                         before compiling on the CPU backend."
5444                    )
5445                });
5446                let inputs_v: Vec<(usize, u32, Shape)> = node
5447                    .inputs
5448                    .iter()
5449                    .map(|&in_id| {
5450                        let s = graph.node(in_id).shape.clone();
5451                        let len = s.num_elements().unwrap_or(0) as u32;
5452                        (node_offset(arena, in_id), len, s)
5453                    })
5454                    .collect();
5455                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5456                Thunk::CustomOp {
5457                    kernel,
5458                    inputs: inputs_v,
5459                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5460                    attrs: attrs.clone(),
5461                }
5462            }
5463
5464            Op::Fft { inverse, norm } => {
5465                let shape = &node.shape;
5466                let meta = rlx_ir::fft::fft_meta(shape);
5467                let dtype = shape.dtype();
5468                assert!(
5469                    matches!(
5470                        dtype,
5471                        rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5472                    ),
5473                    "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5474                );
5475                Thunk::Fft1d {
5476                    src: node_offset(arena, node.inputs[0]),
5477                    dst: node_offset(arena, node.id),
5478                    outer: meta.outer as u32,
5479                    n_complex: meta.n_complex as u32,
5480                    inverse: *inverse,
5481                    norm_tag: norm.tag(),
5482                    dtype,
5483                }
5484            }
5485
5486            Op::FftButterflyStage { stage, n_fft } => {
5487                let state_shape = graph.node(node.inputs[0]).shape.clone();
5488                assert_eq!(
5489                    state_shape.dtype(),
5490                    rlx_ir::DType::F32,
5491                    "Op::FftButterflyStage requires F32 state"
5492                );
5493                let batch = state_shape.dim(0).unwrap_static() as u32;
5494                Thunk::FftButterflyStage {
5495                    state_src: node_offset(arena, node.inputs[0]),
5496                    state_dst: node_offset(arena, node.id),
5497                    gate_src: node_offset(arena, node.inputs[1]),
5498                    rev_src: node_offset(arena, node.inputs[2]),
5499                    tw_re_src: node_offset(arena, node.inputs[3]),
5500                    tw_im_src: node_offset(arena, node.inputs[4]),
5501                    batch,
5502                    n_fft: *n_fft,
5503                    stage: *stage,
5504                }
5505            }
5506
5507            Op::LogMel => {
5508                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5509                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5510                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5511                    .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5512                Thunk::LogMel {
5513                    spec: node_offset(arena, node.inputs[0]),
5514                    filters: node_offset(arena, node.inputs[1]),
5515                    dst: node_offset(arena, node.id),
5516                    outer: meta.outer as u32,
5517                    n_fft: meta.n_fft as u32,
5518                    n_bins: meta.n_bins as u32,
5519                    n_mels: meta.n_mels as u32,
5520                }
5521            }
5522
5523            Op::LogMelBackward => {
5524                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5525                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5526                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5527                    .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5528                Thunk::LogMelBackward {
5529                    spec: node_offset(arena, node.inputs[0]),
5530                    filters: node_offset(arena, node.inputs[1]),
5531                    dy: node_offset(arena, node.inputs[2]),
5532                    dst: node_offset(arena, node.id),
5533                    outer: meta.outer as u32,
5534                    n_fft: meta.n_fft as u32,
5535                    n_bins: meta.n_bins as u32,
5536                    n_mels: meta.n_mels as u32,
5537                }
5538            }
5539
5540            Op::WelchPeaks { k, n_segments } => {
5541                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5542                let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
5543                    .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
5544                Thunk::WelchPeaks {
5545                    spec: node_offset(arena, node.inputs[0]),
5546                    dst: node_offset(arena, node.id),
5547                    welch_batch: meta.welch_batch as u32,
5548                    n_fft: meta.n_fft as u32,
5549                    n_segments: meta.n_segments as u32,
5550                    k: meta.k as u32,
5551                }
5552            }
5553
5554            Op::CustomFn {
5555                fwd_body,
5556                num_inputs,
5557                ..
5558            } => {
5559                // Plan + compile the body sub-graph standalone, fill its
5560                // Constants (mirrors the Op::Scan body lowering), then
5561                // capture per-input copy specs and the output spec.
5562                // Body Inputs in NodeId order match the outer node's
5563                // operand vector by position.
5564                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5565                let body_offsets: HashMap<NodeId, usize> = body_plan
5566                    .assignments
5567                    .iter()
5568                    .map(|(id, slot)| (*id, slot.offset))
5569                    .collect();
5570
5571                let mut body_input_ids: Vec<NodeId> = fwd_body
5572                    .nodes()
5573                    .iter()
5574                    .filter(|n| matches!(n.op, Op::Input { .. }))
5575                    .map(|n| n.id)
5576                    .collect();
5577                body_input_ids.sort();
5578                assert_eq!(
5579                    body_input_ids.len(),
5580                    *num_inputs as usize,
5581                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5582                    body_input_ids.len(),
5583                    *num_inputs,
5584                );
5585
5586                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5587                for n in fwd_body.nodes() {
5588                    if let Op::Constant { data } = &n.op
5589                        && body_arena.has_buffer(n.id)
5590                        && !data.is_empty()
5591                    {
5592                        match n.shape.dtype() {
5593                            rlx_ir::DType::F64 => {
5594                                let off = body_arena.byte_offset(n.id);
5595                                let buf = body_arena.raw_buf_mut();
5596                                let nb = (buf.len() - off).min(data.len());
5597                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5598                            }
5599                            _ => {
5600                                let buf = body_arena.slice_mut(n.id);
5601                                let nf = data.len() / 4;
5602                                let nl = buf.len().min(nf);
5603                                for i in 0..nl {
5604                                    let bytes = [
5605                                        data[i * 4],
5606                                        data[i * 4 + 1],
5607                                        data[i * 4 + 2],
5608                                        data[i * 4 + 3],
5609                                    ];
5610                                    buf[i] = f32::from_le_bytes(bytes);
5611                                }
5612                            }
5613                        }
5614                    }
5615                }
5616                let body_init = body_arena.raw_buf().to_vec();
5617                let body_schedule = compile_thunks(fwd_body, &body_arena);
5618
5619                // Per primal input: (body_input_off, outer_input_off, bytes).
5620                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5621                    .map(|i| {
5622                        let body_in = body_input_ids[i];
5623                        let body_off = body_offsets[&body_in];
5624                        let outer_in = node.inputs[i];
5625                        let outer_off = node_offset(arena, outer_in);
5626                        let bytes = graph
5627                            .node(outer_in)
5628                            .shape
5629                            .size_bytes()
5630                            .expect("Op::CustomFn primal input must have static shape");
5631                        (body_off, outer_off, bytes as u32)
5632                    })
5633                    .collect();
5634
5635                let body_output_id = fwd_body
5636                    .outputs
5637                    .first()
5638                    .copied()
5639                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5640                let body_output_off = body_offsets[&body_output_id];
5641                let out_bytes = node
5642                    .shape
5643                    .size_bytes()
5644                    .expect("Op::CustomFn output must have static shape");
5645
5646                Thunk::CustomFn {
5647                    body: Arc::new(body_schedule),
5648                    body_init: Arc::new(body_init),
5649                    inputs: Arc::new(inputs_v),
5650                    body_output_off,
5651                    outer_output_off: node_offset(arena, node.id),
5652                    out_bytes: out_bytes as u32,
5653                }
5654            }
5655
5656            _ => Thunk::Nop,
5657        };
5658        thunks.push(t);
5659    }
5660
5661    let cfg = crate::config::RuntimeConfig::global();
5662    let mask_thr = cfg.mask_binary_threshold;
5663    let mask_neg = cfg.attn_mask_neg_inf;
5664    let score_skip = cfg.score_skip_threshold;
5665
5666    // Pre-compile closures (skip Nops — they're filtered out)
5667    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5668        .iter()
5669        .filter(|t| !matches!(t, Thunk::Nop))
5670        .map(|thunk| {
5671            match thunk.clone() {
5672                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5673
5674                Thunk::Sgemm { a, b, c, m, k, n } => {
5675                    let (m, k, n) = (m as usize, k as usize, n as usize);
5676                    Arc::new(move |base: *mut u8| unsafe {
5677                        crate::blas::sgemm(
5678                            sl(a, base, m * k),
5679                            sl(b, base, k * n),
5680                            sl_mut(c, base, m * n),
5681                            m,
5682                            k,
5683                            n,
5684                        );
5685                    })
5686                }
5687
5688                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5689                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5690                    Arc::new(move |base: *mut u8| unsafe {
5691                        let a_src = sl_f64(a, base, n_ * n_);
5692                        let b_src = sl_f64(b, base, n_ * nrhs_);
5693                        let mut a_scratch: Vec<f64> = a_src.to_vec();
5694                        let mut x_buf: Vec<f64> = b_src.to_vec();
5695                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5696                        if info != 0 {
5697                            panic!("DenseSolveF64: singular (info={info})");
5698                        }
5699                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5700                    })
5701                }
5702
5703                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5704                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5705                    Arc::new(move |base: *mut u8| unsafe {
5706                        let a_src = sl(a, base, n_ * n_);
5707                        let b_src = sl(b, base, n_ * nrhs_);
5708                        let mut a_scratch: Vec<f32> = a_src.to_vec();
5709                        let mut x_buf: Vec<f32> = b_src.to_vec();
5710                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5711                        if info != 0 {
5712                            panic!("DenseSolveF32: singular (info={info})");
5713                        }
5714                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5715                    })
5716                }
5717
5718                Thunk::FusedMmBiasAct {
5719                    a,
5720                    w,
5721                    bias,
5722                    c,
5723                    m,
5724                    k,
5725                    n,
5726                    act,
5727                } => {
5728                    let (m, k, n) = (m as usize, k as usize, n as usize);
5729                    Arc::new(move |base: *mut u8| unsafe {
5730                        let out = sl_mut(c, base, m * n);
5731                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5732                        // Bias + activation epilogue. Gelu uses the fused
5733                        // `par_bias_gelu` kernel (bias add + Gelu in one
5734                        // pass). For everything else, do the bias add first
5735                        // and then apply the activation per-element. The
5736                        // pre-fix code dispatched `_ => bias_add` and dropped
5737                        // the activation entirely — silent correctness bug
5738                        // for Silu/Relu/Sigmoid/etc.
5739                        match act {
5740                            Some(Activation::Gelu) => {
5741                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5742                            }
5743                            Some(other) => {
5744                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
5745                                apply_activation_inplace(out, other);
5746                            }
5747                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5748                        }
5749                    })
5750                }
5751
5752                Thunk::FusedResidualLN {
5753                    x,
5754                    res,
5755                    bias,
5756                    g,
5757                    b,
5758                    out,
5759                    rows,
5760                    h,
5761                    eps,
5762                    has_bias,
5763                } => {
5764                    let (rows, h) = (rows as usize, h as usize);
5765                    Arc::new(move |base: *mut u8| unsafe {
5766                        let zero = vec![0f32; h]; // closure only — not hot path
5767                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
5768                        let xp = sl(x, base, rows * h).as_ptr() as usize;
5769                        let rp = sl(res, base, rows * h).as_ptr() as usize;
5770                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5771                        let bp = bi.as_ptr() as usize;
5772                        let gp = sl(g, base, h).as_ptr() as usize;
5773                        let bbp = sl(b, base, h).as_ptr() as usize;
5774                        crate::pool::par_for(rows, 4, &|off, cnt| {
5775                            let xs = std::slice::from_raw_parts(
5776                                (xp as *const f32).add(off * h),
5777                                cnt * h,
5778                            );
5779                            let rs = std::slice::from_raw_parts(
5780                                (rp as *const f32).add(off * h),
5781                                cnt * h,
5782                            );
5783                            let os = std::slice::from_raw_parts_mut(
5784                                (op as *mut f32).add(off * h),
5785                                cnt * h,
5786                            );
5787                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
5788                            let g = std::slice::from_raw_parts(gp as *const f32, h);
5789                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
5790                            crate::kernels::residual_bias_layer_norm(
5791                                xs, rs, bi, g, b, os, cnt, h, eps,
5792                            );
5793                        });
5794                    })
5795                }
5796
5797                Thunk::BiasAdd {
5798                    src,
5799                    bias,
5800                    dst,
5801                    m,
5802                    n,
5803                } => {
5804                    let (m, n) = (m as usize, n as usize);
5805                    let len = m * n;
5806                    Arc::new(move |base: *mut u8| unsafe {
5807                        let out = sl_mut(dst, base, len);
5808                        if src != dst {
5809                            let src_ptr = base.add(src) as *const f32;
5810                            let dst_ptr = base.add(dst) as *mut f32;
5811                            if src_ptr != dst_ptr {
5812                                std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5813                            }
5814                        }
5815                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
5816                    })
5817                }
5818
5819                Thunk::Gather {
5820                    table,
5821                    table_len,
5822                    idx,
5823                    dst,
5824                    num_idx,
5825                    trailing,
5826                    idx_i64,
5827                    table_bytes,
5828                } => {
5829                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5830                    let rows = tl / tr.max(1);
5831                    let (idx_i64, table_bytes) = (idx_i64, table_bytes);
5832                    Arc::new(move |base: *mut u8| unsafe {
5833                        if table_bytes == 8 {
5834                            let tab = sl_i64(table, base, tl);
5835                            let out = sl_mut_i64(dst, base, ni * tr);
5836                            if idx_i64 != 0 {
5837                                let ids = sl_i64(idx, base, ni);
5838                                for i in 0..ni {
5839                                    let row = ids[i].max(0) as usize;
5840                                    if row < rows {
5841                                        out[i * tr..(i + 1) * tr]
5842                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5843                                    }
5844                                }
5845                            } else {
5846                                let ids = sl(idx, base, ni);
5847                                for i in 0..ni {
5848                                    let row = ids[i] as usize;
5849                                    if row < rows {
5850                                        out[i * tr..(i + 1) * tr]
5851                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5852                                    }
5853                                }
5854                            }
5855                        } else {
5856                            let tab = sl(table, base, tl);
5857                            let out = sl_mut(dst, base, ni * tr);
5858                            if idx_i64 != 0 {
5859                                let ids = sl_i64(idx, base, ni);
5860                                for i in 0..ni {
5861                                    let row = ids[i].max(0) as usize;
5862                                    if row < rows {
5863                                        out[i * tr..(i + 1) * tr]
5864                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5865                                    }
5866                                }
5867                            } else {
5868                                let ids = sl(idx, base, ni);
5869                                for i in 0..ni {
5870                                    let row = ids[i] as usize;
5871                                    if row < rows {
5872                                        out[i * tr..(i + 1) * tr]
5873                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5874                                    }
5875                                }
5876                            }
5877                        }
5878                    })
5879                }
5880
5881                Thunk::Narrow {
5882                    src,
5883                    dst,
5884                    outer,
5885                    src_stride,
5886                    dst_stride,
5887                    inner,
5888                    elem_bytes,
5889                } => {
5890                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5891                }
5892
5893                Thunk::Copy { src, dst, len } => {
5894                    let len = len as usize;
5895                    Arc::new(move |base: *mut u8| unsafe {
5896                        if src == dst || len == 0 {
5897                            return;
5898                        }
5899                        let src_ptr = base.add(src) as *const f32;
5900                        let dst_ptr = base.add(dst) as *mut f32;
5901                        if src_ptr == dst_ptr {
5902                            return;
5903                        }
5904                        std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
5905                    })
5906                }
5907
5908                Thunk::Softmax { data, rows, cols } => {
5909                    let (rows, cols) = (rows as usize, cols as usize);
5910                    Arc::new(move |base: *mut u8| unsafe {
5911                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5912                    })
5913                }
5914
5915                Thunk::Cumsum {
5916                    src,
5917                    dst,
5918                    rows,
5919                    cols,
5920                    exclusive,
5921                } => {
5922                    let (rows, cols) = (rows as usize, cols as usize);
5923                    Arc::new(move |base: *mut u8| unsafe {
5924                        let s = sl(src, base, rows * cols);
5925                        let d = sl_mut(dst, base, rows * cols);
5926                        if exclusive {
5927                            for r in 0..rows {
5928                                let mut acc = 0.0f32;
5929                                for c in 0..cols {
5930                                    d[r * cols + c] = acc;
5931                                    acc += s[r * cols + c];
5932                                }
5933                            }
5934                        } else {
5935                            for r in 0..rows {
5936                                let mut acc = 0.0f32;
5937                                for c in 0..cols {
5938                                    acc += s[r * cols + c];
5939                                    d[r * cols + c] = acc;
5940                                }
5941                            }
5942                        }
5943                    })
5944                }
5945
5946                Thunk::Sample {
5947                    logits,
5948                    dst,
5949                    batch,
5950                    vocab,
5951                    top_k,
5952                    top_p,
5953                    temperature,
5954                    seed,
5955                } => {
5956                    let (b, v) = (batch as usize, vocab as usize);
5957                    let k = (top_k as usize).min(v);
5958                    Arc::new(move |base: *mut u8| unsafe {
5959                        let lg = sl(logits, base, b * v);
5960                        let out = sl_mut(dst, base, b);
5961                        let mut rng =
5962                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5963                        for bi in 0..b {
5964                            let row = &lg[bi * v..(bi + 1) * v];
5965                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5966                        }
5967                    })
5968                }
5969
5970                Thunk::DequantMatMul {
5971                    x,
5972                    w_q,
5973                    scale,
5974                    zp,
5975                    dst,
5976                    m,
5977                    k,
5978                    n,
5979                    block_size,
5980                    is_asymmetric,
5981                } => {
5982                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5983                    let n_blocks_per_col = k.div_ceil(bs);
5984                    Arc::new(move |base: *mut u8| unsafe {
5985                        let xs = sl(x, base, m * k);
5986                        // w_q is packed i8 — use raw byte slice + reinterpret.
5987                        let raw = base.add(w_q);
5988                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5989                        let scales = sl(scale, base, n_blocks_per_col * n);
5990                        let zps = if is_asymmetric {
5991                            sl(zp, base, n_blocks_per_col * n)
5992                        } else {
5993                            &[][..]
5994                        };
5995                        let out = sl_mut(dst, base, m * n);
5996                        dequant_matmul_int8(
5997                            xs,
5998                            w_bytes,
5999                            scales,
6000                            zps,
6001                            out,
6002                            m,
6003                            k,
6004                            n,
6005                            bs,
6006                            is_asymmetric,
6007                        );
6008                    })
6009                }
6010
6011                Thunk::DequantMatMulGguf {
6012                    x,
6013                    w_q,
6014                    dst,
6015                    m,
6016                    k,
6017                    n,
6018                    scheme,
6019                } => {
6020                    let (m, k, n) = (m as usize, k as usize, n as usize);
6021                    let block_bytes = scheme.gguf_block_bytes() as usize;
6022                    let block_elems = scheme.gguf_block_size() as usize;
6023                    let total_bytes = (k * n) / block_elems * block_bytes;
6024                    Arc::new(move |base: *mut u8| unsafe {
6025                        let xs = sl(x, base, m * k);
6026                        let w_bytes =
6027                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6028                        let out = sl_mut(dst, base, m * n);
6029                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6030                    })
6031                }
6032
6033                Thunk::DequantMatMulInt4 {
6034                    x,
6035                    w_q,
6036                    scale,
6037                    zp,
6038                    dst,
6039                    m,
6040                    k,
6041                    n,
6042                    block_size,
6043                    is_asymmetric,
6044                } => {
6045                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6046                    let n_blocks = k.div_ceil(bs);
6047                    Arc::new(move |base: *mut u8| unsafe {
6048                        let xs = sl(x, base, m * k);
6049                        let w_bytes = std::slice::from_raw_parts(
6050                            base.add(w_q) as *const u8,
6051                            (k * n).div_ceil(2),
6052                        );
6053                        let scales = sl(scale, base, n_blocks * n);
6054                        let zps = if is_asymmetric {
6055                            sl(zp, base, n_blocks * n)
6056                        } else {
6057                            &[][..]
6058                        };
6059                        let out = sl_mut(dst, base, m * n);
6060                        dequant_matmul_int4(
6061                            xs,
6062                            w_bytes,
6063                            scales,
6064                            zps,
6065                            out,
6066                            m,
6067                            k,
6068                            n,
6069                            bs,
6070                            is_asymmetric,
6071                        );
6072                    })
6073                }
6074
6075                Thunk::DequantMatMulFp8 {
6076                    x,
6077                    w_q,
6078                    scale,
6079                    dst,
6080                    m,
6081                    k,
6082                    n,
6083                    e5m2,
6084                } => {
6085                    let (m, k, n) = (m as usize, k as usize, n as usize);
6086                    Arc::new(move |base: *mut u8| unsafe {
6087                        let xs = sl(x, base, m * k);
6088                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6089                        let scales = sl(scale, base, n);
6090                        let out = sl_mut(dst, base, m * n);
6091                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6092                    })
6093                }
6094
6095                Thunk::DequantMatMulNvfp4 {
6096                    x,
6097                    w_q,
6098                    scale,
6099                    global_scale,
6100                    dst,
6101                    m,
6102                    k,
6103                    n,
6104                } => {
6105                    let (m, k, n) = (m as usize, k as usize, n as usize);
6106                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6107                    Arc::new(move |base: *mut u8| unsafe {
6108                        let xs = sl(x, base, m * k);
6109                        let w_bytes = std::slice::from_raw_parts(
6110                            base.add(w_q) as *const u8,
6111                            (k * n).div_ceil(2),
6112                        );
6113                        let scale_bytes =
6114                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6115                        let gs = sl(global_scale, base, 1)[0];
6116                        let out = sl_mut(dst, base, m * n);
6117                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6118                    })
6119                }
6120
6121                Thunk::LoraMatMul {
6122                    x,
6123                    w,
6124                    a,
6125                    b,
6126                    dst,
6127                    m,
6128                    k,
6129                    n,
6130                    r,
6131                    scale,
6132                } => {
6133                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6134                    Arc::new(move |base: *mut u8| unsafe {
6135                        let xs = sl(x, base, m * k);
6136                        let ws = sl(w, base, k * n);
6137                        let a_s = sl(a, base, k * r);
6138                        let bs = sl(b, base, r * n);
6139                        let out = sl_mut(dst, base, m * n);
6140                        // Step 1: out = x · W.
6141                        crate::blas::sgemm(xs, ws, out, m, k, n);
6142                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
6143                        let mut tmp = vec![0f32; m * r];
6144                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6145                        // Step 3: out += scale * (tmp · B).
6146                        // sgemm_accumulate uses alpha=1.0 internally, so
6147                        // scale tmp first.
6148                        if scale != 1.0 {
6149                            for v in tmp.iter_mut() {
6150                                *v *= scale;
6151                            }
6152                        }
6153                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6154                    })
6155                }
6156
6157                Thunk::LayerNorm {
6158                    src,
6159                    g,
6160                    b,
6161                    dst,
6162                    rows,
6163                    h,
6164                    eps,
6165                } => {
6166                    let (rows, h) = (rows as usize, h as usize);
6167                    Arc::new(move |base: *mut u8| unsafe {
6168                        let inp = sl(src, base, rows * h);
6169                        let gamma = sl(g, base, h);
6170                        let beta = sl(b, base, h);
6171                        let out = sl_mut(dst, base, rows * h);
6172                        for row in 0..rows {
6173                            crate::kernels::layer_norm_row(
6174                                &inp[row * h..(row + 1) * h],
6175                                gamma,
6176                                beta,
6177                                &mut out[row * h..(row + 1) * h],
6178                                h,
6179                                eps,
6180                            );
6181                        }
6182                    })
6183                }
6184
6185                Thunk::BatchNormInference {
6186                    src,
6187                    g,
6188                    b,
6189                    mean,
6190                    var,
6191                    dst,
6192                    count,
6193                    channels,
6194                    eps,
6195                } => {
6196                    let count = count as usize;
6197                    let c = channels as usize;
6198                    let n = count * c;
6199                    let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6200                    Arc::new(move |base: *mut u8| unsafe {
6201                        crate::kernels::batch_norm_inference(
6202                            sl(src, base, n),
6203                            sl(g, base, c),
6204                            sl(b, base, c),
6205                            sl(mean, base, c),
6206                            sl(var, base, c),
6207                            sl_mut(dst, base, n),
6208                            c,
6209                            eps,
6210                        );
6211                    })
6212                }
6213
6214                Thunk::Attention {
6215                    q,
6216                    k,
6217                    v,
6218                    mask,
6219                    out,
6220                    batch,
6221                    seq,
6222                    kv_seq,
6223                    heads,
6224                    head_dim,
6225                    mask_kind,
6226                    q_row_stride,
6227                    k_row_stride,
6228                    v_row_stride,
6229                    bhsd,
6230                } => {
6231                    if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6232                        eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6233                    }
6234                    // Q seq length (`q_s`) and K/V seq length (`k_s`) differ
6235                    // during cached decode (`q_s=1`, `k_s=past_seq+1`). The
6236                    // earlier version of this kernel destructured
6237                    // `kv_seq: _` and used a single `s = seq` for both axes,
6238                    // so cached decode only scored 1×1 instead of 1×k_s —
6239                    // attention couldn't see the past K cache and decode
6240                    // collapsed into repetitive fragments
6241                    // (`Self-based on [1\nAnswer: Self-based on [1…`).
6242                    let (b, q_s, k_s, nh, dh) = (
6243                        batch as usize,
6244                        seq as usize,
6245                        kv_seq as usize,
6246                        heads as usize,
6247                        head_dim as usize,
6248                    );
6249                    let hs = nh * dh;
6250                    let qrs = q_row_stride as usize;
6251                    let krs = k_row_stride as usize;
6252                    let vrs = v_row_stride as usize;
6253                    let scale = (dh as f32).powf(-0.5);
6254                    Arc::new(move |base: *mut u8| unsafe {
6255                        if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6256                            eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6257                        }
6258                        // Slice lengths use the source's row stride so the
6259                        // compiler-emitted bounds checks cover the whole
6260                        // strided span (the kernel walks with q/k/v_rs).
6261                        // For [B, H, S, D] the buffer is dense B*H*S*D.
6262                        let (q_len, k_len, v_len, o_len) = if bhsd {
6263                            let qn = b * nh * q_s * dh;
6264                            let kn = b * nh * k_s * dh;
6265                            (qn, kn, kn, qn)
6266                        } else {
6267                            (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6268                        };
6269                        let q_d = sl(q, base, q_len);
6270                        let k_d = sl(k, base, k_len);
6271                        let v_d = sl(v, base, v_len);
6272                        let m_d: &[f32] = match mask_kind {
6273                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6274                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6275                            _ => &[],
6276                        };
6277                        let o_d = sl_mut(out, base, o_len);
6278                        let mut qh = vec![0f32; q_s * dh];
6279                        let mut kh = vec![0f32; k_s * dh];
6280                        let mut vh = vec![0f32; k_s * dh];
6281                        let mut sc = vec![0f32; q_s * k_s];
6282                        let mut oh = vec![0f32; q_s * dh];
6283                        for bi in 0..b {
6284                            for hi in 0..nh {
6285                                // Gather per-head Q.
6286                                for si in 0..q_s {
6287                                    let q_off = if bhsd {
6288                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6289                                    } else {
6290                                        bi * q_s * qrs + si * qrs + hi * dh
6291                                    };
6292                                    qh[si * dh..(si + 1) * dh]
6293                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
6294                                }
6295                                // Gather per-head K, V.
6296                                for si in 0..k_s {
6297                                    let (k_off, v_off) = if bhsd {
6298                                        (
6299                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6300                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6301                                        )
6302                                    } else {
6303                                        (
6304                                            bi * k_s * krs + si * krs + hi * dh,
6305                                            bi * k_s * vrs + si * vrs + hi * dh,
6306                                        )
6307                                    };
6308                                    kh[si * dh..(si + 1) * dh]
6309                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
6310                                    vh[si * dh..(si + 1) * dh]
6311                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
6312                                }
6313                                for qi in 0..q_s {
6314                                    for ki in 0..k_s {
6315                                        let mut dot = 0f32;
6316                                        for d in 0..dh {
6317                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
6318                                        }
6319                                        sc[qi * k_s + ki] = dot * scale;
6320                                    }
6321                                }
6322                                // Apply mask. Causal/SlidingWindow use absolute
6323                                // positions so they handle Lq != Lk (decode mode
6324                                // with cached K/V): q_offset = k_s - q_s.
6325                                let q_offset = k_s.saturating_sub(q_s);
6326                                match mask_kind {
6327                                    rlx_ir::op::MaskKind::None => {}
6328                                    rlx_ir::op::MaskKind::Causal => {
6329                                        for qi in 0..q_s {
6330                                            let abs_q = q_offset + qi;
6331                                            for ki in (abs_q + 1)..k_s {
6332                                                sc[qi * k_s + ki] = mask_neg;
6333                                            }
6334                                        }
6335                                    }
6336                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
6337                                        for qi in 0..q_s {
6338                                            let abs_q = q_offset + qi;
6339                                            let lo = abs_q.saturating_sub(w);
6340                                            for ki in 0..k_s {
6341                                                if ki < lo || ki > abs_q {
6342                                                    sc[qi * k_s + ki] = mask_neg;
6343                                                }
6344                                            }
6345                                        }
6346                                    }
6347                                    rlx_ir::op::MaskKind::Custom => {
6348                                        for qi in 0..q_s {
6349                                            for ki in 0..k_s {
6350                                                if m_d[bi * k_s + ki] < mask_thr {
6351                                                    sc[qi * k_s + ki] = mask_neg;
6352                                                }
6353                                            }
6354                                        }
6355                                    }
6356                                    rlx_ir::op::MaskKind::Bias => {
6357                                        let per_bh = q_s * k_s;
6358                                        let off = (bi * nh + hi) * per_bh;
6359                                        for i in 0..per_bh {
6360                                            sc[i] += m_d[off + i];
6361                                        }
6362                                    }
6363                                }
6364                                crate::naive::softmax(&mut sc, q_s, k_s);
6365                                oh.fill(0.0);
6366                                for qi in 0..q_s {
6367                                    for ki in 0..k_s {
6368                                        let w = sc[qi * k_s + ki];
6369                                        if w > score_skip {
6370                                            for d in 0..dh {
6371                                                oh[qi * dh + d] += w * vh[ki * dh + d];
6372                                            }
6373                                        }
6374                                    }
6375                                }
6376                                for si in 0..q_s {
6377                                    let off = if bhsd {
6378                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6379                                    } else {
6380                                        bi * q_s * hs + si * hs + hi * dh
6381                                    };
6382                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6383                                }
6384                            }
6385                        }
6386                    })
6387                }
6388
6389                Thunk::FusedSwiGLU {
6390                    src,
6391                    dst,
6392                    n_half,
6393                    total,
6394                    gate_first,
6395                } => {
6396                    let n = n_half as usize;
6397                    let t = total as usize;
6398                    let outer = t / n;
6399                    let in_total = outer * 2 * n;
6400                    Arc::new(move |base: *mut u8| unsafe {
6401                        let inp = sl(src, base, in_total);
6402                        let out = sl_mut(dst, base, t);
6403                        for o in 0..outer {
6404                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6405                            let out_row = &mut out[o * n..(o + 1) * n];
6406                            for i in 0..n {
6407                                let (up, gate) = if gate_first {
6408                                    (in_row[n + i], in_row[i])
6409                                } else {
6410                                    (in_row[i], in_row[n + i])
6411                                };
6412                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6413                            }
6414                        }
6415                    })
6416                }
6417
6418                Thunk::Concat {
6419                    dst,
6420                    outer,
6421                    inner,
6422                    total_axis,
6423                    inputs,
6424                } => {
6425                    let outer = outer as usize;
6426                    let inner = inner as usize;
6427                    let total_axis = total_axis as usize;
6428                    let out_total = outer * total_axis * inner;
6429                    // Pre-compute the destination row offset for each input
6430                    // (cumulative axis offsets times inner).
6431                    let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
6432                    let mut cum: usize = 0;
6433                    for (src_off, in_axis) in &inputs {
6434                        let in_axis = *in_axis as usize;
6435                        layout.push((*src_off, cum * inner, in_axis * inner));
6436                        cum += in_axis;
6437                    }
6438                    Arc::new(move |base: *mut u8| unsafe {
6439                        let out = sl_mut(dst, base, out_total);
6440                        let row_stride = total_axis * inner;
6441                        for (src_off, dst_col_off, copy_per_row) in &layout {
6442                            let in_total = outer * *copy_per_row;
6443                            let inp = sl(*src_off, base, in_total);
6444                            for o in 0..outer {
6445                                let dst_row_start = o * row_stride + *dst_col_off;
6446                                let src_row_start = o * *copy_per_row;
6447                                out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
6448                                    &inp[src_row_start..src_row_start + *copy_per_row],
6449                                );
6450                            }
6451                        }
6452                    })
6453                }
6454
6455                Thunk::CustomOp {
6456                    kernel,
6457                    inputs,
6458                    output,
6459                    attrs,
6460                } => {
6461                    // Capture-by-move: clone the Arc and Vecs once into the
6462                    // closure. Dispatch by output dtype each call (the
6463                    // dtype is fixed at compile time but it's cheaper to
6464                    // branch once per execution than to monomorphize a
6465                    // dozen closure variants).
6466                    let kernel = kernel.clone();
6467                    let attrs = attrs.clone();
6468                    let inputs = inputs.clone();
6469                    let (out_off, out_len, out_shape) = output.clone();
6470                    Arc::new(move |base: *mut u8| unsafe {
6471                        dispatch_custom_op(
6472                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6473                        );
6474                    })
6475                }
6476
6477                Thunk::GaussianSplatRender {
6478                    positions_off,
6479                    positions_len,
6480                    scales_off,
6481                    scales_len,
6482                    rotations_off,
6483                    rotations_len,
6484                    opacities_off,
6485                    opacities_len,
6486                    colors_off,
6487                    colors_len,
6488                    sh_coeffs_off,
6489                    sh_coeffs_len,
6490                    meta_off,
6491                    dst_off,
6492                    dst_len,
6493                    width,
6494                    height,
6495                    tile_size,
6496                    radius_scale,
6497                    alpha_cutoff,
6498                    max_splat_steps,
6499                    transmittance_threshold,
6500                    max_list_entries,
6501                } => Arc::new(move |base: *mut u8| unsafe {
6502                    crate::splat::execute_gaussian_splat_render(
6503                        positions_off,
6504                        positions_len,
6505                        scales_off,
6506                        scales_len,
6507                        rotations_off,
6508                        rotations_len,
6509                        opacities_off,
6510                        opacities_len,
6511                        colors_off,
6512                        colors_len,
6513                        sh_coeffs_off,
6514                        sh_coeffs_len,
6515                        meta_off,
6516                        dst_off,
6517                        dst_len,
6518                        width,
6519                        height,
6520                        tile_size,
6521                        radius_scale,
6522                        alpha_cutoff,
6523                        max_splat_steps,
6524                        transmittance_threshold,
6525                        max_list_entries,
6526                        base,
6527                    );
6528                }),
6529
6530                Thunk::GaussianSplatRenderBackward {
6531                    positions_off,
6532                    positions_len,
6533                    scales_off,
6534                    scales_len,
6535                    rotations_off,
6536                    rotations_len,
6537                    opacities_off,
6538                    opacities_len,
6539                    colors_off,
6540                    colors_len,
6541                    sh_coeffs_off,
6542                    sh_coeffs_len,
6543                    meta_off,
6544                    d_loss_off,
6545                    d_loss_len,
6546                    packed_off,
6547                    packed_len,
6548                    width,
6549                    height,
6550                    tile_size,
6551                    radius_scale,
6552                    alpha_cutoff,
6553                    max_splat_steps,
6554                    transmittance_threshold,
6555                    max_list_entries,
6556                    loss_grad_clip,
6557                    sh_band,
6558                    max_anisotropy,
6559                } => Arc::new(move |base: *mut u8| unsafe {
6560                    crate::splat::execute_gaussian_splat_render_backward(
6561                        positions_off,
6562                        positions_len,
6563                        scales_off,
6564                        scales_len,
6565                        rotations_off,
6566                        rotations_len,
6567                        opacities_off,
6568                        opacities_len,
6569                        colors_off,
6570                        colors_len,
6571                        sh_coeffs_off,
6572                        sh_coeffs_len,
6573                        meta_off,
6574                        d_loss_off,
6575                        d_loss_len,
6576                        packed_off,
6577                        packed_len,
6578                        width,
6579                        height,
6580                        tile_size,
6581                        radius_scale,
6582                        alpha_cutoff,
6583                        max_splat_steps,
6584                        transmittance_threshold,
6585                        max_list_entries,
6586                        loss_grad_clip,
6587                        sh_band,
6588                        max_anisotropy,
6589                        base,
6590                    );
6591                }),
6592
6593                Thunk::GaussianSplatPrepare {
6594                    positions_off,
6595                    positions_len,
6596                    scales_off,
6597                    scales_len,
6598                    rotations_off,
6599                    rotations_len,
6600                    opacities_off,
6601                    opacities_len,
6602                    colors_off,
6603                    colors_len,
6604                    sh_coeffs_off,
6605                    sh_coeffs_len,
6606                    meta_off,
6607                    meta_len,
6608                    prep_off,
6609                    prep_len,
6610                    width,
6611                    height,
6612                    tile_size,
6613                    radius_scale,
6614                    alpha_cutoff,
6615                    max_splat_steps,
6616                    transmittance_threshold,
6617                    max_list_entries,
6618                } => Arc::new(move |base: *mut u8| unsafe {
6619                    crate::splat::execute_gaussian_splat_prepare(
6620                        positions_off,
6621                        positions_len,
6622                        scales_off,
6623                        scales_len,
6624                        rotations_off,
6625                        rotations_len,
6626                        opacities_off,
6627                        opacities_len,
6628                        colors_off,
6629                        colors_len,
6630                        sh_coeffs_off,
6631                        sh_coeffs_len,
6632                        meta_off,
6633                        meta_len,
6634                        prep_off,
6635                        prep_len,
6636                        width,
6637                        height,
6638                        tile_size,
6639                        radius_scale,
6640                        alpha_cutoff,
6641                        max_splat_steps,
6642                        transmittance_threshold,
6643                        max_list_entries,
6644                        base,
6645                    );
6646                }),
6647
6648                Thunk::GaussianSplatRasterize {
6649                    prep_off,
6650                    prep_len,
6651                    meta_off,
6652                    meta_len,
6653                    dst_off,
6654                    dst_len,
6655                    count,
6656                    width,
6657                    height,
6658                    tile_size,
6659                    alpha_cutoff,
6660                    max_splat_steps,
6661                    transmittance_threshold,
6662                    max_list_entries,
6663                } => Arc::new(move |base: *mut u8| unsafe {
6664                    crate::splat::execute_gaussian_splat_rasterize(
6665                        prep_off,
6666                        prep_len,
6667                        meta_off,
6668                        meta_len,
6669                        dst_off,
6670                        dst_len,
6671                        count,
6672                        width,
6673                        height,
6674                        tile_size,
6675                        alpha_cutoff,
6676                        max_splat_steps,
6677                        transmittance_threshold,
6678                        max_list_entries,
6679                        base,
6680                    );
6681                }),
6682
6683                Thunk::Fft1d {
6684                    src,
6685                    dst,
6686                    outer,
6687                    n_complex,
6688                    inverse,
6689                    norm_tag,
6690                    dtype,
6691                } => {
6692                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6693                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6694                            execute_fft1d_f64(
6695                                src,
6696                                dst,
6697                                outer as usize,
6698                                n_complex as usize,
6699                                inverse,
6700                                norm_tag,
6701                                base,
6702                            );
6703                        }),
6704                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6705                            execute_fft1d_f32(
6706                                src,
6707                                dst,
6708                                outer as usize,
6709                                n_complex as usize,
6710                                inverse,
6711                                norm_tag,
6712                                base,
6713                            );
6714                        }),
6715                        rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6716                            execute_fft1d_c64(
6717                                src,
6718                                dst,
6719                                outer as usize,
6720                                n_complex as usize,
6721                                inverse,
6722                                norm_tag,
6723                                base,
6724                            );
6725                        }),
6726                        other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6727                    };
6728                    f
6729                }
6730
6731                Thunk::FftButterflyStage {
6732                    state_src,
6733                    state_dst,
6734                    gate_src,
6735                    rev_src,
6736                    tw_re_src,
6737                    tw_im_src,
6738                    batch,
6739                    n_fft,
6740                    stage,
6741                } => Arc::new(move |base: *mut u8| unsafe {
6742                    execute_fft_butterfly_stage_f32(
6743                        state_src,
6744                        state_dst,
6745                        gate_src,
6746                        rev_src,
6747                        tw_re_src,
6748                        tw_im_src,
6749                        batch as usize,
6750                        n_fft as usize,
6751                        stage as usize,
6752                        base,
6753                    );
6754                }),
6755
6756                Thunk::LogMel {
6757                    spec,
6758                    filters,
6759                    dst,
6760                    outer,
6761                    n_fft,
6762                    n_bins,
6763                    n_mels,
6764                } => Arc::new(move |base: *mut u8| unsafe {
6765                    execute_log_mel_f32(
6766                        spec,
6767                        filters,
6768                        dst,
6769                        outer as usize,
6770                        n_fft as usize,
6771                        n_bins as usize,
6772                        n_mels as usize,
6773                        base,
6774                    );
6775                }),
6776
6777                Thunk::LogMelBackward {
6778                    spec,
6779                    filters,
6780                    dy,
6781                    dst,
6782                    outer,
6783                    n_fft,
6784                    n_bins,
6785                    n_mels,
6786                } => Arc::new(move |base: *mut u8| unsafe {
6787                    execute_log_mel_backward_f32(
6788                        spec,
6789                        filters,
6790                        dy,
6791                        dst,
6792                        outer as usize,
6793                        n_fft as usize,
6794                        n_bins as usize,
6795                        n_mels as usize,
6796                        base,
6797                    );
6798                }),
6799
6800                Thunk::WelchPeaks {
6801                    spec,
6802                    dst,
6803                    welch_batch,
6804                    n_fft,
6805                    n_segments,
6806                    k,
6807                } => Arc::new(move |base: *mut u8| unsafe {
6808                    execute_welch_peaks_f32(
6809                        spec,
6810                        dst,
6811                        welch_batch as usize,
6812                        n_fft as usize,
6813                        n_segments as usize,
6814                        k as usize,
6815                        base,
6816                    );
6817                }),
6818
6819                _ => Arc::new(|_: *mut u8| {}),
6820            }
6821        })
6822        .collect();
6823
6824    // ── Thunk-level attention fusion ──────────────────────
6825    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
6826    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
6827    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6828        .and_then(|v| v.parse().ok())
6829        .unwrap_or(64);
6830    let should_fuse = thunks.iter().any(|t| match t {
6831        Thunk::Attention { batch, seq, .. } => {
6832            (*batch as usize) * (*seq as usize) <= fuse_threshold
6833        }
6834        _ => false,
6835    });
6836
6837    if should_fuse {
6838        // Build non-Nop index for pattern matching across Nop gaps
6839        let active: Vec<usize> = thunks
6840            .iter()
6841            .enumerate()
6842            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6843            .map(|(i, _)| i)
6844            .collect();
6845
6846        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
6847        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
6848
6849        let mut ai = 0;
6850        while ai < active.len() {
6851            // Helper: get active thunk at offset from current
6852            let a = |off: usize| -> Option<(usize, &Thunk)> {
6853                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6854            };
6855
6856            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
6857            let matched = (|| {
6858                let (_i0, t0) = a(0)?;
6859                let (_, t1) = a(1)?;
6860                let (_, t2) = a(2)?;
6861                let (_, t3) = a(3)?;
6862
6863                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
6864                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6865                    Thunk::FusedMmBiasAct {
6866                        a,
6867                        w,
6868                        bias,
6869                        n: _,
6870                        act: None,
6871                        ..
6872                    } => (*a, *w, *bias, true),
6873                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6874                    _ => return None,
6875                };
6876
6877                // a[1..3] must be Narrows
6878                if !matches!(t1, Thunk::Narrow { .. }) {
6879                    return None;
6880                }
6881                if !matches!(t2, Thunk::Narrow { .. }) {
6882                    return None;
6883                }
6884                if !matches!(t3, Thunk::Narrow { .. }) {
6885                    return None;
6886                }
6887
6888                // Look for optional Rope×2 then Attention
6889                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6890                    _,
6891                    Thunk::Rope {
6892                        cos, sin, cos_len, ..
6893                    },
6894                )) = a(4)
6895                {
6896                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6897                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6898                            (true, 6, *cos, *sin, *cos_len)
6899                        } else {
6900                            return None;
6901                        }
6902                    } else {
6903                        return None;
6904                    }
6905                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6906                    (false, 4, 0, 0, 0)
6907                } else {
6908                    return None;
6909                };
6910
6911                let (_attn_real_idx, attn_t) = a(attn_ai)?;
6912                let (batch, seq, heads, head_dim, mask) = match attn_t {
6913                    Thunk::Attention {
6914                        batch,
6915                        seq,
6916                        heads,
6917                        head_dim,
6918                        mask,
6919                        ..
6920                    } => (*batch, *seq, *heads, *head_dim, *mask),
6921                    _ => return None,
6922                };
6923
6924                // Next active must be out projection (FusedMmBiasAct or Sgemm)
6925                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6926                let (out_w, out_b, out_dst) = match out_t {
6927                    Thunk::FusedMmBiasAct {
6928                        w,
6929                        bias,
6930                        c,
6931                        act: None,
6932                        ..
6933                    } => (*w, *bias, *c),
6934                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6935                    _ => return None,
6936                };
6937
6938                let hs = heads * head_dim;
6939                let total_active = attn_ai + 2; // number of active thunks consumed
6940
6941                Some((
6942                    total_active,
6943                    Thunk::FusedAttnBlock {
6944                        hidden,
6945                        qkv_w,
6946                        out_w,
6947                        mask,
6948                        out: out_dst,
6949                        qkv_b: if has_b { qkv_b } else { 0 },
6950                        out_b: if has_b { out_b } else { 0 },
6951                        cos: cos_off,
6952                        sin: sin_off,
6953                        cos_len: cl,
6954                        batch,
6955                        seq,
6956                        hs,
6957                        nh: heads,
6958                        dh: head_dim,
6959                        has_bias: has_b,
6960                        has_rope,
6961                    },
6962                ))
6963            })();
6964
6965            if let Some((count, fused_thunk)) = matched {
6966                // Mark consumed thunks for removal
6967                for off in 0..count {
6968                    if let Some(&idx) = active.get(ai + off) {
6969                        kill[idx] = true;
6970                    }
6971                }
6972                // Insert replacement at position of the QKV thunk
6973                insertions.push((active[ai], fused_thunk));
6974                ai += count;
6975            } else {
6976                ai += 1;
6977            }
6978        }
6979
6980        // Rebuild thunk list: keep non-killed, insert fused at right positions
6981        if !insertions.is_empty() {
6982            let mut new_thunks = Vec::with_capacity(thunks.len());
6983            let mut insert_idx = 0;
6984            for (i, t) in thunks.into_iter().enumerate() {
6985                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6986                    new_thunks.push(insertions[insert_idx].1.clone());
6987                    insert_idx += 1;
6988                }
6989                if !kill[i] {
6990                    new_thunks.push(t);
6991                }
6992            }
6993            if cfg.verbose >= 1 {
6994                eprintln!(
6995                    "[rlx] fused_attention: {} attention blocks fused",
6996                    insertions.len()
6997                );
6998            }
6999            thunks = new_thunks;
7000        }
7001    }
7002
7003    // ── Full layer fusion ──────────────────────────────────
7004    // After attention blocks are fused, scan for full layer patterns:
7005    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
7006    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
7007    if should_fuse {
7008        let active: Vec<usize> = thunks
7009            .iter()
7010            .enumerate()
7011            .filter(|(_, t)| !matches!(t, Thunk::Nop))
7012            .map(|(i, _)| i)
7013            .collect();
7014
7015        let mut kill = vec![false; thunks.len()];
7016        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
7017
7018        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
7019
7020        let mut ai = 0;
7021        while ai < active.len() {
7022            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
7023            let bert_match = (|| -> Option<usize> {
7024                let fab = a(ai)?;
7025                let rln1 = a(ai + 1)?;
7026                let ffn1 = a(ai + 2)?;
7027                let ffn2 = a(ai + 3)?;
7028                let rln2 = a(ai + 4)?;
7029
7030                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
7031                    Thunk::FusedAttnBlock {
7032                        hidden,
7033                        qkv_w,
7034                        qkv_b,
7035                        out_w,
7036                        out_b,
7037                        mask,
7038                        batch,
7039                        seq,
7040                        hs,
7041                        nh,
7042                        dh,
7043                        has_bias: true,
7044                        has_rope: false,
7045                        ..
7046                    } => (
7047                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7048                    ),
7049                    _ => return None,
7050                };
7051                let (ln1_g, ln1_b, eps1) = match rln1 {
7052                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7053                    _ => return None,
7054                };
7055                let (fc1_w, fc1_b, int_dim) = match ffn1 {
7056                    Thunk::FusedMmBiasAct {
7057                        w,
7058                        bias,
7059                        n,
7060                        act: Some(Activation::Gelu),
7061                        ..
7062                    } => (*w, *bias, *n),
7063                    _ => return None,
7064                };
7065                let (fc2_w, fc2_b) = match ffn2 {
7066                    Thunk::FusedMmBiasAct {
7067                        w, bias, act: None, ..
7068                    } => (*w, *bias),
7069                    _ => return None,
7070                };
7071                let (ln2_g, ln2_b, eps2, out) = match rln2 {
7072                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7073                    _ => return None,
7074                };
7075
7076                for off in 0..5 {
7077                    kill[active[ai + off]] = true;
7078                }
7079                insertions.push((
7080                    active[ai],
7081                    Thunk::FusedBertLayer {
7082                        hidden,
7083                        qkv_w,
7084                        qkv_b,
7085                        out_w,
7086                        out_b,
7087                        mask,
7088                        ln1_g,
7089                        ln1_b,
7090                        eps1,
7091                        fc1_w,
7092                        fc1_b,
7093                        fc2_w,
7094                        fc2_b,
7095                        ln2_g,
7096                        ln2_b,
7097                        eps2,
7098                        out,
7099                        batch,
7100                        seq,
7101                        hs,
7102                        nh,
7103                        dh,
7104                        int_dim,
7105                    },
7106                ));
7107                Some(5)
7108            })();
7109            if let Some(n) = bert_match {
7110                ai += n;
7111                continue;
7112            }
7113
7114            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
7115            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
7116            // The body below is kept as reference for when the stride bug is fixed.
7117            #[allow(unreachable_code)]
7118            let nomic_match = (|| -> Option<usize> {
7119                return None; // TODO: fix SwiGLU strided fc2 output mismatch
7120                let fab = a(ai)?;
7121                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7122                    match fab {
7123                        Thunk::FusedAttnBlock {
7124                            hidden,
7125                            qkv_w,
7126                            out_w,
7127                            mask,
7128                            cos,
7129                            sin,
7130                            cos_len,
7131                            batch,
7132                            seq,
7133                            hs,
7134                            nh,
7135                            dh,
7136                            has_bias: false,
7137                            has_rope: true,
7138                            ..
7139                        } => (
7140                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7141                            *hs, *nh, *dh,
7142                        ),
7143                        _ => return None,
7144                    };
7145                // FusedResidualLN for LN1
7146                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7147                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7148                    _ => return None,
7149                };
7150                // Sgemm (fused fc11+fc12)
7151                let fused_fc_w = match a(ai + 2)? {
7152                    Thunk::Sgemm { b: w, .. } => *w,
7153                    _ => return None,
7154                };
7155                // Narrow×2 for split
7156                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7157                    return None;
7158                }
7159                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7160                    return None;
7161                }
7162                // SiLU
7163                if !matches!(
7164                    a(ai + 5)?,
7165                    Thunk::ActivationInPlace {
7166                        act: Activation::Silu,
7167                        ..
7168                    }
7169                ) {
7170                    return None;
7171                }
7172                // BinaryFull(Mul) for gate
7173                if !matches!(
7174                    a(ai + 6)?,
7175                    Thunk::BinaryFull {
7176                        op: BinaryOp::Mul,
7177                        ..
7178                    }
7179                ) {
7180                    return None;
7181                }
7182                // Sgemm (fc2)
7183                let fc2_w = match a(ai + 7)? {
7184                    Thunk::Sgemm { b: w, .. } => *w,
7185                    _ => return None,
7186                };
7187                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
7188                let int_dim = match a(ai + 3)? {
7189                    Thunk::Narrow { inner, .. } => *inner,
7190                    _ => return None,
7191                };
7192                // FusedResidualLN for LN2
7193                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7194                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7195                    _ => return None,
7196                };
7197
7198                for off in 0..9 {
7199                    kill[active[ai + off]] = true;
7200                }
7201                insertions.push((
7202                    active[ai],
7203                    Thunk::FusedNomicLayer {
7204                        hidden,
7205                        qkv_w,
7206                        out_w,
7207                        mask,
7208                        cos,
7209                        sin,
7210                        cos_len,
7211                        ln1_g,
7212                        ln1_b,
7213                        eps1,
7214                        fc11_w: fused_fc_w,
7215                        fc12_w: 0,
7216                        fc2_w,
7217                        ln2_g,
7218                        ln2_b,
7219                        eps2,
7220                        out,
7221                        batch,
7222                        seq,
7223                        hs,
7224                        nh,
7225                        dh,
7226                        int_dim,
7227                    },
7228                ));
7229                Some(9)
7230            })();
7231            if let Some(n) = nomic_match {
7232                ai += n;
7233                continue;
7234            }
7235
7236            ai += 1;
7237        }
7238
7239        if !insertions.is_empty() {
7240            let mut new_thunks = Vec::with_capacity(thunks.len());
7241            let mut ins_idx = 0;
7242            for (i, t) in thunks.into_iter().enumerate() {
7243                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7244                    new_thunks.push(insertions[ins_idx].1.clone());
7245                    ins_idx += 1;
7246                }
7247                if !kill[i] {
7248                    new_thunks.push(t);
7249                }
7250            }
7251            if cfg.verbose >= 1 {
7252                eprintln!(
7253                    "[rlx] fused_layer: {} full transformer layers fused",
7254                    insertions.len()
7255                );
7256            }
7257            thunks = new_thunks;
7258        }
7259    }
7260
7261    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
7262    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
7263    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
7264    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
7265    // preceding Narrow whose dst has no other consumer in this schedule.
7266    // Rewrite Rope to read directly from the parent buffer with the
7267    // parent's row stride; the Narrow becomes a Nop.
7268    //
7269    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
7270    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
7271    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
7272    {
7273        // Collect every byte-offset that's read as a thunk's `src` so
7274        // we know whether a Narrow's dst has consumers other than Rope.
7275        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7276        for t in &thunks {
7277            for off in thunk_read_offsets(t) {
7278                *read_offsets.entry(off).or_insert(0) += 1;
7279            }
7280        }
7281
7282        let mut fused_count = 0usize;
7283        for i in 0..thunks.len().saturating_sub(1) {
7284            // Look for Rope at i+1 reading from Narrow at i (skip Nops
7285            // between them since the planner left them in place).
7286            let narrow = match &thunks[i] {
7287                Thunk::Narrow { .. } => i,
7288                _ => continue,
7289            };
7290            // Find the next non-Nop thunk
7291            let mut j = narrow + 1;
7292            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7293                j += 1;
7294            }
7295            if j >= thunks.len() {
7296                continue;
7297            }
7298            // Must be Rope reading Narrow's dst
7299            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7300                Thunk::Narrow {
7301                    src,
7302                    dst,
7303                    src_stride,
7304                    ..
7305                } => (*src, *dst, *src_stride),
7306                _ => continue,
7307            };
7308            let rope_reads_narrow = matches!(&thunks[j],
7309                Thunk::Rope { src, .. } if *src == n_dst);
7310            if !rope_reads_narrow {
7311                continue;
7312            }
7313            // Conservatively require that the Narrow's dst has exactly
7314            // one reader (the Rope). Anything else and rewriting would
7315            // skip a needed write.
7316            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7317                continue;
7318            }
7319
7320            // Rewire: Rope reads from Narrow's adjusted source with the
7321            // parent buffer's row stride.
7322            if let Thunk::Rope {
7323                src,
7324                src_row_stride,
7325                ..
7326            } = &mut thunks[j]
7327            {
7328                *src = n_src;
7329                *src_row_stride = n_src_stride;
7330            }
7331            thunks[narrow] = Thunk::Nop;
7332            fused_count += 1;
7333        }
7334
7335        if fused_count > 0 && cfg.verbose >= 1 {
7336            eprintln!(
7337                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7338                fused_count
7339            );
7340        }
7341    }
7342
7343    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
7344    // For each Attention thunk in the schedule, look up the producers
7345    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
7346    // one consumer (the Attention), rewire Attention to read directly
7347    // from the parent buffer with the parent's row stride. The three
7348    // Narrows become Nops.
7349    //
7350    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
7351    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
7352    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
7353    // per layer × 6 layers = ~14 MB of write traffic gone.
7354    {
7355        let mut read_counts: HashMap<usize, usize> = HashMap::new();
7356        for t in &thunks {
7357            for off in thunk_read_offsets(t) {
7358                *read_counts.entry(off).or_insert(0) += 1;
7359            }
7360        }
7361        // Build dst→index map for fast producer lookup.
7362        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7363        for (i, t) in thunks.iter().enumerate() {
7364            if let Thunk::Narrow { dst, .. } = t {
7365                dst_to_idx.insert(*dst, i);
7366            }
7367        }
7368
7369        let mut fused_count = 0usize;
7370        for i in 0..thunks.len() {
7371            let (q_off, k_off, v_off) = match &thunks[i] {
7372                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7373                _ => continue,
7374            };
7375            // All three inputs must come from Narrows.
7376            let q_n = match dst_to_idx.get(&q_off).copied() {
7377                Some(x) => x,
7378                None => continue,
7379            };
7380            let k_n = match dst_to_idx.get(&k_off).copied() {
7381                Some(x) => x,
7382                None => continue,
7383            };
7384            let v_n = match dst_to_idx.get(&v_off).copied() {
7385                Some(x) => x,
7386                None => continue,
7387            };
7388            // Each Narrow's dst must have exactly one reader (this Attn).
7389            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7390                continue;
7391            }
7392            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7393                continue;
7394            }
7395            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7396                continue;
7397            }
7398
7399            let (q_src, q_stride) = match &thunks[q_n] {
7400                Thunk::Narrow {
7401                    src, src_stride, ..
7402                } => (*src, *src_stride),
7403                _ => continue,
7404            };
7405            let (k_src, k_stride) = match &thunks[k_n] {
7406                Thunk::Narrow {
7407                    src, src_stride, ..
7408                } => (*src, *src_stride),
7409                _ => continue,
7410            };
7411            let (v_src, v_stride) = match &thunks[v_n] {
7412                Thunk::Narrow {
7413                    src, src_stride, ..
7414                } => (*src, *src_stride),
7415                _ => continue,
7416            };
7417
7418            if let Thunk::Attention {
7419                q,
7420                k,
7421                v,
7422                q_row_stride,
7423                k_row_stride,
7424                v_row_stride,
7425                ..
7426            } = &mut thunks[i]
7427            {
7428                *q = q_src;
7429                *k = k_src;
7430                *v = v_src;
7431                *q_row_stride = q_stride;
7432                *k_row_stride = k_stride;
7433                *v_row_stride = v_stride;
7434            }
7435            thunks[q_n] = Thunk::Nop;
7436            thunks[k_n] = Thunk::Nop;
7437            thunks[v_n] = Thunk::Nop;
7438            fused_count += 1;
7439        }
7440
7441        if fused_count > 0 && cfg.verbose >= 1 {
7442            eprintln!(
7443                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7444                fused_count
7445            );
7446        }
7447    }
7448
7449    ThunkSchedule {
7450        thunks,
7451        moe_resident: None,
7452        moe_resident_layers: None,
7453        moe_topk_capture: None,
7454        mask_threshold: cfg.mask_binary_threshold,
7455        mask_neg_inf: cfg.attn_mask_neg_inf,
7456        score_skip: cfg.score_skip_threshold,
7457        compiled_fns,
7458    }
7459}
7460
7461fn get_len(graph: &Graph, id: NodeId) -> usize {
7462    graph.node(id).shape.num_elements().unwrap_or(0)
7463}
7464
7465/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
7466fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7467    let dims = graph.node(id).shape.dims();
7468    let mut out = Vec::with_capacity(dims.len());
7469    for d in dims {
7470        if let Some(s) = match d {
7471            rlx_ir::Dim::Static(s) => Some(*s),
7472            _ => None,
7473        } {
7474            out.push(s);
7475        } else {
7476            return Vec::new();
7477        }
7478    }
7479    out
7480}
7481
7482/// NumPy-style broadcast strides for one operand into the flat output
7483/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
7484/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
7485/// (after left-padding with size-1 to match ranks), otherwise the
7486/// natural row-major stride into the *input* buffer.
7487///
7488/// Caller iterates output flat index `i` → output coords (row-major)
7489/// → input flat index = dot(coords, strides). The result is correct
7490/// for any broadcast pattern (scalar, last-axis, middle-axis,
7491/// bidirectional).
7492/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
7493/// — i.e. every rhs dim either equals the corresponding output dim
7494/// (counting from the right) or rhs is shorter (left-padded with 1s).
7495/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
7496/// where `c > 1`) are NOT trailing broadcasts and require the
7497/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
7498/// kernel silently miscomputes them.
7499fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7500    if rhs_dims.len() > out_dims.len() {
7501        return false;
7502    }
7503    let off = out_dims.len() - rhs_dims.len();
7504    for i in 0..rhs_dims.len() {
7505        let r = match rhs_dims[i] {
7506            rlx_ir::Dim::Static(n) => n,
7507            _ => return false,
7508        };
7509        let o = match out_dims[off + i] {
7510            rlx_ir::Dim::Static(n) => n,
7511            _ => return false,
7512        };
7513        if r != o {
7514            return false;
7515        }
7516    }
7517    true
7518}
7519
7520fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7521    let r_out = out_dims.len();
7522    let r_in = in_dims.len();
7523    assert!(
7524        r_in <= r_out,
7525        "broadcast: input rank {r_in} > output rank {r_out}"
7526    );
7527    let pad = r_out - r_in;
7528    let mut strides = vec![0u32; r_out];
7529    let mut acc: usize = 1;
7530    for d in (0..r_out).rev() {
7531        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7532        if in_size == 1 {
7533            strides[d] = 0;
7534        } else {
7535            assert_eq!(
7536                in_size, out_dims[d],
7537                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7538                out_dims[d]
7539            );
7540            strides[d] = acc as u32;
7541            acc *= in_size;
7542        }
7543    }
7544    strides
7545}
7546
7547/// Execute a thunk schedule on a raw arena buffer.
7548/// Fastest executor: call pre-compiled closures sequentially.
7549/// Zero match dispatch — each closure is a direct kernel call.
7550pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7551    let base = arena_buf.as_mut_ptr();
7552    for f in &schedule.compiled_fns {
7553        f(base);
7554    }
7555}
7556
7557/// Active-extent execution stub. The runtime calls this when it has an
7558/// active-extent hint set. CPU doesn't implement per-thunk active-extent
7559/// scaling yet — return false so the caller falls back to the full
7560/// `execute_thunks` path.
7561pub fn execute_thunks_active(
7562    schedule: &ThunkSchedule,
7563    _arena_buf: &mut [u8],
7564    _actual: usize,
7565    _upper: usize,
7566) -> bool {
7567    let _ = schedule;
7568    false
7569}
7570
7571/// Match-based executor (fallback, used by tests).
7572struct MoeResidencyGuard;
7573impl Drop for MoeResidencyGuard {
7574    fn drop(&mut self) {
7575        if let Some(stats) = crate::moe_residency::take_stats() {
7576            crate::moe_residency::stash_last_forward_stats(stats);
7577        } else {
7578            crate::moe_residency::clear_mask();
7579        }
7580    }
7581}
7582
7583fn thunk_kind_name(t: &Thunk) -> &'static str {
7584    match t {
7585        Thunk::Nop => "Nop",
7586        Thunk::Gather { .. } => "Gather",
7587        Thunk::GatherAxis { .. } => "GatherAxis",
7588        Thunk::TopK { .. } => "TopK",
7589        Thunk::Copy { .. } => "Copy",
7590        Thunk::CopyF64 { .. } => "CopyF64",
7591        Thunk::CopyI64 { .. } => "CopyI64",
7592        Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
7593        Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
7594        Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
7595        Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
7596        Thunk::Transpose { .. } => "Transpose",
7597        Thunk::TransposeF64 { .. } => "TransposeF64",
7598        Thunk::Where { .. } => "Where",
7599        Thunk::Compare { .. } => "Compare",
7600        Thunk::BinaryFull { .. } => "BinaryFull",
7601        Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
7602        Thunk::Sgemm { .. } => "Sgemm",
7603        Thunk::Dgemm { .. } => "Dgemm",
7604        Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
7605        Thunk::BiasAdd { .. } => "BiasAdd",
7606        Thunk::LayerNorm { .. } => "LayerNorm",
7607        Thunk::Softmax { .. } => "Softmax",
7608        Thunk::Conv2D { .. } => "Conv2D",
7609        Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
7610        Thunk::CustomOp { .. } => "CustomOp",
7611        Thunk::ActivationInPlace { .. } => "ActivationInPlace",
7612        Thunk::Narrow { .. } => "Narrow",
7613        Thunk::Cumsum { .. } => "Cumsum",
7614        Thunk::Reduce { .. } => "Reduce",
7615        Thunk::BatchedSgemm { .. } => "BatchedSgemm",
7616        Thunk::DequantMatMul { .. } => "DequantMatMul",
7617        Thunk::Quantize { .. } => "Quantize",
7618        Thunk::Dequantize { .. } => "Dequantize",
7619        Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
7620        Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
7621        _ => "Other",
7622    }
7623}
7624
7625pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7626    crate::moe_residency::reset_gmm_counters();
7627    if let Some(layers) = schedule.moe_resident_layers.clone() {
7628        crate::moe_residency::set_per_layer_masks(Some(layers));
7629    } else {
7630        crate::moe_residency::set_mask(schedule.moe_resident.clone());
7631    }
7632    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
7633        cap.clear();
7634    }
7635    let _moe_guard = MoeResidencyGuard;
7636    let base = arena_buf.as_mut_ptr();
7637    let mask_thr = schedule.mask_threshold;
7638    let mask_neg = schedule.mask_neg_inf;
7639    let score_thr = schedule.score_skip;
7640    let thunks = &schedule.thunks;
7641    let len = thunks.len();
7642
7643    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
7644    let max_h = thunks
7645        .iter()
7646        .filter_map(|t| match t {
7647            Thunk::FusedResidualLN { h, .. }
7648            | Thunk::FusedResidualRmsNorm { h, .. }
7649            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
7650            _ => None,
7651        })
7652        .max()
7653        .unwrap_or(0);
7654    let zero_bias = vec![0f32; max_h];
7655
7656    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
7657    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
7658    let max_sdpa = thunks
7659        .iter()
7660        .filter_map(|t| match t {
7661            Thunk::Attention {
7662                batch,
7663                seq,
7664                kv_seq,
7665                heads,
7666                head_dim,
7667                ..
7668            } => Some((
7669                *batch as usize,
7670                (*seq as usize).max(*kv_seq as usize),
7671                *heads as usize,
7672                *head_dim as usize,
7673            )),
7674            _ => None,
7675        })
7676        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7677            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7678        });
7679    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7680    let max_units = max_batch * max_heads;
7681    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7682
7683    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
7684    let fl = thunks
7685        .iter()
7686        .filter_map(|t| match t {
7687            Thunk::FusedBertLayer {
7688                batch,
7689                seq,
7690                hs,
7691                int_dim,
7692                ..
7693            } => {
7694                let m = (*batch as usize) * (*seq as usize);
7695                let h = *hs as usize;
7696                let id = *int_dim as usize;
7697                Some((m, h, id, m * (*seq as usize)))
7698            }
7699            Thunk::FusedNomicLayer {
7700                batch,
7701                seq,
7702                hs,
7703                int_dim,
7704                ..
7705            } => {
7706                let m = (*batch as usize) * (*seq as usize);
7707                let h = *hs as usize;
7708                let id = *int_dim as usize;
7709                Some((m, h, id, m * (*seq as usize)))
7710            }
7711            _ => None,
7712        })
7713        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7714            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7715        });
7716    let (fl_m, fl_h, fl_int, fl_ss) = fl;
7717    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7718    let mut fl_attn = vec![0f32; fl_m * fl_h];
7719    let mut fl_res = vec![0f32; fl_m * fl_h];
7720    let mut fl_normed = vec![0f32; fl_m * fl_h];
7721    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
7722    let mut fl_sc = vec![0f32; fl_ss.max(1)];
7723
7724    let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
7725    if trace_thunks {
7726        eprintln!(
7727            "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
7728            max_units * max_seq * max_seq
7729        );
7730    }
7731    for i in 0..len {
7732        let thunk = unsafe { thunks.get_unchecked(i) };
7733        if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
7734            eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
7735        }
7736        let trace_done = trace_thunks && i < 120;
7737        match thunk {
7738            Thunk::Nop => {}
7739
7740            Thunk::GaussianSplatRender {
7741                positions_off,
7742                positions_len,
7743                scales_off,
7744                scales_len,
7745                rotations_off,
7746                rotations_len,
7747                opacities_off,
7748                opacities_len,
7749                colors_off,
7750                colors_len,
7751                sh_coeffs_off,
7752                sh_coeffs_len,
7753                meta_off,
7754                dst_off,
7755                dst_len,
7756                width,
7757                height,
7758                tile_size,
7759                radius_scale,
7760                alpha_cutoff,
7761                max_splat_steps,
7762                transmittance_threshold,
7763                max_list_entries,
7764            } => unsafe {
7765                crate::splat::execute_gaussian_splat_render(
7766                    *positions_off,
7767                    *positions_len,
7768                    *scales_off,
7769                    *scales_len,
7770                    *rotations_off,
7771                    *rotations_len,
7772                    *opacities_off,
7773                    *opacities_len,
7774                    *colors_off,
7775                    *colors_len,
7776                    *sh_coeffs_off,
7777                    *sh_coeffs_len,
7778                    *meta_off,
7779                    *dst_off,
7780                    *dst_len,
7781                    *width,
7782                    *height,
7783                    *tile_size,
7784                    *radius_scale,
7785                    *alpha_cutoff,
7786                    *max_splat_steps,
7787                    *transmittance_threshold,
7788                    *max_list_entries,
7789                    base,
7790                );
7791            },
7792
7793            Thunk::GaussianSplatRenderBackward {
7794                positions_off,
7795                positions_len,
7796                scales_off,
7797                scales_len,
7798                rotations_off,
7799                rotations_len,
7800                opacities_off,
7801                opacities_len,
7802                colors_off,
7803                colors_len,
7804                sh_coeffs_off,
7805                sh_coeffs_len,
7806                meta_off,
7807                d_loss_off,
7808                d_loss_len,
7809                packed_off,
7810                packed_len,
7811                width,
7812                height,
7813                tile_size,
7814                radius_scale,
7815                alpha_cutoff,
7816                max_splat_steps,
7817                transmittance_threshold,
7818                max_list_entries,
7819                loss_grad_clip,
7820                sh_band,
7821                max_anisotropy,
7822            } => unsafe {
7823                crate::splat::execute_gaussian_splat_render_backward(
7824                    *positions_off,
7825                    *positions_len,
7826                    *scales_off,
7827                    *scales_len,
7828                    *rotations_off,
7829                    *rotations_len,
7830                    *opacities_off,
7831                    *opacities_len,
7832                    *colors_off,
7833                    *colors_len,
7834                    *sh_coeffs_off,
7835                    *sh_coeffs_len,
7836                    *meta_off,
7837                    *d_loss_off,
7838                    *d_loss_len,
7839                    *packed_off,
7840                    *packed_len,
7841                    *width,
7842                    *height,
7843                    *tile_size,
7844                    *radius_scale,
7845                    *alpha_cutoff,
7846                    *max_splat_steps,
7847                    *transmittance_threshold,
7848                    *max_list_entries,
7849                    *loss_grad_clip,
7850                    *sh_band,
7851                    *max_anisotropy,
7852                    base,
7853                );
7854            },
7855
7856            Thunk::GaussianSplatPrepare {
7857                positions_off,
7858                positions_len,
7859                scales_off,
7860                scales_len,
7861                rotations_off,
7862                rotations_len,
7863                opacities_off,
7864                opacities_len,
7865                colors_off,
7866                colors_len,
7867                sh_coeffs_off,
7868                sh_coeffs_len,
7869                meta_off,
7870                meta_len,
7871                prep_off,
7872                prep_len,
7873                width,
7874                height,
7875                tile_size,
7876                radius_scale,
7877                alpha_cutoff,
7878                max_splat_steps,
7879                transmittance_threshold,
7880                max_list_entries,
7881            } => unsafe {
7882                crate::splat::execute_gaussian_splat_prepare(
7883                    *positions_off,
7884                    *positions_len,
7885                    *scales_off,
7886                    *scales_len,
7887                    *rotations_off,
7888                    *rotations_len,
7889                    *opacities_off,
7890                    *opacities_len,
7891                    *colors_off,
7892                    *colors_len,
7893                    *sh_coeffs_off,
7894                    *sh_coeffs_len,
7895                    *meta_off,
7896                    *meta_len,
7897                    *prep_off,
7898                    *prep_len,
7899                    *width,
7900                    *height,
7901                    *tile_size,
7902                    *radius_scale,
7903                    *alpha_cutoff,
7904                    *max_splat_steps,
7905                    *transmittance_threshold,
7906                    *max_list_entries,
7907                    base,
7908                );
7909            },
7910
7911            Thunk::GaussianSplatRasterize {
7912                prep_off,
7913                prep_len,
7914                meta_off,
7915                meta_len,
7916                dst_off,
7917                dst_len,
7918                count,
7919                width,
7920                height,
7921                tile_size,
7922                alpha_cutoff,
7923                max_splat_steps,
7924                transmittance_threshold,
7925                max_list_entries,
7926            } => unsafe {
7927                crate::splat::execute_gaussian_splat_rasterize(
7928                    *prep_off,
7929                    *prep_len,
7930                    *meta_off,
7931                    *meta_len,
7932                    *dst_off,
7933                    *dst_len,
7934                    *count,
7935                    *width,
7936                    *height,
7937                    *tile_size,
7938                    *alpha_cutoff,
7939                    *max_splat_steps,
7940                    *transmittance_threshold,
7941                    *max_list_entries,
7942                    base,
7943                );
7944            },
7945
7946            Thunk::Fft1d {
7947                src,
7948                dst,
7949                outer,
7950                n_complex,
7951                inverse,
7952                norm_tag,
7953                dtype,
7954            } => unsafe {
7955                match dtype {
7956                    rlx_ir::DType::F64 => execute_fft1d_f64(
7957                        *src,
7958                        *dst,
7959                        *outer as usize,
7960                        *n_complex as usize,
7961                        *inverse,
7962                        *norm_tag,
7963                        base,
7964                    ),
7965                    rlx_ir::DType::F32 => execute_fft1d_f32(
7966                        *src,
7967                        *dst,
7968                        *outer as usize,
7969                        *n_complex as usize,
7970                        *inverse,
7971                        *norm_tag,
7972                        base,
7973                    ),
7974                    rlx_ir::DType::C64 => execute_fft1d_c64(
7975                        *src,
7976                        *dst,
7977                        *outer as usize,
7978                        *n_complex as usize,
7979                        *inverse,
7980                        *norm_tag,
7981                        base,
7982                    ),
7983                    other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7984                }
7985            },
7986
7987            Thunk::FftButterflyStage {
7988                state_src,
7989                state_dst,
7990                gate_src,
7991                rev_src,
7992                tw_re_src,
7993                tw_im_src,
7994                batch,
7995                n_fft,
7996                stage,
7997            } => unsafe {
7998                execute_fft_butterfly_stage_f32(
7999                    *state_src,
8000                    *state_dst,
8001                    *gate_src,
8002                    *rev_src,
8003                    *tw_re_src,
8004                    *tw_im_src,
8005                    *batch as usize,
8006                    *n_fft as usize,
8007                    *stage as usize,
8008                    base,
8009                );
8010            },
8011
8012            Thunk::LogMel {
8013                spec,
8014                filters,
8015                dst,
8016                outer,
8017                n_fft,
8018                n_bins,
8019                n_mels,
8020            } => unsafe {
8021                execute_log_mel_f32(
8022                    *spec,
8023                    *filters,
8024                    *dst,
8025                    *outer as usize,
8026                    *n_fft as usize,
8027                    *n_bins as usize,
8028                    *n_mels as usize,
8029                    base,
8030                );
8031            },
8032
8033            Thunk::LogMelBackward {
8034                spec,
8035                filters,
8036                dy,
8037                dst,
8038                outer,
8039                n_fft,
8040                n_bins,
8041                n_mels,
8042            } => unsafe {
8043                execute_log_mel_backward_f32(
8044                    *spec,
8045                    *filters,
8046                    *dy,
8047                    *dst,
8048                    *outer as usize,
8049                    *n_fft as usize,
8050                    *n_bins as usize,
8051                    *n_mels as usize,
8052                    base,
8053                );
8054            },
8055
8056            Thunk::WelchPeaks {
8057                spec,
8058                dst,
8059                welch_batch,
8060                n_fft,
8061                n_segments,
8062                k,
8063            } => unsafe {
8064                execute_welch_peaks_f32(
8065                    *spec,
8066                    *dst,
8067                    *welch_batch as usize,
8068                    *n_fft as usize,
8069                    *n_segments as usize,
8070                    *k as usize,
8071                    base,
8072                );
8073            },
8074
8075            // CustomFn dispatch (interpreted path). Mirrors the
8076            // pre-compiled-closure variant elsewhere in this file.
8077            // Patched by rlx-eda.
8078            Thunk::CustomFn {
8079                body,
8080                body_init,
8081                inputs,
8082                body_output_off,
8083                outer_output_off,
8084                out_bytes,
8085            } => {
8086                let mut body_buf: Vec<u8> = (**body_init).clone();
8087                unsafe {
8088                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8089                        let src = (base as *const u8).add(*outer_in_off);
8090                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
8091                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8092                    }
8093                }
8094                execute_thunks(body, &mut body_buf);
8095                unsafe {
8096                    let src = body_buf.as_ptr().add(*body_output_off);
8097                    let dst = base.add(*outer_output_off);
8098                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8099                }
8100            }
8101
8102            Thunk::Sgemm { a, b, c, m, k, n } => {
8103                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8104                if trace_thunks {
8105                    eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8106                }
8107                let c_len = m.saturating_mul(n);
8108                let a_len = m.saturating_mul(k);
8109                let b_len = k.saturating_mul(n);
8110                let arena_len = arena_buf.len();
8111                let max_a = (arena_len.saturating_sub(*a)) / 4;
8112                let max_b = (arena_len.saturating_sub(*b)) / 4;
8113                let max_c = (arena_len.saturating_sub(*c)) / 4;
8114                let a_len = a_len.min(max_a);
8115                let b_len = b_len.min(max_b);
8116                let c_len = c_len.min(max_c);
8117                unsafe {
8118                    let a_sl = sl(*a, base, a_len);
8119                    let b_sl = sl(*b, base, b_len);
8120                    let c_sl = sl_mut(*c, base, c_len);
8121                    if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8122                        || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8123                    {
8124                        let mut tmp = vec![0.0f32; c_len];
8125                        crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8126                        c_sl.copy_from_slice(&tmp);
8127                    } else {
8128                        crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8129                    }
8130                }
8131            }
8132
8133            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8134                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8135                // LAPACK overwrites both A and B; clone into scratch
8136                // each call. Caller's A and b must be preserved for
8137                // VJP recompute. (Eventually: swap to a factor-once /
8138                // solve-many scheme; that's the symbolic-reuse story
8139                // and lives with the sparse path.)
8140                unsafe {
8141                    let a_src = sl_f64(*a, base, n_ * n_);
8142                    let b_src = sl_f64(*b, base, n_ * nrhs_);
8143                    let mut a_scratch: Vec<f64> = a_src.to_vec();
8144                    let mut x_buf: Vec<f64> = b_src.to_vec();
8145                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8146                    if info != 0 {
8147                        panic!(
8148                            "DenseSolveF64: dgesv reported singular matrix \
8149                                (info={info}, n={n_}, nrhs={nrhs_})"
8150                        );
8151                    }
8152                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8153                    dst.copy_from_slice(&x_buf);
8154                }
8155            }
8156
8157            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8158                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8159                unsafe {
8160                    let a_src = sl(*a, base, n_ * n_);
8161                    let b_src = sl(*b, base, n_ * nrhs_);
8162                    let mut a_scratch: Vec<f32> = a_src.to_vec();
8163                    let mut x_buf: Vec<f32> = b_src.to_vec();
8164                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8165                    if info != 0 {
8166                        panic!(
8167                            "DenseSolveF32: sgesv reported singular matrix \
8168                             (info={info}, n={n_}, nrhs={nrhs_})"
8169                        );
8170                    }
8171                    let dst = sl_mut(*x, base, n_ * nrhs_);
8172                    dst.copy_from_slice(&x_buf);
8173                }
8174            }
8175
8176            Thunk::BatchedDenseSolveF64 {
8177                a,
8178                b,
8179                x,
8180                batch,
8181                n,
8182                nrhs,
8183            } => {
8184                // Per slice: extract A_i and b_i, dgesv, write x_i.
8185                // LAPACK has no batched dgesv on Accelerate, so this
8186                // is a serial loop over the batch axis. cuSOLVER /
8187                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
8188                // the GPU path — we'll wire that in rlx-cuda when
8189                // someone needs Linux+CUDA.
8190                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8191                let a_stride = n_ * n_;
8192                let b_stride = n_ * nrhs_;
8193                unsafe {
8194                    let a_full = sl_f64(*a, base, b_ * a_stride);
8195                    let b_full = sl_f64(*b, base, b_ * b_stride);
8196                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8197                    for bi in 0..b_ {
8198                        let mut a_scratch: Vec<f64> =
8199                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8200                        let mut x_buf: Vec<f64> =
8201                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8202                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8203                        if info != 0 {
8204                            panic!(
8205                                "BatchedDenseSolveF64: slice {bi} \
8206                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
8207                            );
8208                        }
8209                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8210                    }
8211                }
8212            }
8213
8214            Thunk::BatchedDenseSolveF32 {
8215                a,
8216                b,
8217                x,
8218                batch,
8219                n,
8220                nrhs,
8221            } => {
8222                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8223                let a_stride = n_ * n_;
8224                let b_stride = n_ * nrhs_;
8225                unsafe {
8226                    let a_full = sl(*a, base, b_ * a_stride);
8227                    let b_full = sl(*b, base, b_ * b_stride);
8228                    let x_full = sl_mut(*x, base, b_ * b_stride);
8229                    for bi in 0..b_ {
8230                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8231                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8232                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8233                        if info != 0 {
8234                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8235                        }
8236                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8237                    }
8238                }
8239            }
8240
8241            Thunk::BatchedDgemmF64 {
8242                a,
8243                b,
8244                c,
8245                batch,
8246                m,
8247                k,
8248                n,
8249            } => {
8250                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8251                let a_stride = m_ * k_;
8252                let b_stride = k_ * n_;
8253                let c_stride = m_ * n_;
8254                unsafe {
8255                    let a_full = sl_f64(*a, base, b_ * a_stride);
8256                    let b_full = sl_f64(*b, base, b_ * b_stride);
8257                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8258                    for bi in 0..b_ {
8259                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8260                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8261                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8262                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8263                    }
8264                }
8265            }
8266
8267            Thunk::BatchedSgemm {
8268                a,
8269                b,
8270                c,
8271                batch,
8272                m,
8273                k,
8274                n,
8275            } => {
8276                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8277                if trace_thunks {
8278                    eprintln!(
8279                        "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8280                        *a, *b, *c
8281                    );
8282                }
8283                let a_stride = m_.saturating_mul(k_);
8284                let b_stride = k_.saturating_mul(n_);
8285                let c_stride = m_.saturating_mul(n_);
8286                let arena_len = arena_buf.len();
8287                let a_cap = (arena_len.saturating_sub(*a)) / 4;
8288                let b_cap = (arena_len.saturating_sub(*b)) / 4;
8289                let c_cap = (arena_len.saturating_sub(*c)) / 4;
8290                let a_elems = (b_ * a_stride).min(a_cap);
8291                let b_elems = (b_ * b_stride).min(b_cap);
8292                let c_elems = (b_ * c_stride).min(c_cap);
8293                let b_eff = b_
8294                    .min(a_elems.checked_div(a_stride).unwrap_or(0))
8295                    .min(b_elems.checked_div(b_stride).unwrap_or(0))
8296                    .min(c_elems.checked_div(c_stride).unwrap_or(0));
8297                unsafe {
8298                    let a_full = sl(*a, base, a_elems);
8299                    let b_full = sl(*b, base, b_elems);
8300                    let c_full = sl_mut(*c, base, c_elems);
8301                    for bi in 0..b_eff {
8302                        let a0 = bi * a_stride;
8303                        let b0 = bi * b_stride;
8304                        let c0 = bi * c_stride;
8305                        if a0 + a_stride > a_full.len()
8306                            || b0 + b_stride > b_full.len()
8307                            || c0 + c_stride > c_full.len()
8308                        {
8309                            break;
8310                        }
8311                        let a_slice = &a_full[a0..a0 + a_stride];
8312                        let b_slice = &b_full[b0..b0 + b_stride];
8313                        let c_slice = &mut c_full[c0..c0 + c_stride];
8314                        if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8315                            || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8316                        {
8317                            let mut tmp = vec![0.0f32; c_stride];
8318                            crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8319                            c_slice.copy_from_slice(&tmp);
8320                        } else {
8321                            crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8322                        }
8323                    }
8324                }
8325            }
8326
8327            Thunk::Dgemm { a, b, c, m, k, n } => {
8328                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8329                unsafe {
8330                    crate::blas::dgemm(
8331                        sl_f64(*a, base, m * k),
8332                        sl_f64(*b, base, k * n),
8333                        sl_mut_f64(*c, base, m * n),
8334                        m,
8335                        k,
8336                        n,
8337                    );
8338                }
8339            }
8340
8341            Thunk::TransposeF64 {
8342                src,
8343                dst,
8344                in_total,
8345                out_dims,
8346                in_strides,
8347            } => unsafe {
8348                let inp = sl_f64(*src, base, *in_total as usize);
8349                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8350                let out = sl_mut_f64(*dst, base, out_total);
8351                transpose_walk_f64(inp, out, out_dims, in_strides);
8352            },
8353
8354            Thunk::ActivationF64 {
8355                src,
8356                dst,
8357                len,
8358                kind,
8359            } => {
8360                let len = *len as usize;
8361                unsafe {
8362                    let inp = sl_f64(*src, base, len);
8363                    let out = sl_mut_f64(*dst, base, len);
8364                    apply_activation_f64(inp, out, *kind);
8365                }
8366            }
8367
8368            Thunk::ReduceSumF64 {
8369                src,
8370                dst,
8371                outer,
8372                reduced,
8373                inner,
8374            } => {
8375                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8376                unsafe {
8377                    let inp = sl_f64(*src, base, o * r * n);
8378                    let out = sl_mut_f64(*dst, base, o * n);
8379                    reduce_sum_f64(inp, out, o, r, n);
8380                }
8381            }
8382
8383            Thunk::CopyF64 { src, dst, len } => {
8384                let mut len = *len as usize;
8385                if *src == *dst || len == 0 {
8386                    continue;
8387                }
8388                let arena_len = arena_buf.len();
8389                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8390                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8391                len = len.min(max_from_src).min(max_from_dst);
8392                if len == 0 {
8393                    continue;
8394                }
8395                let byte_len = len.saturating_mul(8);
8396                unsafe {
8397                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8398                }
8399            }
8400
8401            Thunk::CopyI64 { src, dst, len } => {
8402                let mut len = *len as usize;
8403                if *src == *dst || len == 0 {
8404                    continue;
8405                }
8406                let arena_len = arena_buf.len();
8407                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8408                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8409                len = len.min(max_from_src).min(max_from_dst);
8410                if len == 0 {
8411                    continue;
8412                }
8413                let byte_len = len.saturating_mul(8);
8414                unsafe {
8415                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8416                }
8417            }
8418
8419            Thunk::CastF32ToI64 { src, dst, len } => {
8420                let len = *len as usize;
8421                if len == 0 {
8422                    continue;
8423                }
8424                unsafe {
8425                    let inp = sl(*src, base, len);
8426                    let out = sl_mut_i64(*dst, base, len);
8427                    for i in 0..len {
8428                        out[i] = inp[i].round() as i64;
8429                    }
8430                }
8431            }
8432
8433            Thunk::CastI64ToF32 { src, dst, len } => {
8434                let len = *len as usize;
8435                if len == 0 {
8436                    continue;
8437                }
8438                unsafe {
8439                    let inp = sl_i64(*src, base, len);
8440                    let out = sl_mut(*dst, base, len);
8441                    for i in 0..len {
8442                        out[i] = inp[i] as f32;
8443                    }
8444                }
8445            }
8446
8447            Thunk::CastBoolToI32 { src, dst, len } => {
8448                let len = *len as usize;
8449                if len == 0 {
8450                    continue;
8451                }
8452                unsafe {
8453                    let inp = &arena_buf[*src..*src + len];
8454                    let out = sl_mut_i32(*dst, base, len);
8455                    for i in 0..len {
8456                        out[i] = i32::from(inp[i] != 0);
8457                    }
8458                }
8459            }
8460
8461            Thunk::CastI32ToF32 { src, dst, len } => {
8462                let len = *len as usize;
8463                if len == 0 {
8464                    continue;
8465                }
8466                unsafe {
8467                    let inp = sl_i32(*src, base, len);
8468                    let out = sl_mut(*dst, base, len);
8469                    for i in 0..len {
8470                        out[i] = inp[i] as f32;
8471                    }
8472                }
8473            }
8474
8475            Thunk::BinaryFullF64 {
8476                lhs,
8477                rhs,
8478                dst,
8479                len,
8480                lhs_len,
8481                rhs_len,
8482                op,
8483                out_dims_bcast,
8484                bcast_lhs_strides,
8485                bcast_rhs_strides,
8486            } => {
8487                let len = *len as usize;
8488                let lhs_len = *lhs_len as usize;
8489                let rhs_len = *rhs_len as usize;
8490                unsafe {
8491                    let l = sl_f64(*lhs, base, lhs_len);
8492                    let r = sl_f64(*rhs, base, rhs_len);
8493                    let d = sl_mut_f64(*dst, base, len);
8494                    if lhs_len == len && rhs_len == len {
8495                        for i in 0..len {
8496                            d[i] = binary_op_f64(*op, l[i], r[i]);
8497                        }
8498                    } else if !out_dims_bcast.is_empty() {
8499                        // Shape-aware broadcast path: correct for
8500                        // arbitrary NumPy-style broadcasts including
8501                        // bidirectional `[N,1] op [1,S]`.
8502                        let rank = out_dims_bcast.len();
8503                        let mut coords = vec![0u32; rank];
8504                        for i in 0..len {
8505                            let mut rem = i;
8506                            for ax in (0..rank).rev() {
8507                                let sz = out_dims_bcast[ax] as usize;
8508                                coords[ax] = (rem % sz) as u32;
8509                                rem /= sz;
8510                            }
8511                            let mut li: usize = 0;
8512                            let mut ri: usize = 0;
8513                            for ax in 0..rank {
8514                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8515                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8516                            }
8517                            d[i] = binary_op_f64(*op, l[li], r[ri]);
8518                        }
8519                    } else {
8520                        // Fallback: legacy modulo path (preserved for
8521                        // dynamic-shape graphs where strides can't be
8522                        // precomputed). Only correct for scalar /
8523                        // last-axis broadcast.
8524                        for i in 0..len {
8525                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
8526                        }
8527                    }
8528                }
8529            }
8530
8531            Thunk::BinaryFullC64 {
8532                lhs,
8533                rhs,
8534                dst,
8535                len,
8536                lhs_len,
8537                rhs_len,
8538                op,
8539                out_dims_bcast,
8540                bcast_lhs_strides,
8541                bcast_rhs_strides,
8542            } => {
8543                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
8544                // Underlying f32 buffer length is 2·N (N = complex
8545                // element count). All offsets are byte offsets; the
8546                // `sl` helper reads as f32 starting at the byte
8547                // offset, so f32-length = 2·complex-len.
8548                let n_out = *len as usize;
8549                let n_l = *lhs_len as usize;
8550                let n_r = *rhs_len as usize;
8551                unsafe {
8552                    let l = sl(*lhs, base, 2 * n_l);
8553                    let r = sl(*rhs, base, 2 * n_r);
8554                    let d = sl_mut(*dst, base, 2 * n_out);
8555                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
8556                        match op {
8557                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
8558                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
8559                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
8560                            BinaryOp::Div => {
8561                                let denom = b_re * b_re + b_im * b_im;
8562                                (
8563                                    (a_re * b_re + a_im * b_im) / denom,
8564                                    (a_im * b_re - a_re * b_im) / denom,
8565                                )
8566                            }
8567                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
8568                                unreachable!("C64 max/min/pow rejected at lowering")
8569                            }
8570                        }
8571                    };
8572                    if n_l == n_out && n_r == n_out {
8573                        for i in 0..n_out {
8574                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
8575                            d[2 * i] = re;
8576                            d[2 * i + 1] = im;
8577                        }
8578                    } else if !out_dims_bcast.is_empty() {
8579                        // Strided complex broadcast: strides are in
8580                        // *complex element* units; multiply by 2 when
8581                        // indexing into the f32 buffer.
8582                        let rank = out_dims_bcast.len();
8583                        let mut coords = vec![0u32; rank];
8584                        for i in 0..n_out {
8585                            let mut rem = i;
8586                            for ax in (0..rank).rev() {
8587                                let sz = out_dims_bcast[ax] as usize;
8588                                coords[ax] = (rem % sz) as u32;
8589                                rem /= sz;
8590                            }
8591                            let mut li: usize = 0;
8592                            let mut ri: usize = 0;
8593                            for ax in 0..rank {
8594                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8595                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8596                            }
8597                            let (re, im) =
8598                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8599                            d[2 * i] = re;
8600                            d[2 * i + 1] = im;
8601                        }
8602                    } else {
8603                        // Modulo fallback (scalar / last-axis broadcast).
8604                        for i in 0..n_out {
8605                            let li = if n_l == 1 { 0 } else { i % n_l };
8606                            let ri = if n_r == 1 { 0 } else { i % n_r };
8607                            let (re, im) =
8608                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
8609                            d[2 * i] = re;
8610                            d[2 * i + 1] = im;
8611                        }
8612                    }
8613                }
8614            }
8615
8616            Thunk::ComplexNormSqF32 { src, dst, len } => {
8617                let n = *len as usize;
8618                unsafe {
8619                    let s = sl(*src, base, 2 * n);
8620                    let d = sl_mut(*dst, base, n);
8621                    for i in 0..n {
8622                        let re = s[2 * i];
8623                        let im = s[2 * i + 1];
8624                        d[i] = re * re + im * im;
8625                    }
8626                }
8627            }
8628
8629            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
8630                // Wirtinger: dz = g · z, element-wise complex
8631                // (g is real, z is complex).
8632                let n = *len as usize;
8633                unsafe {
8634                    let zb = sl(*z, base, 2 * n);
8635                    let gb = sl(*g, base, n);
8636                    let db = sl_mut(*dz, base, 2 * n);
8637                    for i in 0..n {
8638                        let re = zb[2 * i];
8639                        let im = zb[2 * i + 1];
8640                        let gv = gb[i];
8641                        db[2 * i] = gv * re;
8642                        db[2 * i + 1] = gv * im;
8643                    }
8644                }
8645            }
8646
8647            Thunk::ConjugateC64 { src, dst, len } => {
8648                let n = *len as usize;
8649                unsafe {
8650                    let s = sl(*src, base, 2 * n);
8651                    let d = sl_mut(*dst, base, 2 * n);
8652                    for i in 0..n {
8653                        d[2 * i] = s[2 * i];
8654                        d[2 * i + 1] = -s[2 * i + 1];
8655                    }
8656                }
8657            }
8658
8659            Thunk::ActivationC64 {
8660                src,
8661                dst,
8662                len,
8663                kind,
8664            } => {
8665                let n = *len as usize;
8666                unsafe {
8667                    let s = sl(*src, base, 2 * n);
8668                    let d = sl_mut(*dst, base, 2 * n);
8669                    for i in 0..n {
8670                        let a = s[2 * i];
8671                        let b = s[2 * i + 1];
8672                        let (re, im) = match kind {
8673                            Activation::Neg => (-a, -b),
8674                            Activation::Exp => {
8675                                // exp(a + bi) = e^a · (cos b + i·sin b)
8676                                let ea = a.exp();
8677                                (ea * b.cos(), ea * b.sin())
8678                            }
8679                            Activation::Log => {
8680                                // log(z) = log|z| + i·arg(z), principal branch
8681                                let r = (a * a + b * b).sqrt();
8682                                (r.ln(), b.atan2(a))
8683                            }
8684                            Activation::Sqrt => {
8685                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
8686                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
8687                                let r = (a * a + b * b).sqrt();
8688                                let re = ((r + a) * 0.5).max(0.0).sqrt();
8689                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
8690                                let im = if b >= 0.0 { im_mag } else { -im_mag };
8691                                (re, im)
8692                            }
8693                            _ => unreachable!("non-C64 activation kind survived lowering"),
8694                        };
8695                        d[2 * i] = re;
8696                        d[2 * i + 1] = im;
8697                    }
8698                }
8699            }
8700
8701            Thunk::Scan {
8702                body,
8703                body_init,
8704                body_input_off,
8705                body_output_off,
8706                outer_init_off,
8707                outer_final_off,
8708                length,
8709                carry_bytes,
8710                save_trajectory,
8711                xs_inputs,
8712                bcast_inputs,
8713                num_checkpoints,
8714            } => {
8715                let cb = *carry_bytes as usize;
8716                let n_steps = *length as usize;
8717                // Checkpoint mode: when 0 < K < length, save trajectory[k]
8718                // only when t == c_k = floor((k+1) * length / K) - 1.
8719                // The last index c_{K-1} = length - 1 always.
8720                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
8721                    n_steps // save every step
8722                } else {
8723                    *num_checkpoints as usize
8724                };
8725                let checkpoint_t_for_k = |k: usize| -> usize {
8726                    if k_total == n_steps {
8727                        k
8728                    } else {
8729                        ((k + 1) * n_steps)
8730                            .div_ceil(k_total)
8731                            .saturating_sub(1)
8732                            .min(n_steps - 1)
8733                    }
8734                };
8735                let mut next_k = 0usize;
8736
8737                let mut body_buf: Vec<u8> = (**body_init).clone();
8738                unsafe {
8739                    std::ptr::copy_nonoverlapping(
8740                        base.add(*outer_init_off),
8741                        body_buf.as_mut_ptr().add(*body_input_off),
8742                        cb,
8743                    );
8744                    // Broadcast inputs: copy each one into the body's
8745                    // input slot ONCE. They aren't touched in the
8746                    // iteration loop below (in contrast to xs).
8747                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
8748                        std::ptr::copy_nonoverlapping(
8749                            base.add(*outer_b_off),
8750                            body_buf.as_mut_ptr().add(*body_b_off),
8751                            *total_bytes as usize,
8752                        );
8753                    }
8754                }
8755                for t in 0..n_steps {
8756                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
8757                        let psb = *per_step_bytes as usize;
8758                        unsafe {
8759                            std::ptr::copy_nonoverlapping(
8760                                base.add(*outer_xs_off + t * psb),
8761                                body_buf.as_mut_ptr().add(*body_x_off),
8762                                psb,
8763                            );
8764                        }
8765                    }
8766
8767                    execute_thunks(body, &mut body_buf);
8768
8769                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
8770                        unsafe {
8771                            std::ptr::copy_nonoverlapping(
8772                                body_buf.as_ptr().add(*body_output_off),
8773                                base.add(*outer_final_off + next_k * cb),
8774                                cb,
8775                            );
8776                        }
8777                        next_k += 1;
8778                    }
8779
8780                    if *body_output_off != *body_input_off {
8781                        body_buf
8782                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
8783                    }
8784                }
8785
8786                if !*save_trajectory {
8787                    // Single final-carry write.
8788                    unsafe {
8789                        std::ptr::copy_nonoverlapping(
8790                            body_buf.as_ptr().add(*body_output_off),
8791                            base.add(*outer_final_off),
8792                            cb,
8793                        );
8794                    }
8795                }
8796            }
8797
8798            Thunk::ScanBackward {
8799                body_vjp,
8800                body_init,
8801                body_carry_in_off,
8802                body_x_offs,
8803                body_d_output_off,
8804                body_dcarry_out_off,
8805                outer_init_off,
8806                outer_traj_off,
8807                outer_upstream_off,
8808                outer_xs_offs,
8809                outer_dinit_off,
8810                length,
8811                carry_bytes,
8812                save_trajectory,
8813                num_checkpoints,
8814                forward_body,
8815                forward_body_init,
8816                forward_body_carry_in_off,
8817                forward_body_output_off,
8818                forward_body_x_offs,
8819                carry_elem_size,
8820            } => {
8821                // Two backward paths share the same per-iteration body
8822                // (body_vjp run + dcarry threading). The "All" path
8823                // reads the carry directly from the saved trajectory
8824                // each step. The "Recursive checkpointing" path stores
8825                // only K saved checkpoints and reconstructs intermediate
8826                // carries via Griewank-style recursive subdivision —
8827                // see [`griewank_process_segment`]. Auxiliary memory
8828                // is `O(log(segment_size) · carry_bytes)` for the
8829                // recursion stack, vs the old segment-cache scheme's
8830                // `O(segment_size · carry_bytes)`. Total recompute work
8831                // grows from `O(length)` to `O(length · log)`, which
8832                // is the canonical Griewank trade.
8833                let cb = *carry_bytes as usize;
8834                let n_steps = *length as usize;
8835                let k_total = *num_checkpoints as usize;
8836                let is_recursive = k_total != 0 && k_total != n_steps;
8837                let checkpoint_t_for_k = |k: usize| -> usize {
8838                    ((k + 1) * n_steps)
8839                        .div_ceil(k_total)
8840                        .saturating_sub(1)
8841                        .min(n_steps - 1)
8842                };
8843
8844                let mut fwd_buf: Vec<u8> = if is_recursive {
8845                    (**forward_body_init.as_ref().unwrap()).clone()
8846                } else {
8847                    Vec::new()
8848                };
8849
8850                let mut dcarry: Vec<u8> = vec![0u8; cb];
8851                if !*save_trajectory {
8852                    unsafe {
8853                        std::ptr::copy_nonoverlapping(
8854                            base.add(*outer_upstream_off),
8855                            dcarry.as_mut_ptr(),
8856                            cb,
8857                        );
8858                    }
8859                }
8860
8861                let mut body_buf: Vec<u8> = (**body_init).clone();
8862
8863                // Per-iteration backward action — shared between the
8864                // direct-trajectory (All) and Griewank (Recursive) paths.
8865                // Both feed the same body_vjp run with carry-at-t,
8866                // x_t_i, and d_output, then thread dcarry backward.
8867                let process_iter =
8868                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
8869                        if *save_trajectory {
8870                            unsafe {
8871                                let up_off = *outer_upstream_off + t * cb;
8872                                match *carry_elem_size {
8873                                    4 => {
8874                                        let up_ptr = base.add(up_off) as *const f32;
8875                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8876                                        let n_elems = cb / 4;
8877                                        for i in 0..n_elems {
8878                                            *dc_ptr.add(i) += *up_ptr.add(i);
8879                                        }
8880                                    }
8881                                    8 => {
8882                                        let up_ptr = base.add(up_off) as *const f64;
8883                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8884                                        let n_elems = cb / 8;
8885                                        for i in 0..n_elems {
8886                                            *dc_ptr.add(i) += *up_ptr.add(i);
8887                                        }
8888                                    }
8889                                    other => panic!(
8890                                        "ScanBackward: unsupported carry elem size {other} \
8891                                     (only f32/f64 carries are supported today)"
8892                                    ),
8893                                }
8894                            }
8895                        }
8896                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8897                            .copy_from_slice(carry_in);
8898                        unsafe {
8899                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
8900                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8901                                let psb = per_step_bytes as usize;
8902                                std::ptr::copy_nonoverlapping(
8903                                    base.add(outer_xs_off + t * psb),
8904                                    body_buf.as_mut_ptr().add(*body_x_off),
8905                                    psb,
8906                                );
8907                            }
8908                            std::ptr::copy_nonoverlapping(
8909                                dcarry.as_ptr(),
8910                                body_buf.as_mut_ptr().add(*body_d_output_off),
8911                                cb,
8912                            );
8913                        }
8914                        execute_thunks(body_vjp, body_buf);
8915                        unsafe {
8916                            std::ptr::copy_nonoverlapping(
8917                                body_buf.as_ptr().add(*body_dcarry_out_off),
8918                                dcarry.as_mut_ptr(),
8919                                cb,
8920                            );
8921                        }
8922                    };
8923
8924                if is_recursive {
8925                    // Griewank treeverse path. Process saved-checkpoint
8926                    // segments from highest-t to lowest-t; within each,
8927                    // recursive binary subdivision via
8928                    // `griewank_process_segment`. Auxiliary memory:
8929                    // O(log(seg_size) · cb) for the recursion stack
8930                    // (vs O(seg_size · cb) for the older segment-cache
8931                    // scheme); recompute work: O(seg_size · log).
8932                    let leaf_threshold = 4usize;
8933                    let fb_sched = forward_body.as_ref().unwrap();
8934                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8935                    let mut segment_end = n_steps - 1;
8936                    for seg_k in (0..k_total).rev() {
8937                        let segment_start = if seg_k == 0 {
8938                            0
8939                        } else {
8940                            checkpoint_t_for_k(seg_k - 1) + 1
8941                        };
8942                        let mut anchor: Vec<u8> = vec![0u8; cb];
8943                        unsafe {
8944                            let src = if seg_k == 0 {
8945                                base.add(*outer_init_off)
8946                            } else {
8947                                base.add(*outer_traj_off + (seg_k - 1) * cb)
8948                            };
8949                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8950                        }
8951                        // Closure adapter for the helper's signature
8952                        // (mutably re-borrows dcarry / body_buf each call).
8953                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
8954                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8955                        };
8956                        unsafe {
8957                            griewank_process_segment(
8958                                segment_start,
8959                                segment_end,
8960                                &anchor,
8961                                cb,
8962                                fb_sched,
8963                                fb_init,
8964                                *forward_body_carry_in_off,
8965                                *forward_body_output_off,
8966                                forward_body_x_offs,
8967                                base,
8968                                outer_xs_offs,
8969                                &mut fwd_buf,
8970                                leaf_threshold,
8971                                &mut leaf_action,
8972                            );
8973                        }
8974                        if seg_k == 0 {
8975                            break;
8976                        }
8977                        segment_end = segment_start - 1;
8978                    }
8979                } else {
8980                    // All-trajectory path: read each carry directly
8981                    // from the saved trajectory buffer.
8982                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
8983                    for t in (0..n_steps).rev() {
8984                        unsafe {
8985                            let src = if t == 0 {
8986                                base.add(*outer_init_off)
8987                            } else {
8988                                base.add(*outer_traj_off + (t - 1) * cb)
8989                            };
8990                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8991                        }
8992                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8993                    }
8994                }
8995
8996                unsafe {
8997                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8998                }
8999            }
9000
9001            Thunk::ScanBackwardXs {
9002                body_vjp,
9003                body_init,
9004                body_carry_in_off,
9005                body_x_offs,
9006                body_d_output_off,
9007                body_dcarry_out_off,
9008                body_dxs_out_off,
9009                outer_init_off,
9010                outer_traj_off,
9011                outer_upstream_off,
9012                outer_xs_offs,
9013                outer_dxs_off,
9014                length,
9015                carry_bytes,
9016                carry_elem_size,
9017                per_step_bytes,
9018                save_trajectory,
9019                num_checkpoints,
9020                forward_body,
9021                forward_body_init,
9022                forward_body_carry_in_off,
9023                forward_body_output_off,
9024                forward_body_x_offs,
9025            } => {
9026                let cb = *carry_bytes as usize;
9027                let psb = *per_step_bytes as usize;
9028                let n_steps = *length as usize;
9029                let k_total = *num_checkpoints as usize;
9030                let is_recursive = k_total != 0 && k_total != n_steps;
9031                let checkpoint_t_for_k = |k: usize| -> usize {
9032                    ((k + 1) * n_steps)
9033                        .div_ceil(k_total)
9034                        .saturating_sub(1)
9035                        .min(n_steps - 1)
9036                };
9037
9038                // Forward-body recompute scratch + segment cache —
9039                // exact mirror of the ScanBackward path. With ≈√length
9040                // checkpoints, total recompute work is O(length).
9041                let mut fwd_buf: Vec<u8> = if is_recursive {
9042                    (**forward_body_init.as_ref().unwrap()).clone()
9043                } else {
9044                    Vec::new()
9045                };
9046                let mut seg_cache: Vec<u8> = Vec::new();
9047                let mut seg_start_t: usize = usize::MAX;
9048                let mut seg_count: usize = 0;
9049                let recompute_carry_t =
9050                    |t: usize,
9051                     dst: &mut [u8],
9052                     fwd_buf: &mut Vec<u8>,
9053                     seg_cache: &mut Vec<u8>,
9054                     seg_start_t: &mut usize,
9055                     seg_count: &mut usize| {
9056                        if !is_recursive {
9057                            unsafe {
9058                                let src = if t == 0 {
9059                                    base.add(*outer_init_off)
9060                                } else {
9061                                    base.add(*outer_traj_off + (t - 1) * cb)
9062                                };
9063                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9064                            }
9065                            return;
9066                        }
9067                        if *seg_start_t != usize::MAX
9068                            && t >= *seg_start_t
9069                            && t < *seg_start_t + *seg_count
9070                        {
9071                            let off = (t - *seg_start_t) * cb;
9072                            dst.copy_from_slice(&seg_cache[off..off + cb]);
9073                            return;
9074                        }
9075                        let seg_k = (0..k_total)
9076                            .find(|&k| t <= checkpoint_t_for_k(k))
9077                            .unwrap_or(k_total - 1);
9078                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9079                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
9080                        } else {
9081                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
9082                            (prev_ck + 1, unsafe {
9083                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9084                            })
9085                        };
9086                        let seg_end_t = checkpoint_t_for_k(seg_k);
9087                        let seg_size = seg_end_t - anchor_t + 1;
9088
9089                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9090                        unsafe {
9091                            std::ptr::copy_nonoverlapping(
9092                                anchor_ptr,
9093                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9094                                cb,
9095                            );
9096                        }
9097                        seg_cache.resize(seg_size * cb, 0u8);
9098                        seg_cache[0..cb].copy_from_slice(
9099                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9100                        );
9101                        let fb_sched = forward_body.as_ref().unwrap();
9102                        for i in 1..seg_size {
9103                            let cur_iter = anchor_t + i - 1;
9104                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9105                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9106                                let xb = x_psb as usize;
9107                                unsafe {
9108                                    std::ptr::copy_nonoverlapping(
9109                                        base.add(outer_xs_off + cur_iter * xb),
9110                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
9111                                        xb,
9112                                    );
9113                                }
9114                            }
9115                            execute_thunks(fb_sched, fwd_buf);
9116                            if *forward_body_output_off != *forward_body_carry_in_off {
9117                                fwd_buf.copy_within(
9118                                    *forward_body_output_off..*forward_body_output_off + cb,
9119                                    *forward_body_carry_in_off,
9120                                );
9121                            }
9122                            let cache_off = i * cb;
9123                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
9124                                &fwd_buf
9125                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9126                            );
9127                        }
9128                        *seg_start_t = anchor_t;
9129                        *seg_count = seg_size;
9130
9131                        let off = (t - anchor_t) * cb;
9132                        dst.copy_from_slice(&seg_cache[off..off + cb]);
9133                    };
9134
9135                let mut dcarry: Vec<u8> = vec![0u8; cb];
9136                if !*save_trajectory {
9137                    unsafe {
9138                        std::ptr::copy_nonoverlapping(
9139                            base.add(*outer_upstream_off),
9140                            dcarry.as_mut_ptr(),
9141                            cb,
9142                        );
9143                    }
9144                }
9145
9146                let mut body_buf: Vec<u8> = (**body_init).clone();
9147
9148                for t in (0..n_steps).rev() {
9149                    if *save_trajectory {
9150                        unsafe {
9151                            let up_off = *outer_upstream_off + t * cb;
9152                            match *carry_elem_size {
9153                                4 => {
9154                                    let up_ptr = base.add(up_off) as *const f32;
9155                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9156                                    let n_elems = cb / 4;
9157                                    for i in 0..n_elems {
9158                                        *dc_ptr.add(i) += *up_ptr.add(i);
9159                                    }
9160                                }
9161                                8 => {
9162                                    let up_ptr = base.add(up_off) as *const f64;
9163                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9164                                    let n_elems = cb / 8;
9165                                    for i in 0..n_elems {
9166                                        *dc_ptr.add(i) += *up_ptr.add(i);
9167                                    }
9168                                }
9169                                other => panic!(
9170                                    "ScanBackwardXs: unsupported carry elem size {other} \
9171                                     (only f32/f64 carries are supported today)"
9172                                ),
9173                            }
9174                        }
9175                    }
9176
9177                    // Seed body_vjp's carry input via the recompute
9178                    // helper (works for both All and Recursive modes),
9179                    // then x_t_i + d_output.
9180                    let carry_dst_start = *body_carry_in_off;
9181                    {
9182                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9183                        recompute_carry_t(
9184                            t,
9185                            carry_slice,
9186                            &mut fwd_buf,
9187                            &mut seg_cache,
9188                            &mut seg_start_t,
9189                            &mut seg_count,
9190                        );
9191                    }
9192                    unsafe {
9193                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
9194                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
9195                            let xb = x_psb as usize;
9196                            std::ptr::copy_nonoverlapping(
9197                                base.add(outer_xs_off + t * xb),
9198                                body_buf.as_mut_ptr().add(*body_x_off),
9199                                xb,
9200                            );
9201                        }
9202                        std::ptr::copy_nonoverlapping(
9203                            dcarry.as_ptr(),
9204                            body_buf.as_mut_ptr().add(*body_d_output_off),
9205                            cb,
9206                        );
9207                    }
9208
9209                    execute_thunks(body_vjp, &mut body_buf);
9210
9211                    // Stash this step's dxs into row `t` of the outer
9212                    // [length, *per_step_xs] output.
9213                    unsafe {
9214                        std::ptr::copy_nonoverlapping(
9215                            body_buf.as_ptr().add(*body_dxs_out_off),
9216                            base.add(*outer_dxs_off + t * psb),
9217                            psb,
9218                        );
9219                    }
9220
9221                    // Update dcarry for next backward iteration.
9222                    unsafe {
9223                        std::ptr::copy_nonoverlapping(
9224                            body_buf.as_ptr().add(*body_dcarry_out_off),
9225                            dcarry.as_mut_ptr(),
9226                            cb,
9227                        );
9228                    }
9229                }
9230            }
9231
9232            Thunk::FusedMmBiasAct {
9233                a,
9234                w,
9235                bias,
9236                c,
9237                m,
9238                k,
9239                n,
9240                act,
9241            } => {
9242                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9243                unsafe {
9244                    let out = sl_mut(*c, base, m * n);
9245                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9246                    match act {
9247                        Some(Activation::Gelu) => {
9248                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9249                        }
9250                        Some(other) => {
9251                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9252                            apply_activation_inplace(out, *other);
9253                        }
9254                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9255                    }
9256                }
9257            }
9258
9259            Thunk::FusedResidualLN {
9260                x,
9261                res,
9262                bias,
9263                g,
9264                b,
9265                out,
9266                rows,
9267                h,
9268                eps,
9269                has_bias,
9270            } => {
9271                let (rows, h) = (*rows as usize, *h as usize);
9272                unsafe {
9273                    let zero = &zero_bias[..h];
9274                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9275                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9276                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9277                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9278                    let bi_ptr = bi.as_ptr() as usize;
9279                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9280                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9281                    let e = *eps;
9282                    crate::pool::par_for(rows, 4, &|off, cnt| {
9283                        let xs =
9284                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9285                        let rs =
9286                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9287                        let os = std::slice::from_raw_parts_mut(
9288                            (o_ptr as *mut f32).add(off * h),
9289                            cnt * h,
9290                        );
9291                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9292                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9293                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9294                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9295                    });
9296                }
9297            }
9298
9299            Thunk::FusedResidualRmsNorm {
9300                x,
9301                res,
9302                bias,
9303                g,
9304                b,
9305                out,
9306                rows,
9307                h,
9308                eps,
9309                has_bias,
9310            } => {
9311                let (rows, h) = (*rows as usize, *h as usize);
9312                unsafe {
9313                    let zero = &zero_bias[..h];
9314                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9315                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9316                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9317                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9318                    let bi_ptr = bi.as_ptr() as usize;
9319                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9320                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9321                    let e = *eps;
9322                    crate::pool::par_for(rows, 4, &|off, cnt| {
9323                        let xs =
9324                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9325                        let rs =
9326                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9327                        let os = std::slice::from_raw_parts_mut(
9328                            (o_ptr as *mut f32).add(off * h),
9329                            cnt * h,
9330                        );
9331                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9332                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9333                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9334                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9335                    });
9336                }
9337            }
9338
9339            Thunk::BiasAdd {
9340                src,
9341                bias,
9342                dst,
9343                m,
9344                n,
9345            } => {
9346                let (m, n) = (*m as usize, *n as usize);
9347                let len = m * n;
9348                unsafe {
9349                    let out = sl_mut(*dst, base, len);
9350                    if *src != *dst {
9351                        let src_ptr = base.add(*src) as *const f32;
9352                        let dst_ptr = base.add(*dst) as *mut f32;
9353                        if src_ptr != dst_ptr {
9354                            std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9355                        }
9356                    }
9357                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9358                }
9359            }
9360
9361            Thunk::BinaryFull {
9362                lhs,
9363                rhs,
9364                dst,
9365                len,
9366                lhs_len,
9367                rhs_len,
9368                op,
9369                out_dims_bcast,
9370                bcast_lhs_strides,
9371                bcast_rhs_strides,
9372                elem_bytes,
9373            } => {
9374                let len = *len as usize;
9375                let ll = (*lhs_len as usize).max(1);
9376                let rl = (*rhs_len as usize).max(1);
9377                let eb = (*elem_bytes).max(1) as usize;
9378                let arena_len = arena_buf.len();
9379                let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9380                let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9381                let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9382                unsafe {
9383                    if eb == 8 {
9384                        let l = sl_i64(*lhs, base, ll);
9385                        let r = sl_i64(*rhs, base, rl);
9386                        let o = sl_mut_i64(*dst, base, len);
9387                        if !out_dims_bcast.is_empty() {
9388                            let rank = out_dims_bcast.len();
9389                            let mut coords = vec![0u32; rank];
9390                            for i in 0..len {
9391                                let mut rem = i;
9392                                for ax in (0..rank).rev() {
9393                                    let sz = out_dims_bcast[ax] as usize;
9394                                    coords[ax] = (rem % sz) as u32;
9395                                    rem /= sz;
9396                                }
9397                                let mut li = 0usize;
9398                                let mut ri = 0usize;
9399                                for ax in 0..rank {
9400                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9401                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9402                                }
9403                                o[i] = match op {
9404                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9405                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9406                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9407                                    BinaryOp::Div => {
9408                                        if r[ri] == 0 {
9409                                            0
9410                                        } else {
9411                                            l[li] / r[ri]
9412                                        }
9413                                    }
9414                                    BinaryOp::Max => l[li].max(r[ri]),
9415                                    BinaryOp::Min => l[li].min(r[ri]),
9416                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9417                                };
9418                            }
9419                        } else {
9420                            for i in 0..len {
9421                                let li = if ll == 1 { 0 } else { i % ll };
9422                                let ri = if rl == 1 { 0 } else { i % rl };
9423                                o[i] = match op {
9424                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9425                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9426                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9427                                    BinaryOp::Div => {
9428                                        if r[ri] == 0 {
9429                                            0
9430                                        } else {
9431                                            l[li] / r[ri]
9432                                        }
9433                                    }
9434                                    BinaryOp::Max => l[li].max(r[ri]),
9435                                    BinaryOp::Min => l[li].min(r[ri]),
9436                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9437                                };
9438                            }
9439                        }
9440                    } else {
9441                        let l = sl(*lhs, base, ll);
9442                        let r = sl(*rhs, base, rl);
9443                        let o = sl_mut(*dst, base, len);
9444                        if ll == len && rl == len {
9445                            #[cfg(target_arch = "aarch64")]
9446                            if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9447                                use std::arch::aarch64::*;
9448                                let chunks = len / 4;
9449                                for c in 0..chunks {
9450                                    let off = c * 4;
9451                                    let vl = vld1q_f32(l.as_ptr().add(off));
9452                                    let vr = vld1q_f32(r.as_ptr().add(off));
9453                                    let res = match op {
9454                                        BinaryOp::Add => vaddq_f32(vl, vr),
9455                                        BinaryOp::Mul => vmulq_f32(vl, vr),
9456                                        _ => unreachable!(),
9457                                    };
9458                                    vst1q_f32(o.as_mut_ptr().add(off), res);
9459                                }
9460                                for i in (chunks * 4)..len {
9461                                    o[i] = match op {
9462                                        BinaryOp::Add => l[i] + r[i],
9463                                        BinaryOp::Mul => l[i] * r[i],
9464                                        _ => unreachable!(),
9465                                    };
9466                                }
9467                                continue;
9468                            }
9469                        }
9470                        if !out_dims_bcast.is_empty() {
9471                            let rank = out_dims_bcast.len();
9472                            let mut coords = vec![0u32; rank];
9473                            for i in 0..len {
9474                                let mut rem = i;
9475                                for ax in (0..rank).rev() {
9476                                    let sz = out_dims_bcast[ax] as usize;
9477                                    coords[ax] = (rem % sz) as u32;
9478                                    rem /= sz;
9479                                }
9480                                let mut li = 0usize;
9481                                let mut ri = 0usize;
9482                                for ax in 0..rank {
9483                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9484                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9485                                }
9486                                o[i] = match op {
9487                                    BinaryOp::Add => l[li] + r[ri],
9488                                    BinaryOp::Sub => l[li] - r[ri],
9489                                    BinaryOp::Mul => l[li] * r[ri],
9490                                    BinaryOp::Div => l[li] / r[ri],
9491                                    BinaryOp::Max => l[li].max(r[ri]),
9492                                    BinaryOp::Min => l[li].min(r[ri]),
9493                                    BinaryOp::Pow => l[li].powf(r[ri]),
9494                                };
9495                            }
9496                        } else {
9497                            for i in 0..len {
9498                                let li = if ll == 1 { 0 } else { i % ll };
9499                                let ri = if rl == 1 { 0 } else { i % rl };
9500                                o[i] = match op {
9501                                    BinaryOp::Add => l[li] + r[ri],
9502                                    BinaryOp::Sub => l[li] - r[ri],
9503                                    BinaryOp::Mul => l[li] * r[ri],
9504                                    BinaryOp::Div => l[li] / r[ri],
9505                                    BinaryOp::Max => l[li].max(r[ri]),
9506                                    BinaryOp::Min => l[li].min(r[ri]),
9507                                    BinaryOp::Pow => l[li].powf(r[ri]),
9508                                };
9509                            }
9510                        }
9511                    }
9512                }
9513            }
9514
9515            Thunk::Gather {
9516                table,
9517                table_len,
9518                idx,
9519                dst,
9520                num_idx,
9521                trailing,
9522                idx_i64,
9523                table_bytes,
9524            } => {
9525                let (ni, tr) = (*num_idx as usize, *trailing as usize);
9526                let rows = *table_len as usize / tr.max(1);
9527                unsafe {
9528                    if *table_bytes == 8 {
9529                        let tab = sl_i64(*table, base, *table_len as usize);
9530                        let out = sl_mut_i64(*dst, base, ni * tr);
9531                        if *idx_i64 != 0 {
9532                            let ids = sl_i64(*idx, base, ni);
9533                            for i in 0..ni {
9534                                let row = ids[i].max(0) as usize;
9535                                if row < rows {
9536                                    out[i * tr..(i + 1) * tr]
9537                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9538                                }
9539                            }
9540                        } else {
9541                            let ids = sl(*idx, base, ni);
9542                            for i in 0..ni {
9543                                let row = ids[i] as usize;
9544                                if row < rows {
9545                                    out[i * tr..(i + 1) * tr]
9546                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9547                                }
9548                            }
9549                        }
9550                    } else {
9551                        let tab = sl(*table, base, *table_len as usize);
9552                        let out = sl_mut(*dst, base, ni * tr);
9553                        if *idx_i64 != 0 {
9554                            let ids = sl_i64(*idx, base, ni);
9555                            for i in 0..ni {
9556                                let row = ids[i].max(0) as usize;
9557                                if row < rows {
9558                                    out[i * tr..(i + 1) * tr]
9559                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9560                                }
9561                            }
9562                        } else {
9563                            let ids = sl(*idx, base, ni);
9564                            for i in 0..ni {
9565                                let row = ids[i] as usize;
9566                                if row < rows {
9567                                    out[i * tr..(i + 1) * tr]
9568                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
9569                                }
9570                            }
9571                        }
9572                    }
9573                }
9574            }
9575
9576            Thunk::Narrow {
9577                src,
9578                dst,
9579                outer,
9580                src_stride,
9581                dst_stride,
9582                inner,
9583                elem_bytes,
9584            } => {
9585                let (outer, ss, ds, inner, eb) = (
9586                    *outer as usize,
9587                    *src_stride as usize,
9588                    *dst_stride as usize,
9589                    *inner as usize,
9590                    *elem_bytes as usize,
9591                );
9592                let row_bytes = inner.saturating_mul(eb);
9593                let src_row_stride = ss.saturating_mul(eb);
9594                let dst_row_stride = ds.saturating_mul(eb);
9595                if trace_thunks {
9596                    eprintln!(
9597                        "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
9598                        *src,
9599                        *dst,
9600                        arena_buf.len()
9601                    );
9602                }
9603                if row_bytes > 0 && *src != *dst {
9604                    let arena_len = arena_buf.len();
9605                    for o in 0..outer {
9606                        let s_off = *src + o * src_row_stride;
9607                        let d_off = *dst + o * dst_row_stride;
9608                        if s_off == d_off {
9609                            continue;
9610                        }
9611                        if s_off.saturating_add(row_bytes) > arena_len
9612                            || d_off.saturating_add(row_bytes) > arena_len
9613                        {
9614                            break;
9615                        }
9616                        unsafe {
9617                            std::ptr::copy_nonoverlapping(
9618                                base.add(s_off),
9619                                base.add(d_off),
9620                                row_bytes,
9621                            );
9622                        }
9623                    }
9624                }
9625            }
9626
9627            Thunk::Copy { src, dst, len } => {
9628                let mut len = *len as usize;
9629                if *src == *dst || len == 0 {
9630                    continue;
9631                }
9632                let arena_len = arena_buf.len();
9633                let max_from_src = (arena_len.saturating_sub(*src)) / 4;
9634                let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
9635                len = len.min(max_from_src).min(max_from_dst);
9636                if len == 0 {
9637                    continue;
9638                }
9639                let byte_len = len.saturating_mul(4);
9640                unsafe {
9641                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
9642                }
9643            }
9644
9645            Thunk::LayerNorm {
9646                src,
9647                g,
9648                b,
9649                dst,
9650                rows,
9651                h,
9652                eps,
9653            } => {
9654                let (rows, h) = (*rows as usize, *h as usize);
9655                unsafe {
9656                    let input = sl(*src, base, rows * h);
9657                    let gamma = sl(*g, base, h);
9658                    let beta = sl(*b, base, h);
9659                    let output = sl_mut(*dst, base, rows * h);
9660                    // Parallelize across rows (same pattern as FusedResidualLN)
9661                    if rows >= 4 && rows * h >= 30_000 {
9662                        let i_ptr = input.as_ptr() as usize;
9663                        let o_ptr = output.as_mut_ptr() as usize;
9664                        let g_ptr = gamma.as_ptr() as usize;
9665                        let b_ptr = beta.as_ptr() as usize;
9666                        let e = *eps;
9667                        crate::pool::par_for(rows, 4, &|off, cnt| {
9668                            let inp = std::slice::from_raw_parts(
9669                                (i_ptr as *const f32).add(off * h),
9670                                cnt * h,
9671                            );
9672                            let out = std::slice::from_raw_parts_mut(
9673                                (o_ptr as *mut f32).add(off * h),
9674                                cnt * h,
9675                            );
9676                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9677                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9678                            for row in 0..cnt {
9679                                crate::kernels::layer_norm_row(
9680                                    &inp[row * h..(row + 1) * h],
9681                                    g,
9682                                    b,
9683                                    &mut out[row * h..(row + 1) * h],
9684                                    h,
9685                                    e,
9686                                );
9687                            }
9688                        });
9689                    } else {
9690                        for row in 0..rows {
9691                            crate::kernels::layer_norm_row(
9692                                &input[row * h..(row + 1) * h],
9693                                gamma,
9694                                beta,
9695                                &mut output[row * h..(row + 1) * h],
9696                                h,
9697                                *eps,
9698                            );
9699                        }
9700                    }
9701                }
9702            }
9703
9704            Thunk::GroupNorm {
9705                src,
9706                g,
9707                b,
9708                dst,
9709                n,
9710                c,
9711                h,
9712                w,
9713                num_groups,
9714                eps,
9715            } => {
9716                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9717                let plane = c * h * w;
9718                unsafe {
9719                    for ni in 0..n {
9720                        let input = sl(*src, base.add(ni * plane), plane);
9721                        let gamma = sl(*g, base, c);
9722                        let beta = sl(*b, base, c);
9723                        let output = sl_mut(*dst, base.add(ni * plane), plane);
9724                        crate::kernels::group_norm_nchw(
9725                            input,
9726                            gamma,
9727                            beta,
9728                            output,
9729                            1,
9730                            c,
9731                            h,
9732                            w,
9733                            *num_groups as usize,
9734                            *eps,
9735                        );
9736                    }
9737                }
9738            }
9739
9740            Thunk::BatchNormInference {
9741                src,
9742                g,
9743                b,
9744                mean,
9745                var,
9746                dst,
9747                count,
9748                channels,
9749                eps,
9750            } => {
9751                let count = *count as usize;
9752                let c = *channels as usize;
9753                let n = count * c;
9754                unsafe {
9755                    crate::kernels::batch_norm_inference(
9756                        sl(*src, base, n),
9757                        sl(*g, base, c),
9758                        sl(*b, base, c),
9759                        sl(*mean, base, c),
9760                        sl(*var, base, c),
9761                        sl_mut(*dst, base, n),
9762                        c,
9763                        *eps,
9764                    );
9765                }
9766            }
9767
9768            Thunk::LayerNorm2d {
9769                src,
9770                g,
9771                b,
9772                dst,
9773                n,
9774                c,
9775                h,
9776                w,
9777                eps,
9778            } => {
9779                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9780                let plane = c * h * w;
9781                unsafe {
9782                    let input = sl(*src, base, n * plane);
9783                    let gamma = sl(*g, base, c);
9784                    let beta = sl(*b, base, c);
9785                    let output = sl_mut(*dst, base, n * plane);
9786                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
9787                }
9788            }
9789
9790            Thunk::ConvTranspose2d {
9791                src,
9792                weight,
9793                dst,
9794                n,
9795                c_in,
9796                h,
9797                w_in,
9798                c_out,
9799                h_out,
9800                w_out,
9801                kh,
9802                kw,
9803                sh,
9804                sw,
9805                ph,
9806                pw,
9807                dh,
9808                dw,
9809                groups,
9810            } => {
9811                let n = *n as usize;
9812                let c_in = *c_in as usize;
9813                let h = *h as usize;
9814                let w_in = *w_in as usize;
9815                let c_out = *c_out as usize;
9816                let h_out = *h_out as usize;
9817                let w_out = *w_out as usize;
9818                unsafe {
9819                    let inp = sl(*src, base, n * c_in * h * w_in);
9820                    let wt = sl(
9821                        *weight,
9822                        base,
9823                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
9824                    );
9825                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
9826                    crate::kernels::conv_transpose2d_nchw(
9827                        inp,
9828                        wt,
9829                        out,
9830                        n,
9831                        c_in,
9832                        h,
9833                        w_in,
9834                        c_out,
9835                        h_out,
9836                        w_out,
9837                        *kh as usize,
9838                        *kw as usize,
9839                        *sh as usize,
9840                        *sw as usize,
9841                        *ph as usize,
9842                        *pw as usize,
9843                        *dh as usize,
9844                        *dw as usize,
9845                        *groups as usize,
9846                    );
9847                }
9848            }
9849
9850            Thunk::ResizeNearest2x {
9851                src,
9852                dst,
9853                n,
9854                c,
9855                h,
9856                w,
9857            } => {
9858                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
9859                let in_plane = c * h * w;
9860                let out_plane = c * h * 2 * w * 2;
9861                unsafe {
9862                    for ni in 0..n {
9863                        let input = sl(*src, base.add(ni * in_plane), in_plane);
9864                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
9865                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
9866                    }
9867                }
9868            }
9869
9870            Thunk::AxialRope2d {
9871                src,
9872                dst,
9873                batch,
9874                seq,
9875                hidden,
9876                end_x,
9877                end_y,
9878                head_dim,
9879                num_heads,
9880                theta,
9881                repeat_factor,
9882            } => {
9883                let b = *batch as usize;
9884                let s = *seq as usize;
9885                let hdim = *head_dim as usize;
9886                let nh = *num_heads as usize;
9887                let plane = s * (*hidden as usize);
9888                unsafe {
9889                    for bi in 0..b {
9890                        let input = sl(*src, base.add(bi * plane), plane);
9891                        let output = sl_mut(*dst, base.add(bi * plane), plane);
9892                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
9893                            input,
9894                            nh,
9895                            s,
9896                            hdim,
9897                            *end_x as usize,
9898                            *end_y as usize,
9899                            *theta,
9900                            *repeat_factor as usize,
9901                        );
9902                        output.copy_from_slice(&rotated);
9903                    }
9904                }
9905            }
9906
9907            Thunk::RmsNorm {
9908                src,
9909                g,
9910                b,
9911                dst,
9912                rows,
9913                h,
9914                eps,
9915            } => {
9916                let (rows, h) = (*rows as usize, *h as usize);
9917                unsafe {
9918                    let input = sl(*src, base, rows * h);
9919                    let gamma = sl(*g, base, h);
9920                    let beta = sl(*b, base, h);
9921                    let output = sl_mut(*dst, base, rows * h);
9922                    let inv_h = 1.0 / h as f32;
9923                    for row in 0..rows {
9924                        let in_row = &input[row * h..(row + 1) * h];
9925                        let out_row = &mut output[row * h..(row + 1) * h];
9926                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
9927                        let mut sumsq = 0f32;
9928                        for &v in in_row {
9929                            sumsq += v * v;
9930                        }
9931                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
9932                        for i in 0..h {
9933                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
9934                        }
9935                    }
9936                }
9937            }
9938
9939            Thunk::Softmax { data, rows, cols } => {
9940                let (rows, cols) = (*rows as usize, *cols as usize);
9941                unsafe {
9942                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
9943                }
9944            }
9945
9946            Thunk::Cumsum {
9947                src,
9948                dst,
9949                rows,
9950                cols,
9951                exclusive,
9952            } => {
9953                let (rows, cols) = (*rows as usize, *cols as usize);
9954                unsafe {
9955                    let s = sl(*src, base, rows * cols);
9956                    let d = sl_mut(*dst, base, rows * cols);
9957                    if *exclusive {
9958                        for r in 0..rows {
9959                            let mut acc = 0.0f32;
9960                            for c in 0..cols {
9961                                d[r * cols + c] = acc;
9962                                acc += s[r * cols + c];
9963                            }
9964                        }
9965                    } else {
9966                        for r in 0..rows {
9967                            let mut acc = 0.0f32;
9968                            for c in 0..cols {
9969                                acc += s[r * cols + c];
9970                                d[r * cols + c] = acc;
9971                            }
9972                        }
9973                    }
9974                }
9975            }
9976
9977            Thunk::Sample {
9978                logits,
9979                dst,
9980                batch,
9981                vocab,
9982                top_k,
9983                top_p,
9984                temperature,
9985                seed,
9986            } => {
9987                let (b, v) = (*batch as usize, *vocab as usize);
9988                let k = (*top_k as usize).min(v);
9989                unsafe {
9990                    let lg = sl(*logits, base, b * v);
9991                    let out = sl_mut(*dst, base, b);
9992                    let mut rng =
9993                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
9994                    for bi in 0..b {
9995                        let row = &lg[bi * v..(bi + 1) * v];
9996                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
9997                    }
9998                }
9999            }
10000
10001            Thunk::GatedDeltaNet {
10002                q,
10003                k,
10004                v,
10005                g,
10006                beta,
10007                state,
10008                dst,
10009                batch,
10010                seq,
10011                heads,
10012                state_size,
10013            } => unsafe {
10014                execute_gated_delta_net_f32(
10015                    *q,
10016                    *k,
10017                    *v,
10018                    *g,
10019                    *beta,
10020                    *state,
10021                    *dst,
10022                    *batch as usize,
10023                    *seq as usize,
10024                    *heads as usize,
10025                    *state_size as usize,
10026                    base,
10027                );
10028            },
10029
10030            Thunk::SelectiveScan {
10031                x,
10032                delta,
10033                a,
10034                b: bp,
10035                c: cp,
10036                dst,
10037                batch,
10038                seq,
10039                hidden,
10040                state_size,
10041            } => {
10042                let (b, s, h, n) = (
10043                    *batch as usize,
10044                    *seq as usize,
10045                    *hidden as usize,
10046                    *state_size as usize,
10047                );
10048                unsafe {
10049                    let xs = sl(*x, base, b * s * h);
10050                    let dt = sl(*delta, base, b * s * h);
10051                    let am = sl(*a, base, h * n);
10052                    let bm = sl(*bp, base, b * s * n);
10053                    let cm = sl(*cp, base, b * s * n);
10054                    let out = sl_mut(*dst, base, b * s * h);
10055
10056                    // State buffer per-batch: h channels × n state.
10057                    // Sequential along the seq dimension; could
10058                    // parallelize over batch+channel later.
10059                    let mut state = vec![0f32; h * n];
10060                    for bi in 0..b {
10061                        // Reset state at the start of each batch row.
10062                        for v in state.iter_mut() {
10063                            *v = 0.0;
10064                        }
10065                        for si in 0..s {
10066                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10067                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10068                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10069                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10070                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10071
10072                            for ci in 0..h {
10073                                let d = dt_row[ci];
10074                                let xv = x_row[ci];
10075                                let mut acc = 0f32;
10076                                for ni in 0..n {
10077                                    // Discretize: exp(d * a) and d * b.
10078                                    let da = (d * am[ci * n + ni]).exp();
10079                                    state[ci * n + ni] =
10080                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
10081                                    acc += c_row[ni] * state[ci * n + ni];
10082                                }
10083                                out_row[ci] = acc;
10084                            }
10085                        }
10086                    }
10087                }
10088            }
10089
10090            Thunk::DequantMatMul {
10091                x,
10092                w_q,
10093                scale,
10094                zp,
10095                dst,
10096                m,
10097                k,
10098                n,
10099                block_size,
10100                is_asymmetric,
10101            } => {
10102                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10103                let n_blocks = k.div_ceil(bs);
10104                unsafe {
10105                    let xs = sl(*x, base, m * k);
10106                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10107                    let scales = sl(*scale, base, n_blocks * n);
10108                    let zps = if *is_asymmetric {
10109                        sl(*zp, base, n_blocks * n)
10110                    } else {
10111                        &[][..]
10112                    };
10113                    let out = sl_mut(*dst, base, m * n);
10114                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10115                }
10116            }
10117
10118            Thunk::DequantMatMulGguf {
10119                x,
10120                w_q,
10121                dst,
10122                m,
10123                k,
10124                n,
10125                scheme,
10126            } => {
10127                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10128                let block_bytes = scheme.gguf_block_bytes() as usize;
10129                let block_elems = scheme.gguf_block_size() as usize;
10130                debug_assert!(
10131                    block_bytes > 0 && block_elems > 0,
10132                    "non-GGUF scheme in GGUF arm"
10133                );
10134                debug_assert!(
10135                    (k * n).is_multiple_of(block_elems),
10136                    "k*n={} not aligned to GGUF block size {}",
10137                    k * n,
10138                    block_elems
10139                );
10140                let total_bytes = (k * n) / block_elems * block_bytes;
10141                unsafe {
10142                    let xs = sl(*x, base, m * k);
10143                    let w_bytes_ptr = base.add(*w_q) as *const u8;
10144                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10145                    let out = sl_mut(*dst, base, m * n);
10146                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10147                }
10148            }
10149
10150            Thunk::DequantMatMulInt4 {
10151                x,
10152                w_q,
10153                scale,
10154                zp,
10155                dst,
10156                m,
10157                k,
10158                n,
10159                block_size,
10160                is_asymmetric,
10161            } => {
10162                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10163                let n_blocks = k.div_ceil(bs);
10164                unsafe {
10165                    let xs = sl(*x, base, m * k);
10166                    let w_bytes = std::slice::from_raw_parts(
10167                        base.add(*w_q) as *const u8,
10168                        (k * n).div_ceil(2),
10169                    );
10170                    let scales = sl(*scale, base, n_blocks * n);
10171                    let zps = if *is_asymmetric {
10172                        sl(*zp, base, n_blocks * n)
10173                    } else {
10174                        &[][..]
10175                    };
10176                    let out = sl_mut(*dst, base, m * n);
10177                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10178                }
10179            }
10180
10181            Thunk::DequantMatMulFp8 {
10182                x,
10183                w_q,
10184                scale,
10185                dst,
10186                m,
10187                k,
10188                n,
10189                e5m2,
10190            } => {
10191                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10192                unsafe {
10193                    let xs = sl(*x, base, m * k);
10194                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10195                    let scales = sl(*scale, base, n);
10196                    let out = sl_mut(*dst, base, m * n);
10197                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10198                }
10199            }
10200
10201            Thunk::DequantMatMulNvfp4 {
10202                x,
10203                w_q,
10204                scale,
10205                global_scale,
10206                dst,
10207                m,
10208                k,
10209                n,
10210            } => {
10211                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10212                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10213                unsafe {
10214                    let xs = sl(*x, base, m * k);
10215                    let w_bytes = std::slice::from_raw_parts(
10216                        base.add(*w_q) as *const u8,
10217                        (k * n).div_ceil(2),
10218                    );
10219                    let scale_bytes =
10220                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10221                    let gs = sl(*global_scale, base, 1)[0];
10222                    let out = sl_mut(*dst, base, m * n);
10223                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10224                }
10225            }
10226
10227            Thunk::LoraMatMul {
10228                x,
10229                w,
10230                a,
10231                b,
10232                dst,
10233                m,
10234                k,
10235                n,
10236                r,
10237                scale,
10238            } => {
10239                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10240                unsafe {
10241                    let xs = sl(*x, base, m * k);
10242                    let ws = sl(*w, base, k * n);
10243                    let a_s = sl(*a, base, k * r);
10244                    let bs = sl(*b, base, r * n);
10245                    let out = sl_mut(*dst, base, m * n);
10246                    crate::blas::sgemm(xs, ws, out, m, k, n);
10247                    let mut tmp = vec![0f32; m * r];
10248                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10249                    if *scale != 1.0 {
10250                        for v in tmp.iter_mut() {
10251                            *v *= *scale;
10252                        }
10253                    }
10254                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10255                }
10256            }
10257
10258            Thunk::Attention {
10259                q,
10260                k,
10261                v,
10262                mask,
10263                out,
10264                batch,
10265                seq,
10266                kv_seq,
10267                heads,
10268                head_dim,
10269                mask_kind,
10270                q_row_stride,
10271                k_row_stride,
10272                v_row_stride,
10273                bhsd,
10274            } => {
10275                let (b, q_s, k_s, nh, dh) = (
10276                    *batch as usize,
10277                    *seq as usize,
10278                    *kv_seq as usize,
10279                    *heads as usize,
10280                    *head_dim as usize,
10281                );
10282                let hs = nh * dh;
10283                // For [B, H, S, D] layout each (b, h) tile is dense
10284                // contiguous; the qrs/krs/vrs strides are not used.
10285                let (qrs, krs, vrs) = if *bhsd {
10286                    (dh, dh, dh)
10287                } else {
10288                    (
10289                        *q_row_stride as usize,
10290                        *k_row_stride as usize,
10291                        *v_row_stride as usize,
10292                    )
10293                };
10294                let bhsd = *bhsd;
10295                let _ = (q_row_stride, k_row_stride, v_row_stride);
10296                let scale = (dh as f32).powf(-0.5);
10297                let ss = q_s * k_s;
10298                let cfg = crate::config::RuntimeConfig::global();
10299                unsafe {
10300                    // Slice lengths cover the strided span. When Q/K/V
10301                    // alias the parent QKV (post-#46-fusion), the same
10302                    // bytes back all three slices — compiler bounds
10303                    // checks see the right size. For [B, H, S, D] the
10304                    // buffer is densely B*H*S*D elements; the row
10305                    // strides aren't used.
10306                    let q_len = if bhsd {
10307                        b * nh * q_s * dh
10308                    } else {
10309                        b * q_s * qrs
10310                    };
10311                    let k_len = if bhsd {
10312                        b * nh * k_s * dh
10313                    } else {
10314                        b * k_s * krs
10315                    };
10316                    let v_len = if bhsd {
10317                        b * nh * k_s * dh
10318                    } else {
10319                        b * k_s * vrs
10320                    };
10321                    let q_data = sl(*q, base, q_len);
10322                    let k_data = sl(*k, base, k_len);
10323                    let v_data = sl(*v, base, v_len);
10324                    let mask_data: &[f32] = match mask_kind {
10325                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10326                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10327                        _ => &[],
10328                    };
10329                    let out_len = if bhsd {
10330                        b * nh * q_s * dh
10331                    } else {
10332                        b * q_s * hs
10333                    };
10334                    let out_data = sl_mut(*out, base, out_len);
10335
10336                    // ── [B, H, S, D] fallback ──────────────────────
10337                    // The NEON / strided-BLAS specializations below
10338                    // are written for the [B, S, H, D] layout. When
10339                    // the input is head-major ([B, H, S, D] —
10340                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
10341                    // them and run a simple (correct but slower)
10342                    // scalar implementation. Production-CPU inference
10343                    // graphs use [B, S, H, D] so they still hit the
10344                    // hot path; cross-backend parity tests use
10345                    // [B, H, S, D] and land here.
10346                    if bhsd {
10347                        let scores = &mut sdpa_scores[..ss];
10348                        for bi in 0..b {
10349                            for hi in 0..nh {
10350                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10351                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10352                                // Q@K^T
10353                                for qi in 0..q_s {
10354                                    let q_base = q_head_base + qi * dh;
10355                                    for ki in 0..k_s {
10356                                        let k_base = k_head_base + ki * dh;
10357                                        let mut dot = 0f32;
10358                                        for d in 0..dh {
10359                                            dot += q_data[q_base + d] * k_data[k_base + d];
10360                                        }
10361                                        scores[qi * k_s + ki] = dot * scale;
10362                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10363                                            && !mask_data.is_empty()
10364                                            && mask_data[bi * k_s + ki] < mask_thr
10365                                        {
10366                                            scores[qi * k_s + ki] = mask_neg;
10367                                        }
10368                                    }
10369                                }
10370                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10371                                    let off = (bi * nh + hi) * q_s * k_s;
10372                                    for i in 0..q_s * k_s {
10373                                        scores[i] += mask_data[off + i];
10374                                    }
10375                                }
10376                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10377                                crate::kernels::neon_softmax(scores, q_s, k_s);
10378                                // score @ V
10379                                for qi in 0..q_s {
10380                                    let o_base = q_head_base + qi * dh;
10381                                    for d in 0..dh {
10382                                        out_data[o_base + d] = 0.0;
10383                                    }
10384                                    for ki in 0..k_s {
10385                                        let sc = scores[qi * k_s + ki];
10386                                        if sc > score_thr {
10387                                            let v_base = k_head_base + ki * dh;
10388                                            for d in 0..dh {
10389                                                out_data[o_base + d] += sc * v_data[v_base + d];
10390                                            }
10391                                        }
10392                                    }
10393                                }
10394                            }
10395                        }
10396                        continue;
10397                    }
10398
10399                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
10400                    // For tiny inputs (batch=1, short seq), per-head BLAS call
10401                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
10402                    // exceeds the NEON compute cost. Use direct strided NEON
10403                    // with zero dispatch overhead.
10404                    // For batch≥2: always BLAS + par_for (parallelism wins).
10405                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10406                        // ── Sequential NEON path (zero overhead) ──
10407                        let scores = &mut sdpa_scores[..ss];
10408                        #[cfg(target_arch = "aarch64")]
10409                        let neon_chunks = dh / 4;
10410
10411                        for bi in 0..b {
10412                            for hi in 0..nh {
10413                                // Q@K^T via strided NEON dot products
10414                                for qi in 0..q_s {
10415                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10416                                    for ki in 0..k_s {
10417                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
10418                                        #[cfg(target_arch = "aarch64")]
10419                                        let mut dot;
10420                                        #[cfg(not(target_arch = "aarch64"))]
10421                                        let mut dot = 0f32;
10422                                        #[cfg(target_arch = "aarch64")]
10423                                        {
10424                                            use std::arch::aarch64::*;
10425                                            let mut acc = vdupq_n_f32(0.0);
10426                                            for c in 0..neon_chunks {
10427                                                let vq =
10428                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10429                                                let vk =
10430                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10431                                                acc = vfmaq_f32(acc, vq, vk);
10432                                            }
10433                                            dot = vaddvq_f32(acc);
10434                                            for d in (neon_chunks * 4)..dh {
10435                                                dot += q_data[q_off + d] * k_data[k_off + d];
10436                                            }
10437                                        }
10438                                        #[cfg(not(target_arch = "aarch64"))]
10439                                        for d in 0..dh {
10440                                            dot += q_data[q_off + d] * k_data[k_off + d];
10441                                        }
10442                                        scores[qi * k_s + ki] = dot * scale;
10443                                        // Inner-loop Custom mask check —
10444                                        // Causal / SlidingWindow / None
10445                                        // apply outside the loop below.
10446                                        // Skip for Bias — that mask is a
10447                                        // per-head additive tensor, not a
10448                                        // 0/1 key-padding mask.
10449                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10450                                            && !mask_data.is_empty()
10451                                            && mask_data[bi * k_s + ki] < mask_thr
10452                                        {
10453                                            scores[qi * k_s + ki] = mask_neg;
10454                                        }
10455                                    }
10456                                }
10457
10458                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10459                                    let off = (bi * nh + hi) * q_s * k_s;
10460                                    for i in 0..q_s * k_s {
10461                                        scores[i] += mask_data[off + i];
10462                                    }
10463                                }
10464                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10465                                crate::kernels::neon_softmax(scores, q_s, k_s);
10466
10467                                // Score@V via strided NEON accumulation (zero-copy)
10468                                for qi in 0..q_s {
10469                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
10470                                    // Zero output for this head position
10471                                    for d in 0..dh {
10472                                        out_data[o_off + d] = 0.0;
10473                                    }
10474                                    for ki in 0..k_s {
10475                                        let sc = scores[qi * k_s + ki];
10476                                        if sc > score_thr {
10477                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
10478                                            #[cfg(target_arch = "aarch64")]
10479                                            {
10480                                                use std::arch::aarch64::*;
10481                                                let vsc = vdupq_n_f32(sc);
10482                                                for c in 0..neon_chunks {
10483                                                    let off = c * 4;
10484                                                    let vo = vld1q_f32(
10485                                                        out_data.as_ptr().add(o_off + off),
10486                                                    );
10487                                                    let vv =
10488                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
10489                                                    vst1q_f32(
10490                                                        out_data.as_mut_ptr().add(o_off + off),
10491                                                        vfmaq_f32(vo, vsc, vv),
10492                                                    );
10493                                                }
10494                                            }
10495                                            #[cfg(not(target_arch = "aarch64"))]
10496                                            for d in 0..dh {
10497                                                out_data[o_off + d] += sc * v_data[v_off + d];
10498                                            }
10499                                        }
10500                                    }
10501                                }
10502                            }
10503                        }
10504                    } else {
10505                        // ── Parallel strided BLAS path (high throughput) ──
10506                        let total_work = b * nh;
10507                        let q_addr = q_data.as_ptr() as usize;
10508                        let k_addr = k_data.as_ptr() as usize;
10509                        let v_addr = v_data.as_ptr() as usize;
10510                        let m_addr = mask_data.as_ptr() as usize;
10511                        let o_addr = out_data.as_mut_ptr() as usize;
10512                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
10513
10514                        crate::pool::par_for(total_work, 1, &|off, cnt| {
10515                            for idx in off..off + cnt {
10516                                let bi = idx / nh;
10517                                let hi = idx % nh;
10518
10519                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
10520                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
10521                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
10522                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
10523                                let sc = std::slice::from_raw_parts_mut(
10524                                    (sc_addr as *mut f32).add(idx * ss),
10525                                    ss,
10526                                );
10527
10528                                // LDA = qrs, LDB = krs (parent row strides
10529                                // when fused; hs otherwise).
10530                                crate::blas::sgemm_general(
10531                                    q_start,
10532                                    k_start,
10533                                    sc.as_mut_ptr(),
10534                                    q_s,
10535                                    k_s,
10536                                    dh,
10537                                    scale,
10538                                    0.0,
10539                                    qrs,
10540                                    krs,
10541                                    k_s,
10542                                    false,
10543                                    true,
10544                                );
10545
10546                                match mask_kind {
10547                                    rlx_ir::op::MaskKind::Custom => {
10548                                        let mask_bi = std::slice::from_raw_parts(
10549                                            (m_addr as *const f32).add(bi * k_s),
10550                                            k_s,
10551                                        );
10552                                        for ki in 0..k_s {
10553                                            if mask_bi[ki] < mask_thr {
10554                                                for qi in 0..q_s {
10555                                                    sc[qi * k_s + ki] = mask_neg;
10556                                                }
10557                                            }
10558                                        }
10559                                    }
10560                                    rlx_ir::op::MaskKind::Bias => {
10561                                        // Per-head additive bias slice.
10562                                        let bias = std::slice::from_raw_parts(
10563                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
10564                                            q_s * k_s,
10565                                        );
10566                                        for i in 0..q_s * k_s {
10567                                            sc[i] += bias[i];
10568                                        }
10569                                    }
10570                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
10571                                }
10572
10573                                crate::kernels::neon_softmax(sc, q_s, k_s);
10574
10575                                // LDB = vrs (parent row stride when
10576                                // fused; hs otherwise). LDC stays hs —
10577                                // output is its own contiguous buffer.
10578                                crate::blas::sgemm_general(
10579                                    sc.as_ptr(),
10580                                    v_start,
10581                                    o_start,
10582                                    q_s,
10583                                    dh,
10584                                    k_s,
10585                                    1.0,
10586                                    0.0,
10587                                    k_s,
10588                                    vrs,
10589                                    hs,
10590                                    false,
10591                                    false,
10592                                );
10593                            }
10594                        });
10595                    }
10596                }
10597            }
10598
10599            Thunk::AttentionBackward {
10600                q,
10601                k,
10602                v,
10603                dy,
10604                mask,
10605                out,
10606                batch,
10607                seq,
10608                kv_seq,
10609                heads,
10610                head_dim,
10611                mask_kind,
10612                wrt,
10613                bhsd,
10614            } => {
10615                let (b, q_s, k_s, nh, dh) = (
10616                    *batch as usize,
10617                    *seq as usize,
10618                    *kv_seq as usize,
10619                    *heads as usize,
10620                    *head_dim as usize,
10621                );
10622                unsafe {
10623                    let q_len = if *bhsd {
10624                        b * nh * q_s * dh
10625                    } else {
10626                        b * q_s * nh * dh
10627                    };
10628                    let k_len = if *bhsd {
10629                        b * nh * k_s * dh
10630                    } else {
10631                        b * k_s * nh * dh
10632                    };
10633                    let out_len = match wrt {
10634                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
10635                            k_len
10636                        }
10637                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
10638                    };
10639                    let q_data = sl(*q, base, q_len);
10640                    let k_data = sl(*k, base, k_len);
10641                    let v_data = sl(*v, base, k_len);
10642                    let dy_data = sl(*dy, base, q_len);
10643                    let out_data = sl_mut(*out, base, out_len);
10644                    let mask_data: &[f32] = if *mask != 0 {
10645                        let ml = match mask_kind {
10646                            rlx_ir::op::MaskKind::Custom => b * k_s,
10647                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
10648                            _ => 0,
10649                        };
10650                        sl(*mask, base, ml)
10651                    } else {
10652                        &[]
10653                    };
10654                    crate::attention_bwd::attention_backward(
10655                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
10656                        *mask_kind, mask_data, *bhsd,
10657                    );
10658                }
10659            }
10660
10661            Thunk::ActivationInPlace { data, len, act } => {
10662                let len = *len as usize;
10663                unsafe {
10664                    let d = sl_mut(*data, base, len);
10665                    match act {
10666                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
10667                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
10668                        Activation::Silu => crate::kernels::par_silu_inplace(d),
10669                        Activation::Relu => {
10670                            for v in d.iter_mut() {
10671                                *v = v.max(0.0);
10672                            }
10673                        }
10674                        Activation::Sigmoid => {
10675                            for v in d.iter_mut() {
10676                                *v = 1.0 / (1.0 + (-*v).exp());
10677                            }
10678                        }
10679                        Activation::Tanh => {
10680                            for v in d.iter_mut() {
10681                                *v = v.tanh();
10682                            }
10683                        }
10684                        Activation::Exp => {
10685                            for v in d.iter_mut() {
10686                                *v = v.exp();
10687                            }
10688                        }
10689                        Activation::Log => {
10690                            for v in d.iter_mut() {
10691                                *v = v.ln();
10692                            }
10693                        }
10694                        Activation::Sqrt => {
10695                            for v in d.iter_mut() {
10696                                *v = v.sqrt();
10697                            }
10698                        }
10699                        Activation::Rsqrt => {
10700                            for v in d.iter_mut() {
10701                                *v = 1.0 / v.sqrt();
10702                            }
10703                        }
10704                        Activation::Neg => {
10705                            for v in d.iter_mut() {
10706                                *v = -*v;
10707                            }
10708                        }
10709                        Activation::Abs => {
10710                            for v in d.iter_mut() {
10711                                *v = v.abs();
10712                            }
10713                        }
10714                        Activation::Round => {
10715                            for v in d.iter_mut() {
10716                                *v = v.round();
10717                            }
10718                        }
10719                        Activation::Sin => {
10720                            for v in d.iter_mut() {
10721                                *v = v.sin();
10722                            }
10723                        }
10724                        Activation::Cos => {
10725                            for v in d.iter_mut() {
10726                                *v = v.cos();
10727                            }
10728                        }
10729                        Activation::Tan => {
10730                            for v in d.iter_mut() {
10731                                *v = v.tan();
10732                            }
10733                        }
10734                        Activation::Atan => {
10735                            for v in d.iter_mut() {
10736                                *v = v.atan();
10737                            }
10738                        }
10739                    }
10740                }
10741            }
10742
10743            Thunk::FusedAttnBlock {
10744                hidden,
10745                qkv_w,
10746                out_w,
10747                mask,
10748                out,
10749                qkv_b,
10750                out_b,
10751                cos,
10752                sin,
10753                cos_len,
10754                batch,
10755                seq,
10756                hs,
10757                nh,
10758                dh,
10759                has_bias,
10760                has_rope,
10761            } => {
10762                let (b, s) = (*batch as usize, *seq as usize);
10763                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
10764                let m = b * s;
10765                let scale = (d_h as f32).powf(-0.5);
10766                let half = d_h / 2;
10767                unsafe {
10768                    let inp = sl(*hidden, base, m * h);
10769                    let wq = sl(*qkv_w, base, h * 3 * h);
10770                    let wo = sl(*out_w, base, h * h);
10771                    let mk = sl(*mask, base, b * s);
10772                    let dst = sl_mut(*out, base, m * h);
10773
10774                    // Stack-allocated intermediates — all fit in L1 cache for small batch
10775                    let mut qkv = vec![0f32; m * 3 * h];
10776                    let mut attn_out = vec![0f32; m * h];
10777                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
10778
10779                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
10780                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
10781                    if *has_bias {
10782                        let bias = sl(*qkv_b, base, 3 * h);
10783                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
10784                    }
10785
10786                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
10787                    //    Process heads sequentially with inline RoPE — zero copy.
10788                    #[cfg(target_arch = "aarch64")]
10789                    let neon_chunks = d_h / 4;
10790                    #[cfg(target_arch = "aarch64")]
10791                    let _rope_chunks = half / 4;
10792
10793                    for bi in 0..b {
10794                        for hi in 0..n_h {
10795                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
10796                            for qi in 0..s {
10797                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10798                                for ki in 0..s {
10799                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10800                                    let mut dot = 0f32;
10801
10802                                    if *has_rope {
10803                                        // Apply RoPE inline during dot product
10804                                        let q_cos = qi * half;
10805                                        let k_cos = ki * half;
10806                                        let cos_tab = sl(*cos, base, *cos_len as usize);
10807                                        let sin_tab = sl(*sin, base, *cos_len as usize);
10808                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
10809                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
10810                                        for i in 0..half {
10811                                            let q1 = qkv[q_base + i];
10812                                            let q2 = qkv[q_base + half + i];
10813                                            let k1 = qkv[k_base + i];
10814                                            let k2 = qkv[k_base + half + i];
10815                                            let c_q = cos_tab[q_cos + i];
10816                                            let s_q = sin_tab[q_cos + i];
10817                                            let c_k = cos_tab[k_cos + i];
10818                                            let s_k = sin_tab[k_cos + i];
10819                                            let qr1 = q1 * c_q - q2 * s_q;
10820                                            let kr1 = k1 * c_k - k2 * s_k;
10821                                            let qr2 = q2 * c_q + q1 * s_q;
10822                                            let kr2 = k2 * c_k + k1 * s_k;
10823                                            dot += qr1 * kr1 + qr2 * kr2;
10824                                        }
10825                                    } else {
10826                                        // Standard dot product
10827                                        #[cfg(target_arch = "aarch64")]
10828                                        {
10829                                            use std::arch::aarch64::*;
10830                                            let mut acc = vdupq_n_f32(0.0);
10831                                            for c in 0..neon_chunks {
10832                                                let vq =
10833                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
10834                                                let vk =
10835                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
10836                                                acc = vfmaq_f32(acc, vq, vk);
10837                                            }
10838                                            dot = vaddvq_f32(acc);
10839                                            for d in (neon_chunks * 4)..d_h {
10840                                                dot += qkv[q_base + d] * qkv[k_base + d];
10841                                            }
10842                                        }
10843                                        #[cfg(not(target_arch = "aarch64"))]
10844                                        for d in 0..d_h {
10845                                            dot += qkv[q_base + d] * qkv[k_base + d];
10846                                        }
10847                                    }
10848
10849                                    scores_buf[qi * s + ki] = dot * scale;
10850                                    if mk[bi * s + ki] < mask_thr {
10851                                        scores_buf[qi * s + ki] = mask_neg;
10852                                    }
10853                                }
10854                            }
10855
10856                            // Softmax
10857                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
10858
10859                            // Score @ V accumulation (V at offset 2h in QKV)
10860                            for qi in 0..s {
10861                                let o_base = bi * s * h + qi * h + hi * d_h;
10862                                for d in 0..d_h {
10863                                    attn_out[o_base + d] = 0.0;
10864                                }
10865                                for ki in 0..s {
10866                                    let sc = scores_buf[qi * s + ki];
10867                                    if sc > score_thr {
10868                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10869                                        #[cfg(target_arch = "aarch64")]
10870                                        {
10871                                            use std::arch::aarch64::*;
10872                                            let vsc = vdupq_n_f32(sc);
10873                                            for c in 0..neon_chunks {
10874                                                let off = c * 4;
10875                                                let vo =
10876                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
10877                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
10878                                                vst1q_f32(
10879                                                    attn_out.as_mut_ptr().add(o_base + off),
10880                                                    vfmaq_f32(vo, vsc, vv),
10881                                                );
10882                                            }
10883                                        }
10884                                        #[cfg(not(target_arch = "aarch64"))]
10885                                        for d in 0..d_h {
10886                                            attn_out[o_base + d] += sc * qkv[v_base + d];
10887                                        }
10888                                    }
10889                                }
10890                            }
10891                        }
10892                    }
10893
10894                    // 3. Output projection: [m, h] @ [h, h] → dst
10895                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
10896                    if *has_bias {
10897                        let bias = sl(*out_b, base, h);
10898                        crate::blas::bias_add(dst, bias, m, h);
10899                    }
10900                }
10901            }
10902
10903            Thunk::Rope {
10904                src,
10905                cos,
10906                sin,
10907                dst,
10908                batch,
10909                seq,
10910                hidden,
10911                head_dim,
10912                n_rot,
10913                cos_len,
10914                src_row_stride,
10915            } => {
10916                let (b, s, hs, dh, nr) = (
10917                    *batch as usize,
10918                    *seq as usize,
10919                    *hidden as usize,
10920                    *head_dim as usize,
10921                    *n_rot as usize,
10922                );
10923                let tab_half = dh / 2;
10924                let rot_half = nr / 2;
10925                let nh = hs / dh;
10926                let cl = *cos_len as usize;
10927                let src_rs = *src_row_stride as usize;
10928                unsafe {
10929                    let x = sl(*src, base, b * s * src_rs);
10930                    let cos_tab = sl(*cos, base, cl);
10931                    let sin_tab = sl(*sin, base, cl);
10932                    let out = sl_mut(*dst, base, b * s * hs);
10933
10934                    let total = b * s;
10935                    let x_ptr = x.as_ptr() as usize;
10936                    let o_ptr = out.as_mut_ptr() as usize;
10937                    let c_ptr = cos_tab.as_ptr() as usize;
10938                    let s_ptr = sin_tab.as_ptr() as usize;
10939
10940                    crate::pool::par_for(total, 4, &|off, cnt| {
10941                        for idx in off..off + cnt {
10942                            let bi = idx / s;
10943                            let si = idx % s;
10944                            let tab_off = si * tab_half;
10945
10946                            for hi in 0..nh {
10947                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
10948                                let dst_base = bi * s * hs + si * hs + hi * dh;
10949                                let xp = (x_ptr as *const f32).add(src_base);
10950                                let op = (o_ptr as *mut f32).add(dst_base);
10951                                let cp = (c_ptr as *const f32).add(tab_off);
10952                                let sp = (s_ptr as *const f32).add(tab_off);
10953
10954                                for i in 0..rot_half {
10955                                    let x1 = *xp.add(i);
10956                                    let x2 = *xp.add(rot_half + i);
10957                                    let cv = *cp.add(i);
10958                                    let sv = *sp.add(i);
10959                                    *op.add(i) = x1 * cv - x2 * sv;
10960                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
10961                                }
10962                                for j in nr..dh {
10963                                    *op.add(j) = *xp.add(j);
10964                                }
10965                            }
10966                        }
10967                    });
10968                }
10969            }
10970            Thunk::FusedBertLayer {
10971                hidden,
10972                qkv_w,
10973                qkv_b,
10974                out_w,
10975                out_b,
10976                mask,
10977                ln1_g,
10978                ln1_b,
10979                eps1,
10980                fc1_w,
10981                fc1_b,
10982                fc2_w,
10983                fc2_b,
10984                ln2_g,
10985                ln2_b,
10986                eps2,
10987                out,
10988                batch,
10989                seq,
10990                hs,
10991                nh,
10992                dh,
10993                int_dim,
10994            } => {
10995                let (b, s, h, n_h, d_h) = (
10996                    *batch as usize,
10997                    *seq as usize,
10998                    *hs as usize,
10999                    *nh as usize,
11000                    *dh as usize,
11001                );
11002                let m = b * s;
11003                let id = *int_dim as usize;
11004                let scale = (d_h as f32).powf(-0.5);
11005                let _half = d_h / 2;
11006                #[cfg(target_arch = "aarch64")]
11007                let neon_chunks = d_h / 4;
11008                unsafe {
11009                    let inp = sl(*hidden, base, m * h);
11010                    let dst = sl_mut(*out, base, m * h);
11011                    let mk = sl(*mask, base, b * s);
11012
11013                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
11014                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
11015                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
11016                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
11017                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
11018                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
11019                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
11020
11021                    // QKV (parallelized across cores — multiple AMX coprocessors)
11022                    crate::blas::par_sgemm_bias(
11023                        inp,
11024                        sl(*qkv_w, base, h * 3 * h),
11025                        sl(*qkv_b, base, 3 * h),
11026                        qkv,
11027                        m,
11028                        h,
11029                        3 * h,
11030                    );
11031
11032                    // SDPA per head (sequential NEON, inline — zero overhead)
11033                    for bi in 0..b {
11034                        for hi in 0..n_h {
11035                            for qi in 0..s {
11036                                for ki in 0..s {
11037                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11038                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11039                                    #[cfg(target_arch = "aarch64")]
11040                                    let dot;
11041                                    #[cfg(not(target_arch = "aarch64"))]
11042                                    let mut dot = 0f32;
11043                                    #[cfg(target_arch = "aarch64")]
11044                                    {
11045                                        use std::arch::aarch64::*;
11046                                        let mut acc = vdupq_n_f32(0.0);
11047                                        for c in 0..neon_chunks {
11048                                            acc = vfmaq_f32(
11049                                                acc,
11050                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
11051                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
11052                                            );
11053                                        }
11054                                        dot = vaddvq_f32(acc);
11055                                    }
11056                                    #[cfg(not(target_arch = "aarch64"))]
11057                                    for d in 0..d_h {
11058                                        dot += qkv[q_base + d] * qkv[k_base + d];
11059                                    }
11060                                    sc[qi * s + ki] = dot * scale;
11061                                    if mk[bi * s + ki] < mask_thr {
11062                                        sc[qi * s + ki] = mask_neg;
11063                                    }
11064                                }
11065                            }
11066                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11067                            for qi in 0..s {
11068                                let o = bi * s * h + qi * h + hi * d_h;
11069                                for d in 0..d_h {
11070                                    attn[o + d] = 0.0;
11071                                }
11072                                for ki in 0..s {
11073                                    let w = sc[qi * s + ki];
11074                                    if w > score_thr {
11075                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11076                                        #[cfg(target_arch = "aarch64")]
11077                                        {
11078                                            use std::arch::aarch64::*;
11079                                            let vw = vdupq_n_f32(w);
11080                                            for c in 0..neon_chunks {
11081                                                let off = c * 4;
11082                                                vst1q_f32(
11083                                                    attn.as_mut_ptr().add(o + off),
11084                                                    vfmaq_f32(
11085                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11086                                                        vw,
11087                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11088                                                    ),
11089                                                );
11090                                            }
11091                                        }
11092                                        #[cfg(not(target_arch = "aarch64"))]
11093                                        for d in 0..d_h {
11094                                            attn[o + d] += w * qkv[v + d];
11095                                        }
11096                                    }
11097                                }
11098                            }
11099                        }
11100                    }
11101
11102                    // Out proj (sgemm + bias fused) + residual add with NEON
11103                    crate::blas::sgemm_bias(
11104                        attn,
11105                        sl(*out_w, base, h * h),
11106                        sl(*out_b, base, h),
11107                        res,
11108                        m,
11109                        h,
11110                        h,
11111                    );
11112                    #[cfg(target_arch = "aarch64")]
11113                    {
11114                        use std::arch::aarch64::*;
11115                        let chunks_h = (m * h) / 4;
11116                        for c in 0..chunks_h {
11117                            let off = c * 4;
11118                            vst1q_f32(
11119                                res.as_mut_ptr().add(off),
11120                                vaddq_f32(
11121                                    vld1q_f32(res.as_ptr().add(off)),
11122                                    vld1q_f32(inp.as_ptr().add(off)),
11123                                ),
11124                            );
11125                        }
11126                        for i in (chunks_h * 4)..(m * h) {
11127                            res[i] += inp[i];
11128                        }
11129                    }
11130                    #[cfg(not(target_arch = "aarch64"))]
11131                    for i in 0..m * h {
11132                        res[i] += inp[i];
11133                    }
11134
11135                    // LN1 (fused residual already done above — just normalize)
11136                    let g1 = sl(*ln1_g, base, h);
11137                    let b1 = sl(*ln1_b, base, h);
11138                    for r in 0..m {
11139                        crate::kernels::layer_norm_row(
11140                            &res[r * h..(r + 1) * h],
11141                            g1,
11142                            b1,
11143                            &mut normed[r * h..(r + 1) * h],
11144                            h,
11145                            *eps1,
11146                        );
11147                    }
11148
11149                    // FFN: fc1 (parallel across cores) + GELU
11150                    crate::blas::par_sgemm_bias(
11151                        normed,
11152                        sl(*fc1_w, base, h * id),
11153                        sl(*fc1_b, base, id),
11154                        ffn,
11155                        m,
11156                        h,
11157                        id,
11158                    );
11159                    crate::kernels::par_gelu_inplace(ffn);
11160
11161                    // fc2 + bias (parallel across cores) + residual with NEON
11162                    crate::blas::par_sgemm_bias(
11163                        ffn,
11164                        sl(*fc2_w, base, id * h),
11165                        sl(*fc2_b, base, h),
11166                        res,
11167                        m,
11168                        id,
11169                        h,
11170                    );
11171                    #[cfg(target_arch = "aarch64")]
11172                    {
11173                        use std::arch::aarch64::*;
11174                        let chunks_h = (m * h) / 4;
11175                        for c in 0..chunks_h {
11176                            let off = c * 4;
11177                            vst1q_f32(
11178                                res.as_mut_ptr().add(off),
11179                                vaddq_f32(
11180                                    vld1q_f32(res.as_ptr().add(off)),
11181                                    vld1q_f32(normed.as_ptr().add(off)),
11182                                ),
11183                            );
11184                        }
11185                        for i in (chunks_h * 4)..(m * h) {
11186                            res[i] += normed[i];
11187                        }
11188                    }
11189                    #[cfg(not(target_arch = "aarch64"))]
11190                    for i in 0..m * h {
11191                        res[i] += normed[i];
11192                    }
11193
11194                    // LN2 → output
11195                    let g2 = sl(*ln2_g, base, h);
11196                    let b2 = sl(*ln2_b, base, h);
11197                    for r in 0..m {
11198                        crate::kernels::layer_norm_row(
11199                            &res[r * h..(r + 1) * h],
11200                            g2,
11201                            b2,
11202                            &mut dst[r * h..(r + 1) * h],
11203                            h,
11204                            *eps2,
11205                        );
11206                    }
11207                }
11208            }
11209
11210            Thunk::FusedNomicLayer {
11211                hidden,
11212                qkv_w,
11213                out_w,
11214                mask,
11215                cos,
11216                sin,
11217                cos_len,
11218                ln1_g,
11219                ln1_b,
11220                eps1,
11221                fc11_w,
11222                fc12_w: _,
11223                fc2_w,
11224                ln2_g,
11225                ln2_b,
11226                eps2,
11227                out,
11228                batch,
11229                seq,
11230                hs,
11231                nh,
11232                dh,
11233                int_dim,
11234            } => {
11235                let (b, s, h, n_h, d_h) = (
11236                    *batch as usize,
11237                    *seq as usize,
11238                    *hs as usize,
11239                    *nh as usize,
11240                    *dh as usize,
11241                );
11242                let m = b * s;
11243                let id = *int_dim as usize;
11244                let scale = (d_h as f32).powf(-0.5);
11245                let half_dh = d_h / 2;
11246                #[cfg(target_arch = "aarch64")]
11247                let neon_chunks = d_h / 4;
11248                unsafe {
11249                    let inp = sl(*hidden, base, m * h);
11250                    let dst = sl_mut(*out, base, m * h);
11251                    let mk = sl(*mask, base, b * s);
11252                    let cos_tab = sl(*cos, base, *cos_len as usize);
11253                    let sin_tab = sl(*sin, base, *cos_len as usize);
11254                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
11255                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11256
11257                    let mut qkv = vec![0f32; m * 3 * h];
11258                    let mut attn = vec![0f32; m * h];
11259                    let mut res = vec![0f32; m * h];
11260                    let mut normed = vec![0f32; m * h];
11261                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
11262                    let mut sc = vec![0f32; s * s];
11263
11264                    // QKV (no bias)
11265                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11266
11267                    // SDPA with inline RoPE
11268                    for bi in 0..b {
11269                        for hi in 0..n_h {
11270                            for qi in 0..s {
11271                                for ki in 0..s {
11272                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11273                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11274                                    let mut dot = 0f32;
11275                                    for i in 0..half_dh {
11276                                        let q1 = qkv[q_base + i];
11277                                        let q2 = qkv[q_base + half_dh + i];
11278                                        let k1 = qkv[k_base + i];
11279                                        let k2 = qkv[k_base + half_dh + i];
11280                                        let cq = cos_tab[qi * half_dh + i];
11281                                        let sq = sin_tab[qi * half_dh + i];
11282                                        let ck = cos_tab[ki * half_dh + i];
11283                                        let sk = sin_tab[ki * half_dh + i];
11284                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11285                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11286                                    }
11287                                    sc[qi * s + ki] = dot * scale;
11288                                    if mk[bi * s + ki] < mask_thr {
11289                                        sc[qi * s + ki] = mask_neg;
11290                                    }
11291                                }
11292                            }
11293                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11294                            for qi in 0..s {
11295                                let o = bi * s * h + qi * h + hi * d_h;
11296                                for d in 0..d_h {
11297                                    attn[o + d] = 0.0;
11298                                }
11299                                for ki in 0..s {
11300                                    let w = sc[qi * s + ki];
11301                                    if w > score_thr {
11302                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11303                                        #[cfg(target_arch = "aarch64")]
11304                                        {
11305                                            use std::arch::aarch64::*;
11306                                            let vw = vdupq_n_f32(w);
11307                                            for c in 0..neon_chunks {
11308                                                let off = c * 4;
11309                                                vst1q_f32(
11310                                                    attn.as_mut_ptr().add(o + off),
11311                                                    vfmaq_f32(
11312                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11313                                                        vw,
11314                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11315                                                    ),
11316                                                );
11317                                            }
11318                                        }
11319                                        #[cfg(not(target_arch = "aarch64"))]
11320                                        for d in 0..d_h {
11321                                            attn[o + d] += w * qkv[v + d];
11322                                        }
11323                                    }
11324                                }
11325                            }
11326                        }
11327                    }
11328
11329                    // Out proj (no bias) + residual
11330                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11331                    for i in 0..m * h {
11332                        res[i] += inp[i];
11333                    }
11334
11335                    // LN1
11336                    let g1 = sl(*ln1_g, base, h);
11337                    let b1 = sl(*ln1_b, base, h);
11338                    for r in 0..m {
11339                        crate::kernels::layer_norm_row(
11340                            &res[r * h..(r + 1) * h],
11341                            g1,
11342                            b1,
11343                            &mut normed[r * h..(r + 1) * h],
11344                            h,
11345                            *eps1,
11346                        );
11347                    }
11348
11349                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
11350                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11351                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
11352                    // SiLU on gate, then multiply up * gate → store in up region
11353                    for row in 0..m {
11354                        let bo = row * 2 * id;
11355                        // SiLU in-place on gate portion
11356                        for j in 0..id {
11357                            let x = ffn_concat[bo + id + j];
11358                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11359                        }
11360                        // Multiply: up[j] *= gate[j]
11361                        for j in 0..id {
11362                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11363                        }
11364                    }
11365
11366                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
11367                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
11368                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
11369                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
11370                    crate::blas::sgemm_general(
11371                        ffn_concat.as_ptr(),
11372                        sl(*fc2_w, base, id * h).as_ptr(),
11373                        res.as_mut_ptr(),
11374                        m,
11375                        h,
11376                        id,
11377                        1.0,
11378                        0.0,
11379                        2 * id,
11380                        h,
11381                        h,
11382                        false,
11383                        false,
11384                    );
11385                    for i in 0..m * h {
11386                        res[i] += normed[i];
11387                    }
11388
11389                    // LN2 → output
11390                    let g2 = sl(*ln2_g, base, h);
11391                    let b2 = sl(*ln2_b, base, h);
11392                    for r in 0..m {
11393                        crate::kernels::layer_norm_row(
11394                            &res[r * h..(r + 1) * h],
11395                            g2,
11396                            b2,
11397                            &mut dst[r * h..(r + 1) * h],
11398                            h,
11399                            *eps2,
11400                        );
11401                    }
11402                }
11403            }
11404
11405            Thunk::FusedSwiGLU {
11406                src,
11407                dst,
11408                n_half,
11409                total,
11410                gate_first,
11411            } => {
11412                let n = *n_half as usize;
11413                let t = *total as usize;
11414                let outer = t / n;
11415                let in_total = outer * 2 * n;
11416                let gate_first = *gate_first;
11417                unsafe {
11418                    let inp = sl(*src, base, in_total);
11419                    let out = sl_mut(*dst, base, t);
11420                    for o in 0..outer {
11421                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11422                        let out_row = &mut out[o * n..(o + 1) * n];
11423                        for i in 0..n {
11424                            let (up, gate) = if gate_first {
11425                                (in_row[n + i], in_row[i])
11426                            } else {
11427                                (in_row[i], in_row[n + i])
11428                            };
11429                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11430                        }
11431                    }
11432                }
11433            }
11434
11435            Thunk::Concat {
11436                dst,
11437                outer,
11438                inner,
11439                total_axis,
11440                inputs,
11441            } => {
11442                let outer = *outer as usize;
11443                let inner = *inner as usize;
11444                let total_axis = *total_axis as usize;
11445                let row_stride = total_axis * inner;
11446                let out_total = outer * row_stride;
11447                unsafe {
11448                    let out = sl_mut(*dst, base, out_total);
11449                    let mut cum: usize = 0;
11450                    for (src_off, in_axis) in inputs {
11451                        let in_axis = *in_axis as usize;
11452                        let copy_per_row = in_axis * inner;
11453                        let dst_col_off = cum * inner;
11454                        let in_total = outer * copy_per_row;
11455                        let inp = sl(*src_off, base, in_total);
11456                        for o in 0..outer {
11457                            let dst_row_start = o * row_stride + dst_col_off;
11458                            let src_row_start = o * copy_per_row;
11459                            out[dst_row_start..dst_row_start + copy_per_row]
11460                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11461                        }
11462                        cum += in_axis;
11463                    }
11464                }
11465            }
11466
11467            Thunk::ConcatF64 {
11468                dst,
11469                outer,
11470                inner,
11471                total_axis,
11472                inputs,
11473            } => {
11474                let outer = *outer as usize;
11475                let inner = *inner as usize;
11476                let total_axis = *total_axis as usize;
11477                let row_stride = total_axis * inner;
11478                let out_total = outer * row_stride;
11479                unsafe {
11480                    let out = sl_mut_f64(*dst, base, out_total);
11481                    let mut cum: usize = 0;
11482                    for (src_off, in_axis) in inputs {
11483                        let in_axis = *in_axis as usize;
11484                        let copy_per_row = in_axis * inner;
11485                        let dst_col_off = cum * inner;
11486                        let in_total = outer * copy_per_row;
11487                        let inp = sl_f64(*src_off, base, in_total);
11488                        for o in 0..outer {
11489                            let dst_row_start = o * row_stride + dst_col_off;
11490                            let src_row_start = o * copy_per_row;
11491                            out[dst_row_start..dst_row_start + copy_per_row]
11492                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
11493                        }
11494                        cum += in_axis;
11495                    }
11496                }
11497            }
11498
11499            Thunk::Compare {
11500                lhs,
11501                rhs,
11502                dst,
11503                len,
11504                op,
11505                inputs_i64,
11506                inputs_elem_bytes,
11507                dst_elem_bytes,
11508            } => {
11509                let len = *len as usize;
11510                let arena_len = arena_buf.len();
11511                let elem = (*inputs_elem_bytes).max(1) as usize;
11512                let dst_eb = (*dst_elem_bytes).max(1) as usize;
11513                let max_l = (arena_len.saturating_sub(*lhs)) / elem;
11514                let max_r = (arena_len.saturating_sub(*rhs)) / elem;
11515                let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
11516                let len = len.min(max_l).min(max_r).min(max_d);
11517                if trace_thunks && len > 0 {
11518                    eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
11519                }
11520                if elem == 1 {
11521                    let l = arena_buf[*lhs..*lhs + len].to_vec();
11522                    let r = arena_buf[*rhs..*rhs + len].to_vec();
11523                    for i in 0..len {
11524                        let v = match op {
11525                            CmpOp::Eq => l[i] == r[i],
11526                            CmpOp::Ne => l[i] != r[i],
11527                            CmpOp::Lt => l[i] < r[i],
11528                            CmpOp::Le => l[i] <= r[i],
11529                            CmpOp::Gt => l[i] > r[i],
11530                            CmpOp::Ge => l[i] >= r[i],
11531                        };
11532                        if *dst_elem_bytes == 1 {
11533                            arena_buf[*dst + i] = u8::from(v);
11534                        } else {
11535                            unsafe {
11536                                let o = sl_mut(*dst, base, len);
11537                                o[i] = if v { 1.0 } else { 0.0 };
11538                            }
11539                        }
11540                    }
11541                } else if *inputs_i64 != 0 {
11542                    unsafe {
11543                        let l = sl_i64(*lhs, base, len);
11544                        let r = sl_i64(*rhs, base, len);
11545                        for i in 0..len {
11546                            let v = match op {
11547                                CmpOp::Eq => l[i] == r[i],
11548                                CmpOp::Ne => l[i] != r[i],
11549                                CmpOp::Lt => l[i] < r[i],
11550                                CmpOp::Le => l[i] <= r[i],
11551                                CmpOp::Gt => l[i] > r[i],
11552                                CmpOp::Ge => l[i] >= r[i],
11553                            };
11554                            if *dst_elem_bytes == 1 {
11555                                arena_buf[*dst + i] = u8::from(v);
11556                            } else {
11557                                let o = sl_mut(*dst, base, len);
11558                                o[i] = if v { 1.0 } else { 0.0 };
11559                            }
11560                        }
11561                    }
11562                } else {
11563                    unsafe {
11564                        let l = sl(*lhs, base, len);
11565                        let r = sl(*rhs, base, len);
11566                        for i in 0..len {
11567                            let v = match op {
11568                                CmpOp::Eq => l[i] == r[i],
11569                                CmpOp::Ne => l[i] != r[i],
11570                                CmpOp::Lt => l[i] < r[i],
11571                                CmpOp::Le => l[i] <= r[i],
11572                                CmpOp::Gt => l[i] > r[i],
11573                                CmpOp::Ge => l[i] >= r[i],
11574                            };
11575                            if *dst_elem_bytes == 1 {
11576                                arena_buf[*dst + i] = u8::from(v);
11577                            } else {
11578                                let o = sl_mut(*dst, base, len);
11579                                o[i] = if v { 1.0 } else { 0.0 };
11580                            }
11581                        }
11582                    }
11583                }
11584            }
11585
11586            Thunk::Where {
11587                cond,
11588                on_true,
11589                on_false,
11590                dst,
11591                len,
11592                elem_bytes,
11593                cond_elem_bytes,
11594            } => {
11595                let len = *len as usize;
11596                let eb = *elem_bytes as usize;
11597                let cond_eb = (*cond_elem_bytes).max(1) as usize;
11598                let arena_len = arena_buf.len();
11599                let len = len
11600                    .min((arena_len.saturating_sub(*cond)) / cond_eb)
11601                    .min((arena_len.saturating_sub(*on_true)) / eb)
11602                    .min((arena_len.saturating_sub(*on_false)) / eb)
11603                    .min((arena_len.saturating_sub(*dst)) / eb);
11604                unsafe {
11605                    if *elem_bytes == 8 {
11606                        let t = sl_i64(*on_true, base, len);
11607                        let e = sl_i64(*on_false, base, len);
11608                        let o = sl_mut_i64(*dst, base, len);
11609                        if *cond_elem_bytes == 1 {
11610                            let c = &arena_buf[*cond..*cond + len];
11611                            for i in 0..len {
11612                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
11613                            }
11614                        } else {
11615                            let c = sl_i64(*cond, base, len);
11616                            for i in 0..len {
11617                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
11618                            }
11619                        }
11620                    } else if *cond_elem_bytes == 1 {
11621                        let c = &arena_buf[*cond..*cond + len];
11622                        let t = sl(*on_true, base, len);
11623                        let e = sl(*on_false, base, len);
11624                        let o = sl_mut(*dst, base, len);
11625                        for i in 0..len {
11626                            o[i] = if c[i] != 0 { t[i] } else { e[i] };
11627                        }
11628                    } else {
11629                        let c = sl(*cond, base, len);
11630                        let t = sl(*on_true, base, len);
11631                        let e = sl(*on_false, base, len);
11632                        let o = sl_mut(*dst, base, len);
11633                        for i in 0..len {
11634                            o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
11635                        }
11636                    }
11637                }
11638            }
11639
11640            Thunk::ScatterAdd {
11641                updates,
11642                indices,
11643                dst,
11644                num_updates,
11645                out_dim,
11646                trailing,
11647            } => {
11648                let num_updates = *num_updates as usize;
11649                let out_dim = *out_dim as usize;
11650                let trailing = *trailing as usize;
11651                unsafe {
11652                    let upd = sl(*updates, base, num_updates * trailing);
11653                    let ids = sl(*indices, base, num_updates);
11654                    let out = sl_mut(*dst, base, out_dim * trailing);
11655                    // Zero the output first — semantics are accumulate-into-zeros.
11656                    for v in out.iter_mut() {
11657                        *v = 0.0;
11658                    }
11659                    for i in 0..num_updates {
11660                        let row = ids[i] as usize;
11661                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
11662                        let src_off = i * trailing;
11663                        let dst_off = row * trailing;
11664                        for j in 0..trailing {
11665                            out[dst_off + j] += upd[src_off + j];
11666                        }
11667                    }
11668                }
11669            }
11670
11671            Thunk::GroupedMatMul {
11672                input,
11673                weight,
11674                expert_idx,
11675                dst,
11676                m,
11677                k_dim,
11678                n,
11679                num_experts,
11680            } => {
11681                let m = *m as usize;
11682                let k_dim = *k_dim as usize;
11683                let n = *n as usize;
11684                let num_experts = *num_experts as usize;
11685                unsafe {
11686                    let inp = sl(*input, base, m * k_dim);
11687                    let wt = sl(*weight, base, num_experts * k_dim * n);
11688                    let ids = sl(*expert_idx, base, m);
11689                    let out = sl_mut(*dst, base, m * n);
11690
11691                    // Counting-sort tokens by their assigned expert.
11692                    // counts[e] = how many tokens routed to expert e.
11693                    let mut counts = vec![0usize; num_experts];
11694                    for i in 0..m {
11695                        let e = ids[i] as usize;
11696                        debug_assert!(
11697                            e < num_experts,
11698                            "expert_idx out of range: {e} >= {num_experts}"
11699                        );
11700                        counts[e] += 1;
11701                    }
11702                    // Cumulative offsets into the packed buffer.
11703                    let mut offsets = vec![0usize; num_experts + 1];
11704                    for e in 0..num_experts {
11705                        offsets[e + 1] = offsets[e] + counts[e];
11706                    }
11707                    // Pack: each expert's rows land contiguously in `packed_in`.
11708                    // `original_pos[packed_idx] = original_token_idx` for the
11709                    // unpermute step at the end.
11710                    let mut packed_in = vec![0f32; m * k_dim];
11711                    let mut original_pos = vec![0usize; m];
11712                    let mut write_idx = vec![0usize; num_experts];
11713                    for i in 0..m {
11714                        let e = ids[i] as usize;
11715                        let dst_row = offsets[e] + write_idx[e];
11716                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
11717                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
11718                        original_pos[dst_row] = i;
11719                        write_idx[e] += 1;
11720                    }
11721
11722                    // One BLAS sgemm per expert. Skip experts with no
11723                    // tokens — common at the tail when M is much smaller
11724                    // than num_experts × k.
11725                    let mut packed_out = vec![0f32; m * n];
11726                    let expert_stride = k_dim * n;
11727                    let gmm_ord = crate::moe_residency::next_gmm_ord();
11728                    let moe_layer = gmm_ord / 3;
11729                    for e in 0..num_experts {
11730                        let count = counts[e];
11731                        if count == 0 {
11732                            continue;
11733                        }
11734                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
11735                        let in_start = offsets[e];
11736                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
11737                        let w_slab: &[f32] =
11738                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
11739                                if let Some(ptr) =
11740                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
11741                                {
11742                                    std::slice::from_raw_parts(ptr, expert_stride)
11743                                } else {
11744                                    &wt[e * expert_stride..(e + 1) * expert_stride]
11745                                }
11746                            } else {
11747                                &wt[e * expert_stride..(e + 1) * expert_stride]
11748                            };
11749                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
11750                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
11751                    }
11752
11753                    // Unpermute back to original token order.
11754                    for packed_idx in 0..m {
11755                        let i = original_pos[packed_idx];
11756                        out[i * n..(i + 1) * n]
11757                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
11758                    }
11759                }
11760            }
11761
11762            Thunk::DequantGroupedMatMulGguf {
11763                input,
11764                w_q,
11765                expert_idx,
11766                dst,
11767                m,
11768                k_dim,
11769                n,
11770                num_experts,
11771                scheme,
11772            } => {
11773                let m = *m as usize;
11774                let k_dim = *k_dim as usize;
11775                let n = *n as usize;
11776                let num_experts = *num_experts as usize;
11777                let block_elems = scheme.gguf_block_size() as usize;
11778                let block_bytes = scheme.gguf_block_bytes() as usize;
11779                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11780                unsafe {
11781                    let inp = sl(*input, base, m * k_dim);
11782                    let wt = std::slice::from_raw_parts(
11783                        base.add(*w_q) as *const u8,
11784                        num_experts * slab_bytes,
11785                    );
11786                    let ids = sl(*expert_idx, base, m);
11787                    let out = sl_mut(*dst, base, m * n);
11788                    crate::gguf_matmul::gguf_grouped_matmul_bt(
11789                        inp,
11790                        wt,
11791                        ids,
11792                        out,
11793                        m,
11794                        k_dim,
11795                        n,
11796                        num_experts,
11797                        *scheme,
11798                    );
11799                }
11800            }
11801
11802            Thunk::DequantMoEWeightsGguf {
11803                w_q,
11804                dst,
11805                k_dim,
11806                n,
11807                num_experts,
11808                scheme,
11809            } => {
11810                let k_dim = *k_dim as usize;
11811                let n = *n as usize;
11812                let num_experts = *num_experts as usize;
11813                let block_elems = scheme.gguf_block_size() as usize;
11814                let block_bytes = scheme.gguf_block_bytes() as usize;
11815                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
11816                unsafe {
11817                    let wt = std::slice::from_raw_parts(
11818                        base.add(*w_q) as *const u8,
11819                        num_experts * slab_bytes,
11820                    );
11821                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
11822                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
11823                        wt,
11824                        out,
11825                        num_experts,
11826                        k_dim,
11827                        n,
11828                        *scheme,
11829                    );
11830                }
11831            }
11832
11833            Thunk::TopK {
11834                src,
11835                dst,
11836                outer,
11837                axis_dim,
11838                k,
11839                indices_i64,
11840            } => {
11841                let outer = *outer as usize;
11842                let axis_dim = *axis_dim as usize;
11843                let k = *k as usize;
11844                unsafe {
11845                    let inp = sl(*src, base, outer * axis_dim);
11846                    // Repeated argmax with masking. O(k * axis_dim) per row;
11847                    // good enough for small k (MoE typical k=2–8). For larger
11848                    // k a partial heap would win.
11849                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
11850                    if *indices_i64 != 0 {
11851                        let out = sl_mut_i64(*dst, base, outer * k);
11852                        for o in 0..outer {
11853                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11854                            for ki in 0..k {
11855                                let mut best_i = 0usize;
11856                                let mut best_v = row_buf[0];
11857                                for i in 1..axis_dim {
11858                                    let v = row_buf[i];
11859                                    if v > best_v {
11860                                        best_v = v;
11861                                        best_i = i;
11862                                    }
11863                                }
11864                                out[o * k + ki] = best_i as i64;
11865                                row_buf[best_i] = f32::NEG_INFINITY;
11866                            }
11867                        }
11868                    } else {
11869                        let out = sl_mut(*dst, base, outer * k);
11870                        for o in 0..outer {
11871                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
11872                            for ki in 0..k {
11873                                let mut best_i = 0usize;
11874                                let mut best_v = row_buf[0];
11875                                for i in 1..axis_dim {
11876                                    let v = row_buf[i];
11877                                    if v > best_v {
11878                                        best_v = v;
11879                                        best_i = i;
11880                                    }
11881                                }
11882                                out[o * k + ki] = best_i as f32;
11883                                row_buf[best_i] = f32::NEG_INFINITY;
11884                            }
11885                        }
11886                        if let Some(cap) = schedule.moe_topk_capture.as_ref() {
11887                            cap.push_topk_f32(&out[..outer * k], axis_dim);
11888                        }
11889                    }
11890                }
11891            }
11892
11893            Thunk::Reduce {
11894                src,
11895                dst,
11896                outer,
11897                reduced,
11898                inner,
11899                op,
11900            } => {
11901                let outer = *outer as usize;
11902                let reduced = *reduced as usize;
11903                let inner = *inner as usize;
11904                let in_total = outer * reduced * inner;
11905                let out_total = outer * inner;
11906                unsafe {
11907                    let inp = sl(*src, base, in_total);
11908                    let out = sl_mut(*dst, base, out_total);
11909                    for o in 0..outer {
11910                        for i in 0..inner {
11911                            let mut acc = match op {
11912                                ReduceOp::Max => f32::NEG_INFINITY,
11913                                ReduceOp::Min => f32::INFINITY,
11914                                ReduceOp::Prod => 1.0f32,
11915                                _ => 0.0f32, // Sum / Mean
11916                            };
11917                            // Walk the reduced axis with stride `inner`.
11918                            for r in 0..reduced {
11919                                let v = inp[o * reduced * inner + r * inner + i];
11920                                acc = match op {
11921                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
11922                                    ReduceOp::Max => acc.max(v),
11923                                    ReduceOp::Min => acc.min(v),
11924                                    ReduceOp::Prod => acc * v,
11925                                };
11926                            }
11927                            if matches!(op, ReduceOp::Mean) {
11928                                acc /= reduced as f32;
11929                            }
11930                            out[o * inner + i] = acc;
11931                        }
11932                    }
11933                }
11934            }
11935
11936            Thunk::Conv2D1x1 {
11937                src,
11938                weight,
11939                dst,
11940                n,
11941                c_in,
11942                c_out,
11943                hw,
11944            } => {
11945                let n = *n as usize;
11946                let c_in = *c_in as usize;
11947                let c_out = *c_out as usize;
11948                let hw = *hw as usize;
11949                unsafe {
11950                    let inp = sl(*src, base, n * c_in * hw);
11951                    let wt = sl(*weight, base, c_out * c_in);
11952                    let out = sl_mut(*dst, base, n * c_out * hw);
11953                    // Per-batch sgemm: weight [c_out, c_in] @ input
11954                    // [c_in, hw] = output [c_out, hw]. The weight is
11955                    // shared across batches, so we get to dispatch
11956                    // BLAS once per N (typically 1).
11957                    for ni in 0..n {
11958                        let in_off = ni * c_in * hw;
11959                        let out_off = ni * c_out * hw;
11960                        crate::blas::sgemm(
11961                            wt,
11962                            &inp[in_off..in_off + c_in * hw],
11963                            &mut out[out_off..out_off + c_out * hw],
11964                            c_out,
11965                            c_in,
11966                            hw,
11967                        );
11968                    }
11969                }
11970            }
11971
11972            Thunk::Conv2D {
11973                src,
11974                weight,
11975                dst,
11976                n,
11977                c_in,
11978                h,
11979                w,
11980                c_out,
11981                h_out,
11982                w_out,
11983                kh,
11984                kw,
11985                sh,
11986                sw,
11987                ph,
11988                pw,
11989                dh,
11990                dw,
11991                groups,
11992            } => {
11993                let n = *n as usize;
11994                let c_in = *c_in as usize;
11995                let h = *h as usize;
11996                let w = *w as usize;
11997                let c_out = *c_out as usize;
11998                let h_out = *h_out as usize;
11999                let w_out = *w_out as usize;
12000                let kh = *kh as usize;
12001                let kw = *kw as usize;
12002                let sh = *sh as usize;
12003                let sw = *sw as usize;
12004                let ph = *ph as usize;
12005                let pw = *pw as usize;
12006                let dh = *dh as usize;
12007                let dw = *dw as usize;
12008                let groups = *groups as usize;
12009                let c_in_per_g = c_in / groups;
12010                let c_out_per_g = c_out / groups;
12011                unsafe {
12012                    let inp = sl(*src, base, n * c_in * h * w);
12013                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
12014                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
12015                    for ni in 0..n {
12016                        for co in 0..c_out {
12017                            let g = co / c_out_per_g;
12018                            let ci_start = g * c_in_per_g;
12019                            for ho in 0..h_out {
12020                                for wo in 0..w_out {
12021                                    let mut acc = 0f32;
12022                                    for ci_off in 0..c_in_per_g {
12023                                        let ci = ci_start + ci_off;
12024                                        let in_chan = ((ni * c_in) + ci) * h * w;
12025                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12026                                        for ki in 0..kh {
12027                                            for kj in 0..kw {
12028                                                let hi = ho * sh + ki * dh;
12029                                                let wi = wo * sw + kj * dw;
12030                                                if hi < ph || wi < pw {
12031                                                    continue;
12032                                                }
12033                                                let hi = hi - ph;
12034                                                let wi = wi - pw;
12035                                                if hi >= h || wi >= w {
12036                                                    continue;
12037                                                }
12038                                                acc += inp[in_chan + hi * w + wi]
12039                                                    * wt[wt_chan + ki * kw + kj];
12040                                            }
12041                                        }
12042                                    }
12043                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
12044                                        acc;
12045                                }
12046                            }
12047                        }
12048                    }
12049                }
12050            }
12051
12052            Thunk::Pool2D {
12053                src,
12054                dst,
12055                n,
12056                c,
12057                h,
12058                w,
12059                h_out,
12060                w_out,
12061                kh,
12062                kw,
12063                sh,
12064                sw,
12065                ph,
12066                pw,
12067                kind,
12068            } => {
12069                let n = *n as usize;
12070                let c = *c as usize;
12071                let h = *h as usize;
12072                let w = *w as usize;
12073                let h_out = *h_out as usize;
12074                let w_out = *w_out as usize;
12075                let kh = *kh as usize;
12076                let kw = *kw as usize;
12077                let sh = *sh as usize;
12078                let sw = *sw as usize;
12079                let ph = *ph as usize;
12080                let pw = *pw as usize;
12081                let kernel_area = (kh * kw) as f32;
12082                unsafe {
12083                    let inp = sl(*src, base, n * c * h * w);
12084                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
12085                    for ni in 0..n {
12086                        for ci in 0..c {
12087                            let in_chan = ni * c * h * w + ci * h * w;
12088                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12089                            for ho in 0..h_out {
12090                                for wo in 0..w_out {
12091                                    let mut acc = match kind {
12092                                        ReduceOp::Max => f32::NEG_INFINITY,
12093                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
12094                                    };
12095                                    for ki in 0..kh {
12096                                        for kj in 0..kw {
12097                                            let hi = ho * sh + ki;
12098                                            let wi = wo * sw + kj;
12099                                            // Padded-zero region.
12100                                            if hi < ph || wi < pw {
12101                                                continue;
12102                                            }
12103                                            let hi = hi - ph;
12104                                            let wi = wi - pw;
12105                                            if hi >= h || wi >= w {
12106                                                continue;
12107                                            }
12108                                            let v = inp[in_chan + hi * w + wi];
12109                                            match kind {
12110                                                ReduceOp::Max => acc = acc.max(v),
12111                                                _ => acc += v,
12112                                            }
12113                                        }
12114                                    }
12115                                    if matches!(kind, ReduceOp::Mean) {
12116                                        acc /= kernel_area;
12117                                    }
12118                                    out[out_chan + ho * w_out + wo] = acc;
12119                                }
12120                            }
12121                        }
12122                    }
12123                }
12124            }
12125
12126            Thunk::ReluBackward { x, dy, dx, len } => {
12127                let len = *len as usize;
12128                unsafe {
12129                    let xs = sl(*x, base, len);
12130                    let dys = sl(*dy, base, len);
12131                    let out = sl_mut(*dx, base, len);
12132                    for i in 0..len {
12133                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12134                    }
12135                }
12136            }
12137
12138            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12139                let len = *len as usize;
12140                unsafe {
12141                    let xs = sl_f64(*x, base, len);
12142                    let dys = sl_f64(*dy, base, len);
12143                    let out = sl_mut_f64(*dx, base, len);
12144                    for i in 0..len {
12145                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12146                    }
12147                }
12148            }
12149
12150            Thunk::QMatMul {
12151                x,
12152                w,
12153                bias,
12154                out,
12155                m,
12156                k,
12157                n,
12158                x_zp,
12159                w_zp,
12160                out_zp,
12161                mult,
12162            } => {
12163                let m = *m as usize;
12164                let k = *k as usize;
12165                let n = *n as usize;
12166                unsafe {
12167                    let x_ptr = base.add(*x) as *const i8;
12168                    let w_ptr = base.add(*w) as *const i8;
12169                    let bias_ptr = base.add(*bias) as *const i32;
12170                    let out_ptr = base.add(*out) as *mut i8;
12171                    for mi in 0..m {
12172                        for ni in 0..n {
12173                            let mut acc: i32 = *bias_ptr.add(ni);
12174                            for ki in 0..k {
12175                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12176                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12177                                acc += xv * wv;
12178                            }
12179                            // Requantize: round(acc · mult) + out_zp,
12180                            // clamped to i8.
12181                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12182                            let r = r.clamp(-128, 127) as i8;
12183                            *out_ptr.add(mi * n + ni) = r;
12184                        }
12185                    }
12186                }
12187            }
12188
12189            Thunk::QConv2d {
12190                x,
12191                w,
12192                bias,
12193                out,
12194                n,
12195                c_in,
12196                h,
12197                w_in,
12198                c_out,
12199                h_out,
12200                w_out,
12201                kh,
12202                kw,
12203                sh,
12204                sw,
12205                ph,
12206                pw,
12207                dh,
12208                dw,
12209                groups,
12210                x_zp,
12211                w_zp,
12212                out_zp,
12213                mult,
12214            } => {
12215                let n = *n as usize;
12216                let c_in = *c_in as usize;
12217                let h = *h as usize;
12218                let w_in = *w_in as usize;
12219                let c_out = *c_out as usize;
12220                let h_out = *h_out as usize;
12221                let w_out = *w_out as usize;
12222                let kh = *kh as usize;
12223                let kw = *kw as usize;
12224                let sh = *sh as usize;
12225                let sw = *sw as usize;
12226                let ph = *ph as usize;
12227                let pw = *pw as usize;
12228                let dh = *dh as usize;
12229                let dw = *dw as usize;
12230                let groups = *groups as usize;
12231                let c_in_per_g = c_in / groups;
12232                let c_out_per_g = c_out / groups;
12233                unsafe {
12234                    let x_ptr = base.add(*x) as *const i8;
12235                    let w_ptr = base.add(*w) as *const i8;
12236                    let bias_ptr = base.add(*bias) as *const i32;
12237                    let out_ptr = base.add(*out) as *mut i8;
12238                    for ni in 0..n {
12239                        for co in 0..c_out {
12240                            let g = co / c_out_per_g;
12241                            let ci_start = g * c_in_per_g;
12242                            for ho in 0..h_out {
12243                                for wo in 0..w_out {
12244                                    let mut acc: i32 = *bias_ptr.add(co);
12245                                    for ci_off in 0..c_in_per_g {
12246                                        let ci = ci_start + ci_off;
12247                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
12248                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12249                                        for ki in 0..kh {
12250                                            for kj in 0..kw {
12251                                                let hi = ho * sh + ki * dh;
12252                                                let wi = wo * sw + kj * dw;
12253                                                if hi < ph || wi < pw {
12254                                                    continue;
12255                                                }
12256                                                let hi = hi - ph;
12257                                                let wi = wi - pw;
12258                                                if hi >= h || wi >= w_in {
12259                                                    continue;
12260                                                }
12261                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12262                                                    as i32
12263                                                    - *x_zp;
12264                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12265                                                    - *w_zp;
12266                                                acc += xv * wv;
12267                                            }
12268                                        }
12269                                    }
12270                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12271                                    let r = r.clamp(-128, 127) as i8;
12272                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12273                                    *out_ptr.add(dst) = r;
12274                                }
12275                            }
12276                        }
12277                    }
12278                }
12279            }
12280
12281            Thunk::Quantize {
12282                x,
12283                q,
12284                len,
12285                chan_axis: _,
12286                chan_dim,
12287                inner,
12288                scales,
12289                zero_points,
12290            } => {
12291                let len = *len as usize;
12292                let chan_dim = *chan_dim as usize;
12293                let inner = *inner as usize;
12294                unsafe {
12295                    let xs = sl(*x, base, len);
12296                    let q_ptr = base.add(*q) as *mut i8;
12297                    for i in 0..len {
12298                        let c = if chan_dim == 1 {
12299                            0
12300                        } else {
12301                            (i / inner) % chan_dim
12302                        };
12303                        let inv_scale = 1.0 / scales[c];
12304                        let zp = zero_points[c];
12305                        let v = (xs[i] * inv_scale).round() as i32 + zp;
12306                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12307                    }
12308                }
12309            }
12310
12311            Thunk::Dequantize {
12312                q,
12313                x,
12314                len,
12315                chan_axis: _,
12316                chan_dim,
12317                inner,
12318                scales,
12319                zero_points,
12320            } => {
12321                let len = *len as usize;
12322                let chan_dim = *chan_dim as usize;
12323                let inner = *inner as usize;
12324                unsafe {
12325                    let q_ptr = base.add(*q) as *const i8;
12326                    let out = sl_mut(*x, base, len);
12327                    for i in 0..len {
12328                        let c = if chan_dim == 1 {
12329                            0
12330                        } else {
12331                            (i / inner) % chan_dim
12332                        };
12333                        let scale = scales[c];
12334                        let zp = zero_points[c];
12335                        let qv = *q_ptr.add(i) as i32;
12336                        out[i] = (qv - zp) as f32 * scale;
12337                    }
12338                }
12339            }
12340
12341            Thunk::FakeQuantize {
12342                x,
12343                out,
12344                len,
12345                chan_axis: _,
12346                chan_dim,
12347                inner,
12348                bits,
12349                ste: _,
12350                scale_mode,
12351                state_off,
12352            } => {
12353                use rlx_ir::op::ScaleMode;
12354                let len = *len as usize;
12355                let chan_dim = *chan_dim as usize;
12356                let inner = *inner as usize;
12357                let q_max: f32 = match *bits {
12358                    8 => 127.0,
12359                    4 => 7.0,
12360                    2 => 1.0,
12361                    n => panic!("FakeQuantize: unsupported bits {n}"),
12362                };
12363                unsafe {
12364                    let xs = sl(*x, base, len);
12365                    let outs = sl_mut(*out, base, len);
12366
12367                    let mut scale = vec![0f32; chan_dim];
12368                    match scale_mode {
12369                        ScaleMode::PerBatch => {
12370                            let mut max_abs = vec![0f32; chan_dim];
12371                            for i in 0..len {
12372                                let c = if chan_dim == 1 {
12373                                    0
12374                                } else {
12375                                    (i / inner) % chan_dim
12376                                };
12377                                let a = xs[i].abs();
12378                                if a > max_abs[c] {
12379                                    max_abs[c] = a;
12380                                }
12381                            }
12382                            for c in 0..chan_dim {
12383                                scale[c] = (max_abs[c] / q_max).max(1e-12);
12384                            }
12385                        }
12386                        ScaleMode::EMA { decay } => {
12387                            // Per-channel current max-abs, then blend
12388                            // into the running state in place.
12389                            let mut max_abs = vec![0f32; chan_dim];
12390                            for i in 0..len {
12391                                let c = if chan_dim == 1 {
12392                                    0
12393                                } else {
12394                                    (i / inner) % chan_dim
12395                                };
12396                                let a = xs[i].abs();
12397                                if a > max_abs[c] {
12398                                    max_abs[c] = a;
12399                                }
12400                            }
12401                            let state =
12402                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12403                            for c in 0..chan_dim {
12404                                let cur = (max_abs[c] / q_max).max(1e-12);
12405                                // Cold-start: state==0 → seed directly.
12406                                let blended = if state[c] <= 0.0 {
12407                                    cur
12408                                } else {
12409                                    *decay * state[c] + (1.0 - *decay) * cur
12410                                };
12411                                state[c] = blended;
12412                                scale[c] = blended;
12413                            }
12414                        }
12415                        ScaleMode::Fixed => {
12416                            let state =
12417                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
12418                            for c in 0..chan_dim {
12419                                scale[c] = state[c].max(1e-12);
12420                            }
12421                        }
12422                    }
12423
12424                    for i in 0..len {
12425                        let c = if chan_dim == 1 {
12426                            0
12427                        } else {
12428                            (i / inner) % chan_dim
12429                        };
12430                        let s = scale[c];
12431                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12432                        outs[i] = qv * s;
12433                    }
12434                }
12435            }
12436
12437            Thunk::ActivationBackward {
12438                x,
12439                dy,
12440                dx,
12441                len,
12442                kind,
12443            } => {
12444                let len = *len as usize;
12445                unsafe {
12446                    let xs = sl(*x, base, len);
12447                    let dys = sl(*dy, base, len);
12448                    let out = sl_mut(*dx, base, len);
12449                    activation_backward_kernel(*kind, xs, dys, out);
12450                }
12451            }
12452
12453            Thunk::ActivationBackwardF64 {
12454                x,
12455                dy,
12456                dx,
12457                len,
12458                kind,
12459            } => {
12460                let len = *len as usize;
12461                unsafe {
12462                    let xs = sl_f64(*x, base, len);
12463                    let dys = sl_f64(*dy, base, len);
12464                    let out = sl_mut_f64(*dx, base, len);
12465                    activation_backward_kernel_f64(*kind, xs, dys, out);
12466                }
12467            }
12468
12469            Thunk::FakeQuantizeLSQ {
12470                x,
12471                scale_off,
12472                out,
12473                len,
12474                chan_axis: _,
12475                chan_dim,
12476                inner,
12477                bits,
12478            } => {
12479                let len = *len as usize;
12480                let chan_dim = *chan_dim as usize;
12481                let inner = *inner as usize;
12482                let q_max: f32 = match *bits {
12483                    8 => 127.0,
12484                    4 => 7.0,
12485                    2 => 1.0,
12486                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
12487                };
12488                unsafe {
12489                    let xs = sl(*x, base, len);
12490                    let scale = sl(*scale_off, base, chan_dim);
12491                    let outs = sl_mut(*out, base, len);
12492                    for i in 0..len {
12493                        let c = if chan_dim == 1 {
12494                            0
12495                        } else {
12496                            (i / inner) % chan_dim
12497                        };
12498                        let s = scale[c].max(1e-12);
12499                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
12500                        outs[i] = qv * s;
12501                    }
12502                }
12503            }
12504
12505            Thunk::FakeQuantizeLSQBackwardX {
12506                x,
12507                scale_off,
12508                dy,
12509                dx,
12510                len,
12511                chan_axis: _,
12512                chan_dim,
12513                inner,
12514                bits,
12515            } => {
12516                let len = *len as usize;
12517                let chan_dim = *chan_dim as usize;
12518                let inner = *inner as usize;
12519                let q_max: f32 = match *bits {
12520                    8 => 127.0,
12521                    4 => 7.0,
12522                    2 => 1.0,
12523                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
12524                };
12525                unsafe {
12526                    let xs = sl(*x, base, len);
12527                    let scale = sl(*scale_off, base, chan_dim);
12528                    let dys = sl(*dy, base, len);
12529                    let outs = sl_mut(*dx, base, len);
12530                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
12531                    for i in 0..len {
12532                        let c = if chan_dim == 1 {
12533                            0
12534                        } else {
12535                            (i / inner) % chan_dim
12536                        };
12537                        let z = xs[i] / scale[c].max(1e-12);
12538                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
12539                    }
12540                }
12541            }
12542
12543            Thunk::FakeQuantizeLSQBackwardScale {
12544                x,
12545                scale_off,
12546                dy,
12547                dscale,
12548                len,
12549                chan_axis: _,
12550                chan_dim,
12551                inner,
12552                bits,
12553            } => {
12554                let len = *len as usize;
12555                let chan_dim = *chan_dim as usize;
12556                let inner = *inner as usize;
12557                let q_max: f32 = match *bits {
12558                    8 => 127.0,
12559                    4 => 7.0,
12560                    2 => 1.0,
12561                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
12562                };
12563                unsafe {
12564                    let xs = sl(*x, base, len);
12565                    let scale = sl(*scale_off, base, chan_dim);
12566                    let dys = sl(*dy, base, len);
12567                    let outs = sl_mut(*dscale, base, chan_dim);
12568                    for v in outs.iter_mut() {
12569                        *v = 0.0;
12570                    }
12571                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
12572                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
12573                    for i in 0..len {
12574                        let c = if chan_dim == 1 {
12575                            0
12576                        } else {
12577                            (i / inner) % chan_dim
12578                        };
12579                        let s = scale[c].max(1e-12);
12580                        let z = xs[i] / s;
12581                        let psi = if z.abs() <= q_max {
12582                            -z + z.round()
12583                        } else if z > 0.0 {
12584                            q_max
12585                        } else {
12586                            -q_max
12587                        };
12588                        outs[c] += psi * dys[i];
12589                    }
12590                }
12591            }
12592
12593            Thunk::FakeQuantizeBackward {
12594                x,
12595                dy,
12596                dx,
12597                len,
12598                chan_axis: _,
12599                chan_dim,
12600                inner,
12601                bits,
12602                ste,
12603            } => {
12604                use rlx_ir::op::SteKind;
12605                let len = *len as usize;
12606                let chan_dim = *chan_dim as usize;
12607                let inner = *inner as usize;
12608                let q_max: f32 = match *bits {
12609                    8 => 127.0,
12610                    4 => 7.0,
12611                    2 => 1.0,
12612                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
12613                };
12614                unsafe {
12615                    let xs = sl(*x, base, len);
12616                    let dys = sl(*dy, base, len);
12617                    let outs = sl_mut(*dx, base, len);
12618
12619                    // Per-channel max-abs → scale, same as forward.
12620                    let mut max_abs = vec![0f32; chan_dim];
12621                    for i in 0..len {
12622                        let c = if chan_dim == 1 {
12623                            0
12624                        } else {
12625                            (i / inner) % chan_dim
12626                        };
12627                        let a = xs[i].abs();
12628                        if a > max_abs[c] {
12629                            max_abs[c] = a;
12630                        }
12631                    }
12632                    let mut scale = vec![0f32; chan_dim];
12633                    for c in 0..chan_dim {
12634                        scale[c] = (max_abs[c] / q_max).max(1e-12);
12635                    }
12636
12637                    match *ste {
12638                        SteKind::Identity => {
12639                            // dx = dy unchanged.
12640                            outs.copy_from_slice(dys);
12641                        }
12642                        SteKind::ClippedIdentity => {
12643                            // dx = dy * (|x| <= q_max·s); zero if the
12644                            // forward saturated.
12645                            for i in 0..len {
12646                                let c = if chan_dim == 1 {
12647                                    0
12648                                } else {
12649                                    (i / inner) % chan_dim
12650                                };
12651                                let bound = q_max * scale[c];
12652                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
12653                            }
12654                        }
12655                        SteKind::Tanh => {
12656                            // dx = dy * (1 - tanh²(x/s)).
12657                            for i in 0..len {
12658                                let c = if chan_dim == 1 {
12659                                    0
12660                                } else {
12661                                    (i / inner) % chan_dim
12662                                };
12663                                let t = (xs[i] / scale[c]).tanh();
12664                                outs[i] = dys[i] * (1.0 - t * t);
12665                            }
12666                        }
12667                        SteKind::HardTanh => {
12668                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
12669                            for i in 0..len {
12670                                let c = if chan_dim == 1 {
12671                                    0
12672                                } else {
12673                                    (i / inner) % chan_dim
12674                                };
12675                                let bound = q_max * scale[c];
12676                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
12677                                outs[i] = dys[i] * attenuation;
12678                            }
12679                        }
12680                    }
12681                }
12682            }
12683
12684            Thunk::LayerNormBackwardInput {
12685                x,
12686                gamma,
12687                dy,
12688                dx,
12689                rows,
12690                h,
12691                eps,
12692            } => {
12693                let rows = *rows as usize;
12694                let h = *h as usize;
12695                let eps = *eps;
12696                unsafe {
12697                    let xs = sl(*x, base, rows * h);
12698                    let g = sl(*gamma, base, h);
12699                    let dys = sl(*dy, base, rows * h);
12700                    let out = sl_mut(*dx, base, rows * h);
12701                    let n_inv = 1.0 / h as f32;
12702                    for r in 0..rows {
12703                        let xr = &xs[r * h..(r + 1) * h];
12704                        let dyr = &dys[r * h..(r + 1) * h];
12705                        // Per-row mean and inv_std (recompute — no saved
12706                        // tensor from the forward pass).
12707                        let mut sum = 0f32;
12708                        for &v in xr {
12709                            sum += v;
12710                        }
12711                        let mean = sum * n_inv;
12712                        let mut var = 0f32;
12713                        for &v in xr {
12714                            let d = v - mean;
12715                            var += d * d;
12716                        }
12717                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12718
12719                        // sums needed for the closed-form:
12720                        //   mean(dy·γ) and mean(dy·γ·x̂)
12721                        let mut s_sy = 0f32;
12722                        let mut s_sxh = 0f32;
12723                        for d in 0..h {
12724                            let xh = (xr[d] - mean) * inv_std;
12725                            let sy = dyr[d] * g[d];
12726                            s_sy += sy;
12727                            s_sxh += sy * xh;
12728                        }
12729                        let m_sy = s_sy * n_inv;
12730                        let m_sxh = s_sxh * n_inv;
12731
12732                        for d in 0..h {
12733                            let xh = (xr[d] - mean) * inv_std;
12734                            let sy = dyr[d] * g[d];
12735                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
12736                        }
12737                    }
12738                }
12739            }
12740
12741            Thunk::BatchNormInferenceBackwardInput {
12742                x,
12743                gamma,
12744                mean,
12745                var,
12746                dy,
12747                dx,
12748                count,
12749                channels,
12750                eps,
12751            } => {
12752                let count = *count as usize;
12753                let c = *channels as usize;
12754                let n = count * c;
12755                let eps = *eps;
12756                unsafe {
12757                    crate::kernels::batch_norm_inference_backward_input(
12758                        sl(*x, base, n),
12759                        sl(*gamma, base, c),
12760                        sl(*mean, base, c),
12761                        sl(*var, base, c),
12762                        sl(*dy, base, n),
12763                        sl_mut(*dx, base, n),
12764                        c,
12765                        eps,
12766                    );
12767                }
12768            }
12769
12770            Thunk::BatchNormInferenceBackwardGamma {
12771                x,
12772                mean,
12773                var,
12774                dy,
12775                dgamma,
12776                count,
12777                channels,
12778                eps,
12779            } => {
12780                let count = *count as usize;
12781                let c = *channels as usize;
12782                let n = count * c;
12783                let eps = *eps;
12784                unsafe {
12785                    crate::kernels::batch_norm_inference_backward_gamma(
12786                        sl(*x, base, n),
12787                        sl(*mean, base, c),
12788                        sl(*var, base, c),
12789                        sl(*dy, base, n),
12790                        sl_mut(*dgamma, base, c),
12791                        c,
12792                        eps,
12793                    );
12794                }
12795            }
12796
12797            Thunk::BatchNormInferenceBackwardBeta {
12798                dy,
12799                dbeta,
12800                count,
12801                channels,
12802            } => {
12803                let count = *count as usize;
12804                let c = *channels as usize;
12805                let n = count * c;
12806                unsafe {
12807                    crate::kernels::batch_norm_inference_backward_beta(
12808                        sl(*dy, base, n),
12809                        sl_mut(*dbeta, base, c),
12810                        c,
12811                    );
12812                }
12813            }
12814
12815            Thunk::LayerNormBackwardGamma {
12816                x,
12817                dy,
12818                dgamma,
12819                rows,
12820                h,
12821                eps,
12822            } => {
12823                let rows = *rows as usize;
12824                let h = *h as usize;
12825                let eps = *eps;
12826                unsafe {
12827                    let xs = sl(*x, base, rows * h);
12828                    let dys = sl(*dy, base, rows * h);
12829                    let out = sl_mut(*dgamma, base, h);
12830                    for v in out.iter_mut() {
12831                        *v = 0.0;
12832                    }
12833                    let n_inv = 1.0 / h as f32;
12834                    for r in 0..rows {
12835                        let xr = &xs[r * h..(r + 1) * h];
12836                        let dyr = &dys[r * h..(r + 1) * h];
12837                        let mut sum = 0f32;
12838                        for &v in xr {
12839                            sum += v;
12840                        }
12841                        let mean = sum * n_inv;
12842                        let mut var = 0f32;
12843                        for &v in xr {
12844                            let d = v - mean;
12845                            var += d * d;
12846                        }
12847                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
12848                        for d in 0..h {
12849                            let xh = (xr[d] - mean) * inv_std;
12850                            out[d] += dyr[d] * xh;
12851                        }
12852                    }
12853                }
12854            }
12855
12856            Thunk::RmsNormBackwardInput {
12857                x,
12858                gamma,
12859                beta,
12860                dy,
12861                dx,
12862                rows,
12863                h,
12864                eps,
12865            } => {
12866                let (rows, h) = (*rows as usize, *h as usize);
12867                unsafe {
12868                    let xs = sl(*x, base, rows * h);
12869                    let g = sl(*gamma, base, h);
12870                    let b = sl(*beta, base, h);
12871                    let dys = sl(*dy, base, rows * h);
12872                    let out = sl_mut(*dx, base, rows * h);
12873                    let mut dg = vec![0f32; h];
12874                    let mut db = vec![0f32; h];
12875                    for r in 0..rows {
12876                        crate::training_bwd::rms_norm_backward_row(
12877                            &xs[r * h..(r + 1) * h],
12878                            g,
12879                            b,
12880                            &dys[r * h..(r + 1) * h],
12881                            &mut out[r * h..(r + 1) * h],
12882                            &mut dg,
12883                            &mut db,
12884                            *eps,
12885                        );
12886                    }
12887                }
12888            }
12889
12890            Thunk::RmsNormBackwardGamma {
12891                x,
12892                gamma,
12893                beta,
12894                dy,
12895                dgamma,
12896                rows,
12897                h,
12898                eps,
12899            } => {
12900                let (rows, h) = (*rows as usize, *h as usize);
12901                unsafe {
12902                    let xs = sl(*x, base, rows * h);
12903                    let g = sl(*gamma, base, h);
12904                    let b = sl(*beta, base, h);
12905                    let dys = sl(*dy, base, rows * h);
12906                    let out = sl_mut(*dgamma, base, h);
12907                    for v in out.iter_mut() {
12908                        *v = 0.0;
12909                    }
12910                    let mut dx = vec![0f32; h];
12911                    let mut db = vec![0f32; h];
12912                    for r in 0..rows {
12913                        crate::training_bwd::rms_norm_backward_row(
12914                            &xs[r * h..(r + 1) * h],
12915                            g,
12916                            b,
12917                            &dys[r * h..(r + 1) * h],
12918                            &mut dx,
12919                            &mut *out,
12920                            &mut db,
12921                            *eps,
12922                        );
12923                    }
12924                }
12925            }
12926
12927            Thunk::RmsNormBackwardBeta {
12928                x,
12929                gamma,
12930                beta,
12931                dy,
12932                dbeta,
12933                rows,
12934                h,
12935                eps,
12936            } => {
12937                let (rows, h) = (*rows as usize, *h as usize);
12938                unsafe {
12939                    let xs = sl(*x, base, rows * h);
12940                    let g = sl(*gamma, base, h);
12941                    let b = sl(*beta, base, h);
12942                    let dys = sl(*dy, base, rows * h);
12943                    let out = sl_mut(*dbeta, base, h);
12944                    for v in out.iter_mut() {
12945                        *v = 0.0;
12946                    }
12947                    let mut dx = vec![0f32; h];
12948                    let mut dg = vec![0f32; h];
12949                    for r in 0..rows {
12950                        crate::training_bwd::rms_norm_backward_row(
12951                            &xs[r * h..(r + 1) * h],
12952                            g,
12953                            b,
12954                            &dys[r * h..(r + 1) * h],
12955                            &mut dx,
12956                            &mut dg,
12957                            &mut *out,
12958                            *eps,
12959                        );
12960                    }
12961                }
12962            }
12963
12964            Thunk::RopeBackward {
12965                dy,
12966                cos,
12967                sin,
12968                dx,
12969                batch,
12970                seq,
12971                hidden,
12972                head_dim,
12973                n_rot,
12974                cos_len,
12975            } => {
12976                let (b, s, hs, dh, nr, cl) = (
12977                    *batch as usize,
12978                    *seq as usize,
12979                    *hidden as usize,
12980                    *head_dim as usize,
12981                    *n_rot as usize,
12982                    *cos_len as usize,
12983                );
12984                let nh = hs / dh;
12985                let tab_half = dh / 2;
12986                unsafe {
12987                    let dys = sl(*dy, base, b * s * hs);
12988                    let cos_tab = sl(*cos, base, cl);
12989                    let sin_tab = sl(*sin, base, cl);
12990                    let out = sl_mut(*dx, base, b * s * hs);
12991                    for bi in 0..b {
12992                        for si in 0..s {
12993                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12994                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12995                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12996                            for hi in 0..nh {
12997                                let base_idx = bi * s * hs + si * hs + hi * dh;
12998                                crate::training_bwd::rope_backward_row(
12999                                    &dys[base_idx..base_idx + dh],
13000                                    cp,
13001                                    sp,
13002                                    &mut out[base_idx..base_idx + dh],
13003                                    dh,
13004                                    nr,
13005                                );
13006                            }
13007                        }
13008                    }
13009                }
13010            }
13011
13012            Thunk::CumsumBackward {
13013                dy,
13014                dx,
13015                rows,
13016                cols,
13017                exclusive,
13018            } => {
13019                let (rows, cols) = (*rows as usize, *cols as usize);
13020                unsafe {
13021                    let dys = sl(*dy, base, rows * cols);
13022                    let out = sl_mut(*dx, base, rows * cols);
13023                    for r in 0..rows {
13024                        crate::training_bwd::cumsum_backward_row(
13025                            &dys[r * cols..(r + 1) * cols],
13026                            &mut out[r * cols..(r + 1) * cols],
13027                            *exclusive,
13028                        );
13029                    }
13030                }
13031            }
13032
13033            Thunk::GroupNormBackwardInput {
13034                x,
13035                gamma,
13036                beta: _beta,
13037                dy,
13038                dx,
13039                n,
13040                c,
13041                h,
13042                w,
13043                num_groups,
13044                eps,
13045            } => {
13046                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13047                let plane = c * h * w;
13048                unsafe {
13049                    let xs = sl(*x, base, n * plane);
13050                    let g = sl(*gamma, base, c);
13051                    let dys = sl(*dy, base, n * plane);
13052                    let out = sl_mut(*dx, base, n * plane);
13053                    crate::training_bwd::group_norm_backward_input_nchw(
13054                        xs,
13055                        g,
13056                        dys,
13057                        out,
13058                        n,
13059                        c,
13060                        h,
13061                        w,
13062                        *num_groups as usize,
13063                        *eps,
13064                    );
13065                }
13066            }
13067
13068            Thunk::GroupNormBackwardGamma {
13069                x,
13070                dy,
13071                dgamma,
13072                n,
13073                c,
13074                h,
13075                w,
13076                num_groups,
13077                eps,
13078            } => {
13079                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13080                let plane = c * h * w;
13081                unsafe {
13082                    let xs = sl(*x, base, n * plane);
13083                    let dys = sl(*dy, base, n * plane);
13084                    let out = sl_mut(*dgamma, base, c);
13085                    crate::training_bwd::group_norm_backward_gamma_nchw(
13086                        xs,
13087                        dys,
13088                        out,
13089                        n,
13090                        c,
13091                        h,
13092                        w,
13093                        *num_groups as usize,
13094                        *eps,
13095                    );
13096                }
13097            }
13098
13099            Thunk::GroupNormBackwardBeta {
13100                dy,
13101                dbeta,
13102                n,
13103                c,
13104                h,
13105                w,
13106            } => {
13107                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13108                let plane = c * h * w;
13109                unsafe {
13110                    let dys = sl(*dy, base, n * plane);
13111                    let out = sl_mut(*dbeta, base, c);
13112                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13113                }
13114            }
13115
13116            Thunk::GatherBackward {
13117                dy,
13118                indices,
13119                dst,
13120                outer,
13121                axis_dim,
13122                num_idx,
13123                trailing,
13124            } => {
13125                let (outer, axis_dim, num_idx, trailing) = (
13126                    *outer as usize,
13127                    *axis_dim as usize,
13128                    *num_idx as usize,
13129                    *trailing as usize,
13130                );
13131                unsafe {
13132                    let dys = sl(*dy, base, outer * num_idx * trailing);
13133                    let ids = sl(*indices, base, num_idx);
13134                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13135                    for v in out.iter_mut() {
13136                        *v = 0.0;
13137                    }
13138                    crate::training_bwd::gather_axis_backward(
13139                        dys, ids, out, outer, axis_dim, num_idx, trailing,
13140                    );
13141                }
13142            }
13143
13144            Thunk::MaxPool2dBackward {
13145                x,
13146                dy,
13147                dx,
13148                n,
13149                c,
13150                h,
13151                w,
13152                h_out,
13153                w_out,
13154                kh,
13155                kw,
13156                sh,
13157                sw,
13158                ph,
13159                pw,
13160            } => unsafe {
13161                execute_maxpool2d_backward_f32(
13162                    *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13163                    base,
13164                );
13165            },
13166
13167            Thunk::Conv2dBackwardInput {
13168                dy,
13169                w,
13170                dx,
13171                n,
13172                c_in,
13173                h,
13174                w_in,
13175                c_out,
13176                h_out,
13177                w_out,
13178                kh,
13179                kw,
13180                sh,
13181                sw,
13182                ph,
13183                pw,
13184                dh,
13185                dw,
13186                groups,
13187            } => {
13188                // Per-group GEMM + col2im. Two orders of magnitude faster
13189                // than the naive 6-deep nested loop on training shapes.
13190                //
13191                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
13192                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
13193                //
13194                // Layouts (all row-major):
13195                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
13196                //   dy_n_g    [c_out_per_g, h_out · w_out]
13197                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
13198                //   dx_n_g    [c_in_per_g, h · w_in]
13199                let n = *n as usize;
13200                let c_in = *c_in as usize;
13201                let h = *h as usize;
13202                let w_in = *w_in as usize;
13203                let c_out = *c_out as usize;
13204                let h_out = *h_out as usize;
13205                let w_out = *w_out as usize;
13206                let kh = *kh as usize;
13207                let kw = *kw as usize;
13208                let sh = *sh as usize;
13209                let sw = *sw as usize;
13210                let ph = *ph as usize;
13211                let pw = *pw as usize;
13212                let dh = *dh as usize;
13213                let dw = *dw as usize;
13214                let groups = *groups as usize;
13215                let c_in_per_g = c_in / groups;
13216                let c_out_per_g = c_out / groups;
13217
13218                let m_dim = c_in_per_g * kh * kw;
13219                let n_dim = h_out * w_out;
13220                let k_dim = c_out_per_g;
13221
13222                let dy_stride_n = c_out * h_out * w_out;
13223                let dy_stride_g = c_out_per_g * h_out * w_out;
13224                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13225                let dx_stride_n = c_in * h * w_in;
13226                let dx_stride_g = c_in_per_g * h * w_in;
13227
13228                unsafe {
13229                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13230                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13231                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13232                    for v in dxs.iter_mut() {
13233                        *v = 0.0;
13234                    }
13235
13236                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
13237                    let mut dcol = vec![0f32; m_dim * n_dim];
13238
13239                    for ni in 0..n {
13240                        for g in 0..groups {
13241                            let w_g_off = g * w_stride_g;
13242                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13243                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13244
13245                            // dcol = w_g^T @ dy_n_g
13246                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
13247                            // (i.e. K×M storage with lda = M = m_dim — exactly what
13248                            // sgemm_general wants for trans_a=true).
13249                            crate::blas::sgemm_general(
13250                                ws.as_ptr().add(w_g_off),
13251                                dys.as_ptr().add(dy_n_g_off),
13252                                dcol.as_mut_ptr(),
13253                                m_dim,
13254                                n_dim,
13255                                k_dim,
13256                                1.0,
13257                                0.0,
13258                                /*lda=*/ m_dim,
13259                                /*ldb=*/ n_dim,
13260                                /*ldc=*/ n_dim,
13261                                /*trans_a=*/ true,
13262                                /*trans_b=*/ false,
13263                            );
13264
13265                            // dx_n_g += col2im(dcol)
13266                            col2im(
13267                                &dcol,
13268                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13269                                c_in_per_g,
13270                                h,
13271                                w_in,
13272                                h_out,
13273                                w_out,
13274                                kh,
13275                                kw,
13276                                sh,
13277                                sw,
13278                                ph,
13279                                pw,
13280                                dh,
13281                                dw,
13282                            );
13283                        }
13284                    }
13285                }
13286            }
13287
13288            Thunk::Conv2dBackwardWeight {
13289                x,
13290                dy,
13291                dw,
13292                n,
13293                c_in,
13294                h,
13295                w,
13296                c_out,
13297                h_out,
13298                w_out,
13299                kh,
13300                kw,
13301                sh,
13302                sw,
13303                ph,
13304                pw,
13305                dh,
13306                dw_dil,
13307                groups,
13308            } => {
13309                let n = *n as usize;
13310                let c_in = *c_in as usize;
13311                let h = *h as usize;
13312                let w = *w as usize;
13313                // Per-group im2col + GEMM, summed across batch.
13314                //
13315                //   col_n_g  = im2col(x_n_g)               (gather)
13316                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
13317                //
13318                // Layouts:
13319                //   x_n_g     [c_in_per_g, h · w]
13320                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
13321                //   dy_n_g    [c_out_per_g, h_out · w_out]
13322                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
13323                let c_out = *c_out as usize;
13324                let h_out = *h_out as usize;
13325                let w_out = *w_out as usize;
13326                let kh = *kh as usize;
13327                let kw = *kw as usize;
13328                let sh = *sh as usize;
13329                let sw = *sw as usize;
13330                let ph = *ph as usize;
13331                let pw = *pw as usize;
13332                let dh = *dh as usize;
13333                let dw_dil = *dw_dil as usize;
13334                let groups = *groups as usize;
13335                let c_in_per_g = c_in / groups;
13336                let c_out_per_g = c_out / groups;
13337
13338                let m_dim = c_out_per_g;
13339                let n_dim = c_in_per_g * kh * kw;
13340                let k_dim = h_out * w_out;
13341
13342                let x_stride_n = c_in * h * w;
13343                let x_stride_g = c_in_per_g * h * w;
13344                let dy_stride_n = c_out * h_out * w_out;
13345                let dy_stride_g = c_out_per_g * h_out * w_out;
13346                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13347
13348                unsafe {
13349                    let xs = sl(*x, base, n * c_in * h * w);
13350                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13351                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13352                    for v in dws.iter_mut() {
13353                        *v = 0.0;
13354                    }
13355
13356                    let mut col = vec![0f32; n_dim * k_dim];
13357
13358                    for ni in 0..n {
13359                        for g in 0..groups {
13360                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13361                            im2col(
13362                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
13363                                &mut col,
13364                                c_in_per_g,
13365                                h,
13366                                w,
13367                                h_out,
13368                                w_out,
13369                                kh,
13370                                kw,
13371                                sh,
13372                                sw,
13373                                ph,
13374                                pw,
13375                                dh,
13376                                dw_dil,
13377                            );
13378
13379                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13380                            let dw_g_off = g * dw_stride_g;
13381
13382                            // dw_g += dy_n_g @ col^T
13383                            //
13384                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
13385                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
13386                            // col is stored as N×K row-major; with trans_b=true,
13387                            // sgemm_general uses ldb = K = k_dim and treats it as
13388                            // transposed. β=1 accumulates across the batch loop.
13389                            crate::blas::sgemm_general(
13390                                dys.as_ptr().add(dy_n_g_off),
13391                                col.as_ptr(),
13392                                dws.as_mut_ptr().add(dw_g_off),
13393                                m_dim,
13394                                n_dim,
13395                                k_dim,
13396                                1.0,
13397                                1.0,
13398                                /*lda=*/ k_dim,
13399                                /*ldb=*/ k_dim,
13400                                /*ldc=*/ n_dim,
13401                                /*trans_a=*/ false,
13402                                /*trans_b=*/ true,
13403                            );
13404                        }
13405                    }
13406                }
13407            }
13408
13409            Thunk::Im2Col {
13410                x,
13411                col,
13412                n,
13413                c_in,
13414                h,
13415                w,
13416                h_out,
13417                w_out,
13418                kh,
13419                kw,
13420                sh,
13421                sw,
13422                ph,
13423                pw,
13424                dh,
13425                dw_dil,
13426            } => {
13427                let c_in = *c_in as usize;
13428                let h = *h as usize;
13429                let w = *w as usize;
13430                let h_out = *h_out as usize;
13431                let w_out = *w_out as usize;
13432                let kh = *kh as usize;
13433                let kw = *kw as usize;
13434                let sh = *sh as usize;
13435                let sw = *sw as usize;
13436                let ph = *ph as usize;
13437                let pw = *pw as usize;
13438                let dh = *dh as usize;
13439                let dw_dil = *dw_dil as usize;
13440                let per_batch = c_in * h * w;
13441                unsafe {
13442                    let n_eff = if *n == 0 { 0usize } else { *n as usize };
13443                    let x_floats = if n_eff == 0 {
13444                        per_batch.max(1)
13445                    } else {
13446                        n_eff * per_batch
13447                    };
13448                    let xs = sl(*x, base, x_floats);
13449                    let n = if *n == 0 {
13450                        xs.len() / per_batch.max(1)
13451                    } else {
13452                        n_eff
13453                    };
13454                    let m = n * h_out * w_out;
13455                    let k = c_in * kh * kw;
13456                    let cols = sl_mut(*col, base, m * k);
13457                    crate::im2col::im2col_rows_layout(
13458                        xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
13459                    );
13460                }
13461            }
13462
13463            Thunk::SoftmaxCrossEntropy {
13464                logits,
13465                labels,
13466                dst,
13467                n,
13468                c,
13469            } => {
13470                let n = *n as usize;
13471                let c = *c as usize;
13472                unsafe {
13473                    let lg = sl(*logits, base, n * c);
13474                    let lb = sl(*labels, base, n);
13475                    let out = sl_mut(*dst, base, n);
13476                    for ni in 0..n {
13477                        let row = &lg[ni * c..(ni + 1) * c];
13478                        // log-sum-exp: max-subtract for stability.
13479                        let mut m = f32::NEG_INFINITY;
13480                        for &v in row {
13481                            if v > m {
13482                                m = v;
13483                            }
13484                        }
13485                        let mut sum = 0f32;
13486                        for &v in row {
13487                            sum += (v - m).exp();
13488                        }
13489                        let lse = m + sum.ln();
13490                        let label_idx = lb[ni] as usize;
13491                        // loss = -(logits[label] - lse) = lse - logits[label].
13492                        out[ni] = lse - row[label_idx];
13493                    }
13494                }
13495            }
13496
13497            Thunk::SoftmaxCrossEntropyBackward {
13498                logits,
13499                labels,
13500                d_loss,
13501                dlogits,
13502                n,
13503                c,
13504            } => {
13505                let n = *n as usize;
13506                let c = *c as usize;
13507                unsafe {
13508                    let lg = sl(*logits, base, n * c);
13509                    let lb = sl(*labels, base, n);
13510                    let dl = sl(*d_loss, base, n);
13511                    let out = sl_mut(*dlogits, base, n * c);
13512                    for ni in 0..n {
13513                        let row = &lg[ni * c..(ni + 1) * c];
13514                        let label_idx = lb[ni] as usize;
13515                        let scale = dl[ni];
13516                        let mut m = f32::NEG_INFINITY;
13517                        for &v in row {
13518                            if v > m {
13519                                m = v;
13520                            }
13521                        }
13522                        let mut sum = 0f32;
13523                        for &v in row {
13524                            sum += (v - m).exp();
13525                        }
13526                        let inv_sum = 1.0 / sum;
13527                        let dst_row = &mut out[ni * c..(ni + 1) * c];
13528                        for k in 0..c {
13529                            let p = (row[k] - m).exp() * inv_sum;
13530                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
13531                            dst_row[k] = (p - one_hot) * scale;
13532                        }
13533                    }
13534                }
13535            }
13536
13537            Thunk::GatherAxis {
13538                table,
13539                idx,
13540                dst,
13541                outer,
13542                axis_dim,
13543                num_idx,
13544                trailing,
13545                idx_i64,
13546                table_bytes,
13547            } => {
13548                let outer = *outer as usize;
13549                let axis_dim = *axis_dim as usize;
13550                let num_idx = *num_idx as usize;
13551                let trailing = *trailing as usize;
13552                unsafe {
13553                    if *table_bytes == 8 {
13554                        let tab = sl_i64(*table, base, outer * axis_dim * trailing);
13555                        let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
13556                        for o in 0..outer {
13557                            let tab_outer = o * axis_dim * trailing;
13558                            let out_outer = o * num_idx * trailing;
13559                            if *idx_i64 != 0 {
13560                                let ids = sl_i64(*idx, base, num_idx);
13561                                for k in 0..num_idx {
13562                                    let row = ids[k].max(0) as usize;
13563                                    if row < axis_dim {
13564                                        let tab_row = tab_outer + row * trailing;
13565                                        let out_row = out_outer + k * trailing;
13566                                        out[out_row..out_row + trailing]
13567                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13568                                    }
13569                                }
13570                            } else {
13571                                let ids = sl(*idx, base, num_idx);
13572                                for k in 0..num_idx {
13573                                    let row = ids[k] as usize;
13574                                    if row < axis_dim {
13575                                        let tab_row = tab_outer + row * trailing;
13576                                        let out_row = out_outer + k * trailing;
13577                                        out[out_row..out_row + trailing]
13578                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13579                                    }
13580                                }
13581                            }
13582                        }
13583                    } else {
13584                        let tab = sl(*table, base, outer * axis_dim * trailing);
13585                        let out = sl_mut(*dst, base, outer * num_idx * trailing);
13586                        for o in 0..outer {
13587                            let tab_outer = o * axis_dim * trailing;
13588                            let out_outer = o * num_idx * trailing;
13589                            if *idx_i64 != 0 {
13590                                let ids = sl_i64(*idx, base, num_idx);
13591                                for k in 0..num_idx {
13592                                    let row = ids[k].max(0) as usize;
13593                                    if row < axis_dim {
13594                                        let tab_row = tab_outer + row * trailing;
13595                                        let out_row = out_outer + k * trailing;
13596                                        out[out_row..out_row + trailing]
13597                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13598                                    }
13599                                }
13600                            } else {
13601                                let ids = sl(*idx, base, num_idx);
13602                                for k in 0..num_idx {
13603                                    let row = ids[k] as usize;
13604                                    if row < axis_dim {
13605                                        let tab_row = tab_outer + row * trailing;
13606                                        let out_row = out_outer + k * trailing;
13607                                        out[out_row..out_row + trailing]
13608                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
13609                                    }
13610                                }
13611                            }
13612                        }
13613                    }
13614                }
13615            }
13616
13617            Thunk::Transpose {
13618                src,
13619                dst,
13620                in_total,
13621                out_dims,
13622                in_strides,
13623                elem_bytes,
13624            } => {
13625                // N-D index walk: for each output flat index, decompose into
13626                // multi-dim coords using out_dims, then dot with in_strides
13627                // to find the source flat index. Stride 0 = broadcast (read
13628                // the same input element repeatedly along that dim).
13629                let rank = out_dims.len();
13630                let total: usize = out_dims.iter().map(|&d| d as usize).product();
13631                let in_total = *in_total as usize;
13632                unsafe {
13633                    if *elem_bytes == 8 {
13634                        let inp = sl_i64(*src, base, in_total);
13635                        let out = sl_mut_i64(*dst, base, total);
13636                        let mut idx = vec![0usize; rank];
13637                        for o in 0..total {
13638                            let mut src_idx = 0usize;
13639                            for d in 0..rank {
13640                                src_idx += idx[d] * in_strides[d] as usize;
13641                            }
13642                            out[o] = inp[src_idx];
13643                            for d in (0..rank).rev() {
13644                                idx[d] += 1;
13645                                if idx[d] < out_dims[d] as usize {
13646                                    break;
13647                                }
13648                                idx[d] = 0;
13649                            }
13650                        }
13651                    } else {
13652                        let inp = sl(*src, base, in_total);
13653                        let out = sl_mut(*dst, base, total);
13654                        let mut idx = vec![0usize; rank];
13655                        for o in 0..total {
13656                            let mut src_idx = 0usize;
13657                            for d in 0..rank {
13658                                src_idx += idx[d] * in_strides[d] as usize;
13659                            }
13660                            out[o] = inp[src_idx];
13661                            for d in (0..rank).rev() {
13662                                idx[d] += 1;
13663                                if idx[d] < out_dims[d] as usize {
13664                                    break;
13665                                }
13666                                idx[d] = 0;
13667                            }
13668                        }
13669                    }
13670                }
13671            }
13672
13673            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
13674            // stubs here as placeholders during the wire-up; both
13675            // are now reached by the real implementations earlier in
13676            // this same match — the stubs were dead code shadowed by
13677            // the specific-pattern arms above. Removed.)
13678            Thunk::CustomOp {
13679                kernel,
13680                inputs,
13681                output,
13682                attrs,
13683            } => {
13684                let (out_off, out_len, out_shape) = output;
13685                unsafe {
13686                    dispatch_custom_op(
13687                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
13688                    );
13689                }
13690            }
13691        }
13692        if trace_done {
13693            eprintln!("[thunk {i} done]");
13694        }
13695    }
13696}
13697
13698/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
13699/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
13700/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
13701/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
13702///
13703/// Compared to the iterative segment-cached scheme, this trades extra
13704/// recompute for less working memory — each level of recursion holds
13705/// one `cb`-sized intermediate carry on the stack but never the whole
13706/// segment at once. With K saved outer checkpoints, the outer driver
13707/// invokes this helper once per segment.
13708///
13709/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
13710/// runs `body_vjp` at iteration `t` with the supplied carry, threads
13711/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
13712#[allow(clippy::too_many_arguments)]
13713unsafe fn griewank_process_segment(
13714    t_lo: usize,
13715    t_hi: usize,
13716    anchor_carry: &[u8],
13717    cb: usize,
13718    fwd_sched: &ThunkSchedule,
13719    fwd_init: &[u8],
13720    fwd_carry_in_off: usize,
13721    fwd_output_off: usize,
13722    fwd_x_offs: &[usize],
13723    base: *mut u8,
13724    outer_xs_offs: &[(usize, u32)],
13725    fwd_buf: &mut Vec<u8>,
13726    leaf_threshold: usize,
13727    process_iter: &mut dyn FnMut(usize, &[u8]),
13728) {
13729    unsafe {
13730        let size = t_hi - t_lo + 1;
13731        if size == 1 {
13732            process_iter(t_lo, anchor_carry);
13733            return;
13734        }
13735        if size <= leaf_threshold {
13736            // Walk forward, cache each carry, run backward in reverse.
13737            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
13738            cache.extend_from_slice(anchor_carry);
13739            fwd_buf.copy_from_slice(fwd_init);
13740            std::ptr::copy_nonoverlapping(
13741                anchor_carry.as_ptr(),
13742                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13743                cb,
13744            );
13745            for i in 1..size {
13746                let cur_iter = t_lo + i - 1;
13747                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13748                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13749                    let xb = x_psb as usize;
13750                    std::ptr::copy_nonoverlapping(
13751                        base.add(outer_xs_off + cur_iter * xb),
13752                        fwd_buf.as_mut_ptr().add(*fb_x_off),
13753                        xb,
13754                    );
13755                }
13756                execute_thunks(fwd_sched, fwd_buf);
13757                if fwd_output_off != fwd_carry_in_off {
13758                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13759                }
13760                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
13761            }
13762            // Process backward.
13763            for t in (t_lo..=t_hi).rev() {
13764                let idx = t - t_lo;
13765                let carry = &cache[idx * cb..(idx + 1) * cb];
13766                process_iter(t, carry);
13767            }
13768            return;
13769        }
13770
13771        // Split: walk forward from anchor to compute carry entering `mid`.
13772        // (We need `mid - t_lo` body executions: one per iteration in
13773        // [t_lo, mid).)
13774        let mid = t_lo + size / 2;
13775        fwd_buf.copy_from_slice(fwd_init);
13776        std::ptr::copy_nonoverlapping(
13777            anchor_carry.as_ptr(),
13778            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
13779            cb,
13780        );
13781        for cur_iter in t_lo..mid {
13782            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
13783                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
13784                let xb = x_psb as usize;
13785                std::ptr::copy_nonoverlapping(
13786                    base.add(outer_xs_off + cur_iter * xb),
13787                    fwd_buf.as_mut_ptr().add(*fb_x_off),
13788                    xb,
13789                );
13790            }
13791            execute_thunks(fwd_sched, fwd_buf);
13792            if fwd_output_off != fwd_carry_in_off {
13793                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
13794            }
13795        }
13796        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
13797
13798        // Right half first (higher t values processed first to match the
13799        // canonical reverse-mode iteration order: dcarry threads from
13800        // t=length-1 down to t=0).
13801        griewank_process_segment(
13802            mid,
13803            t_hi,
13804            &mid_carry,
13805            cb,
13806            fwd_sched,
13807            fwd_init,
13808            fwd_carry_in_off,
13809            fwd_output_off,
13810            fwd_x_offs,
13811            base,
13812            outer_xs_offs,
13813            fwd_buf,
13814            leaf_threshold,
13815            process_iter,
13816        );
13817        // Then left half with original anchor.
13818        griewank_process_segment(
13819            t_lo,
13820            mid - 1,
13821            anchor_carry,
13822            cb,
13823            fwd_sched,
13824            fwd_init,
13825            fwd_carry_in_off,
13826            fwd_output_off,
13827            fwd_x_offs,
13828            base,
13829            outer_xs_offs,
13830            fwd_buf,
13831            leaf_threshold,
13832            process_iter,
13833        );
13834    }
13835}
13836
13837/// Execute a batched 1D FFT in the f64 2N-real-block layout.
13838/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
13839/// The `outer` rows are independent and processed sequentially.
13840///
13841/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
13842/// kernel — only the twiddle-factor sign differs. Power-of-2 only
13843/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
13844/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
13845/// backend crates can invoke this as a host fallback against a
13846/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
13847/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
13848/// command buffer). Self-contained — no rlx-cpu state required.
13849///
13850/// Safety: `base + src` and `base + dst` must be valid for the
13851/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
13852/// alive for the duration of the call.
13853pub unsafe fn execute_fft1d_f64(
13854    src: usize,
13855    dst: usize,
13856    outer: usize,
13857    n_complex: usize,
13858    inverse: bool,
13859    norm_tag: u32,
13860    base: *mut u8,
13861) {
13862    let row_elems = 2 * n_complex;
13863    let mut re = vec![0f64; n_complex];
13864    let mut im = vec![0f64; n_complex];
13865    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13866    let scale = norm.output_scale(n_complex, inverse);
13867    // Scratch reused across rows for the Bluestein path. Empty when
13868    // we're on the radix-2 fast path.
13869    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13870        BluesteinScratchF64::empty()
13871    } else {
13872        BluesteinScratchF64::build(n_complex, inverse)
13873    };
13874    for o in 0..outer {
13875        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
13876        let s = unsafe { sl_f64(row_offset, base, row_elems) };
13877        re.copy_from_slice(&s[..n_complex]);
13878        im.copy_from_slice(&s[n_complex..]);
13879        if n_complex.is_power_of_two() {
13880            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
13881        } else if n_complex <= 16 {
13882            fft_naive_inplace_f64(&mut re, &mut im, inverse);
13883        } else {
13884            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
13885        }
13886        if scale != 1.0 {
13887            re.iter_mut().for_each(|v| *v *= scale);
13888            im.iter_mut().for_each(|v| *v *= scale);
13889        }
13890        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
13891        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
13892        d[..n_complex].copy_from_slice(&re);
13893        d[n_complex..].copy_from_slice(&im);
13894    }
13895}
13896
13897/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
13898/// (first N real, second N imag per row), same unnormalized
13899/// convention; only the element width differs. Twiddle factors are
13900/// computed in f64 and cast to f32 to keep large-N error closer to
13901/// the f64 path (the savings from f32 are in memory bandwidth, not in
13902/// twiddle precision).
13903/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
13904/// When `state == 0`, uses a zero-initialized scratch state per batch item.
13905pub unsafe fn execute_gated_delta_net_f32(
13906    q: usize,
13907    k: usize,
13908    v: usize,
13909    g: usize,
13910    beta: usize,
13911    state: usize,
13912    dst: usize,
13913    batch: usize,
13914    seq: usize,
13915    heads: usize,
13916    state_size: usize,
13917    base: *mut u8,
13918) {
13919    use rayon::prelude::*;
13920
13921    #[derive(Copy, Clone)]
13922    struct ArenaPtr(usize);
13923    unsafe impl Send for ArenaPtr {}
13924    unsafe impl Sync for ArenaPtr {}
13925    impl ArenaPtr {
13926        #[inline]
13927        fn get(self) -> *mut u8 {
13928            self.0 as *mut u8
13929        }
13930    }
13931
13932    unsafe {
13933        let arena = ArenaPtr(base as usize);
13934        let (b, s, h, n) = (batch, seq, heads, state_size);
13935        let scale = 1.0f32 / (n as f32).sqrt();
13936        let use_external = state != 0;
13937        let mut owned_state = vec![0f32; h * n * n];
13938
13939        crate::pool::num_threads();
13940
13941        assert!(
13942            n <= crate::gdn::GDN_MAX_STATE,
13943            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
13944            crate::gdn::GDN_MAX_STATE
13945        );
13946
13947        let qs = sl(q, arena.get(), b * s * h * n);
13948        let ks = sl(k, arena.get(), b * s * h * n);
13949        let vs = sl(v, arena.get(), b * s * h * n);
13950        let gs = sl(g, arena.get(), b * s * h);
13951        let betas = sl(beta, arena.get(), b * s * h);
13952        let _out = sl_mut(dst, arena.get(), b * s * h * n);
13953        let hs_n = h * n;
13954
13955        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
13956            for ti in 0..s {
13957                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
13958                let gb_step = bi * s * h + ti * h + hi;
13959                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
13960                crate::gdn::gdn_step_blas(
13961                    s_mat,
13962                    &qs[qkv_step..qkv_step + n],
13963                    &ks[qkv_step..qkv_step + n],
13964                    &vs[qkv_step..qkv_step + n],
13965                    gs[gb_step],
13966                    betas[gb_step],
13967                    out_row,
13968                    sk,
13969                    n,
13970                    scale,
13971                );
13972            }
13973        };
13974
13975        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
13976        // better occupancy than head-outer when prompt length dominates.
13977        if !use_external && s > 1 {
13978            for bi in 0..b {
13979                (0..h).into_par_iter().for_each(|hi| {
13980                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
13981                    let sk = &mut sk_buf[..n];
13982                    let mut local_state =
13983                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
13984                    let s_mat = &mut local_state[..n * n];
13985                    s_mat.fill(0.0);
13986                    run_head(bi, hi, s_mat, sk);
13987                });
13988            }
13989            return;
13990        }
13991
13992        if use_external {
13993            let state_bytes = state;
13994            (0..b * h).into_par_iter().for_each(|bhi| {
13995                let bi = bhi / h;
13996                let hi = bhi % h;
13997                let elem_off = bi * h * n * n + hi * n * n;
13998                let s_mat = sl_mut(
13999                    state_bytes + elem_off * std::mem::size_of::<f32>(),
14000                    arena.get(),
14001                    n * n,
14002                );
14003                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14004                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14005            });
14006        } else {
14007            for bi in 0..b {
14008                owned_state.fill(0.0);
14009                owned_state
14010                    .par_chunks_mut(n * n)
14011                    .enumerate()
14012                    .for_each(|(hi, s_mat)| {
14013                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14014                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14015                    });
14016            }
14017        }
14018    }
14019}
14020
14021/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
14022pub unsafe fn execute_rms_norm_backward_input_f32(
14023    x: usize,
14024    gamma: usize,
14025    beta: usize,
14026    dy: usize,
14027    dx: usize,
14028    rows: u32,
14029    h: u32,
14030    eps: f32,
14031    base: *mut u8,
14032) {
14033    let (rows, h) = (rows as usize, h as usize);
14034    let mut dg = vec![0f32; h];
14035    let mut db = vec![0f32; h];
14036    let xs = sl(x, base, rows * h);
14037    let dys = sl(dy, base, rows * h);
14038    let g = sl(gamma, base, h);
14039    let b = sl(beta, base, h);
14040    let out = sl_mut(dx, base, rows * h);
14041    for r in 0..rows {
14042        crate::training_bwd::rms_norm_backward_row(
14043            &xs[r * h..(r + 1) * h],
14044            g,
14045            b,
14046            &dys[r * h..(r + 1) * h],
14047            &mut out[r * h..(r + 1) * h],
14048            &mut dg,
14049            &mut db,
14050            eps,
14051        );
14052    }
14053}
14054
14055pub unsafe fn execute_rms_norm_backward_gamma_f32(
14056    x: usize,
14057    gamma: usize,
14058    beta: usize,
14059    dy: usize,
14060    dgamma: usize,
14061    rows: u32,
14062    h: u32,
14063    eps: f32,
14064    base: *mut u8,
14065) {
14066    let (rows, h) = (rows as usize, h as usize);
14067    let out = sl_mut(dgamma, base, h);
14068    out.fill(0.0);
14069    let mut dx = vec![0f32; h];
14070    let mut db = vec![0f32; h];
14071    let xs = sl(x, base, rows * h);
14072    let dys = sl(dy, base, rows * h);
14073    let g = sl(gamma, base, h);
14074    let b = sl(beta, base, h);
14075    for r in 0..rows {
14076        crate::training_bwd::rms_norm_backward_row(
14077            &xs[r * h..(r + 1) * h],
14078            g,
14079            b,
14080            &dys[r * h..(r + 1) * h],
14081            &mut dx,
14082            out,
14083            &mut db,
14084            eps,
14085        );
14086    }
14087}
14088
14089pub unsafe fn execute_rms_norm_backward_beta_f32(
14090    x: usize,
14091    gamma: usize,
14092    beta: usize,
14093    dy: usize,
14094    dbeta: usize,
14095    rows: u32,
14096    h: u32,
14097    eps: f32,
14098    base: *mut u8,
14099) {
14100    let (rows, h) = (rows as usize, h as usize);
14101    let out = sl_mut(dbeta, base, h);
14102    out.fill(0.0);
14103    let mut dx = vec![0f32; h];
14104    let mut dg = vec![0f32; h];
14105    let xs = sl(x, base, rows * h);
14106    let dys = sl(dy, base, rows * h);
14107    let g = sl(gamma, base, h);
14108    let b = sl(beta, base, h);
14109    for r in 0..rows {
14110        crate::training_bwd::rms_norm_backward_row(
14111            &xs[r * h..(r + 1) * h],
14112            g,
14113            b,
14114            &dys[r * h..(r + 1) * h],
14115            &mut dx,
14116            &mut dg,
14117            out,
14118            eps,
14119        );
14120    }
14121}
14122
14123#[allow(clippy::too_many_arguments)]
14124pub unsafe fn execute_conv2d_forward_f32(
14125    src: usize,
14126    weight: usize,
14127    dst: usize,
14128    n: u32,
14129    c_in: u32,
14130    h: u32,
14131    w: u32,
14132    c_out: u32,
14133    h_out: u32,
14134    w_out: u32,
14135    kh: u32,
14136    kw: u32,
14137    sh: u32,
14138    sw: u32,
14139    ph: u32,
14140    pw: u32,
14141    dh: u32,
14142    dw: u32,
14143    groups: u32,
14144    base: *mut u8,
14145) {
14146    let n = n as usize;
14147    let c_in = c_in as usize;
14148    let h = h as usize;
14149    let w = w as usize;
14150    let c_out = c_out as usize;
14151    let h_out = h_out as usize;
14152    let w_out = w_out as usize;
14153    let kh = kh as usize;
14154    let kw = kw as usize;
14155    let sh = sh as usize;
14156    let sw = sw as usize;
14157    let ph = ph as usize;
14158    let pw = pw as usize;
14159    let dh = dh as usize;
14160    let dw = dw as usize;
14161    let groups = groups as usize;
14162    let c_in_per_g = c_in / groups;
14163    let inp = sl(src, base, n * c_in * h * w);
14164    let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14165    let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14166    crate::conv_fwd::conv2d_forward_nchw_f32(
14167        inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14168    );
14169}
14170
14171pub unsafe fn execute_maxpool2d_backward_f32(
14172    x: usize,
14173    dy: usize,
14174    dx: usize,
14175    n: u32,
14176    c: u32,
14177    h: u32,
14178    w: u32,
14179    h_out: u32,
14180    w_out: u32,
14181    kh: u32,
14182    kw: u32,
14183    sh: u32,
14184    sw: u32,
14185    ph: u32,
14186    pw: u32,
14187    base: *mut u8,
14188) {
14189    let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14190    let (h_out, w_out) = (h_out as usize, w_out as usize);
14191    let (kh, kw) = (kh as usize, kw as usize);
14192    let (sh, sw) = (sh as usize, sw as usize);
14193    let (ph, pw) = (ph as usize, pw as usize);
14194    let xs = sl(x, base, n * c * h * w);
14195    let dys = sl(dy, base, n * c * h_out * w_out);
14196    let dxs = sl_mut(dx, base, n * c * h * w);
14197    crate::training_bwd::maxpool2d_backward_nchw(
14198        xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14199    );
14200}
14201
14202pub unsafe fn execute_rope_backward_f32(
14203    dy: usize,
14204    cos: usize,
14205    sin: usize,
14206    dx: usize,
14207    batch: u32,
14208    seq: u32,
14209    hidden: u32,
14210    head_dim: u32,
14211    n_rot: u32,
14212    cos_len: u32,
14213    base: *mut u8,
14214) {
14215    let (b, s, hs, dh, nr, cl) = (
14216        batch as usize,
14217        seq as usize,
14218        hidden as usize,
14219        head_dim as usize,
14220        n_rot as usize,
14221        cos_len as usize,
14222    );
14223    let nh = hs / dh;
14224    let tab_half = dh / 2;
14225    let dys = sl(dy, base, b * s * hs);
14226    let cos_tab = sl(cos, base, cl);
14227    let sin_tab = sl(sin, base, cl);
14228    let out = sl_mut(dx, base, b * s * hs);
14229    for bi in 0..b {
14230        for si in 0..s {
14231            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
14232            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
14233            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
14234            for hi in 0..nh {
14235                let base_idx = bi * s * hs + si * hs + hi * dh;
14236                crate::training_bwd::rope_backward_row(
14237                    &dys[base_idx..base_idx + dh],
14238                    cp,
14239                    sp,
14240                    &mut out[base_idx..base_idx + dh],
14241                    dh,
14242                    nr,
14243                );
14244            }
14245        }
14246    }
14247}
14248
14249pub unsafe fn execute_cumsum_backward_f32(
14250    dy: usize,
14251    dx: usize,
14252    rows: u32,
14253    cols: u32,
14254    exclusive: bool,
14255    base: *mut u8,
14256) {
14257    let (rows, cols) = (rows as usize, cols as usize);
14258    let dys = sl(dy, base, rows * cols);
14259    let out = sl_mut(dx, base, rows * cols);
14260    for r in 0..rows {
14261        crate::training_bwd::cumsum_backward_row(
14262            &dys[r * cols..(r + 1) * cols],
14263            &mut out[r * cols..(r + 1) * cols],
14264            exclusive,
14265        );
14266    }
14267}
14268
14269pub unsafe fn execute_gather_backward_f32(
14270    dy: usize,
14271    indices: usize,
14272    dst: usize,
14273    outer: u32,
14274    axis_dim: u32,
14275    num_idx: u32,
14276    trailing: u32,
14277    base: *mut u8,
14278) {
14279    let (outer, axis_dim, num_idx, trailing) = (
14280        outer as usize,
14281        axis_dim as usize,
14282        num_idx as usize,
14283        trailing as usize,
14284    );
14285    let out = sl_mut(dst, base, outer * axis_dim * trailing);
14286    out.fill(0.0);
14287    crate::training_bwd::gather_axis_backward(
14288        sl(dy, base, outer * num_idx * trailing),
14289        sl(indices, base, num_idx),
14290        out,
14291        outer,
14292        axis_dim,
14293        num_idx,
14294        trailing,
14295    );
14296}
14297
14298/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
14299pub unsafe fn execute_dequant_matmul_gguf_f32(
14300    x: usize,
14301    w_q: usize,
14302    dst: usize,
14303    m: usize,
14304    k: usize,
14305    n: usize,
14306    scheme: rlx_ir::quant::QuantScheme,
14307    base: *mut u8,
14308) {
14309    unsafe {
14310        let block_bytes = scheme.gguf_block_bytes() as usize;
14311        let block_elems = scheme.gguf_block_size() as usize;
14312        let total_bytes = (k * n) / block_elems * block_bytes;
14313        let xs = sl(x, base, m * k);
14314        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
14315        let out = sl_mut(dst, base, m * n);
14316        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
14317    }
14318}
14319
14320/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
14321pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
14322    input: usize,
14323    w_q: usize,
14324    expert_idx: usize,
14325    dst: usize,
14326    m: usize,
14327    k: usize,
14328    n: usize,
14329    num_experts: usize,
14330    scheme: rlx_ir::quant::QuantScheme,
14331    base: *mut u8,
14332) {
14333    unsafe {
14334        let block_bytes = scheme.gguf_block_bytes() as usize;
14335        let block_elems = scheme.gguf_block_size() as usize;
14336        let slab_bytes = (k * n) / block_elems * block_bytes;
14337        let xs = sl(input, base, m * k);
14338        let w_bytes =
14339            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
14340        let ids = sl(expert_idx, base, m);
14341        let out = sl_mut(dst, base, m * n);
14342        crate::gguf_matmul::gguf_grouped_matmul_bt(
14343            xs,
14344            w_bytes,
14345            ids,
14346            out,
14347            m,
14348            k,
14349            n,
14350            num_experts,
14351            scheme,
14352        );
14353    }
14354}
14355
14356/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
14357pub unsafe fn execute_dequant_matmul_int4_f32(
14358    x: usize,
14359    w_q: usize,
14360    scale: usize,
14361    zp: usize,
14362    dst: usize,
14363    m: usize,
14364    k: usize,
14365    n: usize,
14366    block_size: u32,
14367    is_asymmetric: bool,
14368    base: *mut u8,
14369) {
14370    let bs = block_size as usize;
14371    let n_blocks = k.div_ceil(bs);
14372    unsafe {
14373        let xs = sl(x, base, m * k);
14374        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14375        let scales = sl(scale, base, n_blocks * n);
14376        let zps = if is_asymmetric {
14377            sl(zp, base, n_blocks * n)
14378        } else {
14379            &[][..]
14380        };
14381        let out = sl_mut(dst, base, m * n);
14382        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
14383    }
14384}
14385
14386/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
14387pub unsafe fn execute_dequant_matmul_fp8_f32(
14388    x: usize,
14389    w_q: usize,
14390    scale: usize,
14391    dst: usize,
14392    m: usize,
14393    k: usize,
14394    n: usize,
14395    e5m2: bool,
14396    base: *mut u8,
14397) {
14398    unsafe {
14399        let xs = sl(x, base, m * k);
14400        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
14401        let scales = sl(scale, base, n);
14402        let out = sl_mut(dst, base, m * n);
14403        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
14404    }
14405}
14406
14407/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
14408pub unsafe fn execute_dequant_matmul_nvfp4_f32(
14409    x: usize,
14410    w_q: usize,
14411    scale: usize,
14412    global_scale: usize,
14413    dst: usize,
14414    m: usize,
14415    k: usize,
14416    n: usize,
14417    base: *mut u8,
14418) {
14419    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
14420    unsafe {
14421        let xs = sl(x, base, m * k);
14422        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
14423        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
14424        let gs = sl(global_scale, base, 1)[0];
14425        let out = sl_mut(dst, base, m * n);
14426        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
14427    }
14428}
14429
14430/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
14431pub unsafe fn execute_gated_delta_net_f16(
14432    q: usize,
14433    k: usize,
14434    v: usize,
14435    g: usize,
14436    beta: usize,
14437    state: usize,
14438    dst: usize,
14439    batch: usize,
14440    seq: usize,
14441    heads: usize,
14442    state_size: usize,
14443    base: *mut u8,
14444) {
14445    use half::f16;
14446    unsafe {
14447        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
14448            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
14449            raw.chunks_exact(2)
14450                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
14451                .collect()
14452        };
14453        let write_f16 = |off: usize, data: &[f32]| {
14454            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
14455            for (i, &v) in data.iter().enumerate() {
14456                let le = f16::from_f32(v).to_le_bytes();
14457                out[i * 2] = le[0];
14458                out[i * 2 + 1] = le[1];
14459            }
14460        };
14461
14462        let (b, s, h, n) = (batch, seq, heads, state_size);
14463        let q_f = read_f16(q, b * s * h * n);
14464        let k_f = read_f16(k, b * s * h * n);
14465        let v_f = read_f16(v, b * s * h * n);
14466        let g_f = read_f16(g, b * s * h);
14467        let b_f = read_f16(beta, b * s * h);
14468        let mut state_f = if state != 0 {
14469            read_f16(state, b * h * n * n)
14470        } else {
14471            vec![0f32; b * h * n * n]
14472        };
14473        let mut out_f = vec![0f32; b * s * h * n];
14474        let scale = 1.0f32 / (n as f32).sqrt();
14475        let mut sk_buf = vec![0f32; n];
14476        let mut owned_state = vec![0f32; h * n * n];
14477
14478        for bi in 0..b {
14479            let state_slice: &mut [f32] = if state != 0 {
14480                let start = bi * h * n * n;
14481                &mut state_f[start..start + h * n * n]
14482            } else {
14483                owned_state.fill(0.0);
14484                &mut owned_state
14485            };
14486
14487            for ti in 0..s {
14488                let qkv_step_base = bi * s * h * n + ti * h * n;
14489                let gb_step_base = bi * s * h + ti * h;
14490
14491                for hi in 0..h {
14492                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14493                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14494                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14495                    let g_t = g_f[gb_step_base + hi];
14496                    let beta_t = b_f[gb_step_base + hi];
14497
14498                    let s_base = hi * n * n;
14499                    let s_mat = &mut state_slice[s_base..s_base + n * n];
14500
14501                    let g_exp = g_t.exp();
14502                    for st in s_mat.iter_mut() {
14503                        *st *= g_exp;
14504                    }
14505
14506                    for j in 0..n {
14507                        let mut acc = 0f32;
14508                        for i in 0..n {
14509                            acc += s_mat[i * n + j] * k_row[i];
14510                        }
14511                        sk_buf[j] = acc;
14512                    }
14513
14514                    for j in 0..n {
14515                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
14516                    }
14517
14518                    for i in 0..n {
14519                        let ki = k_row[i];
14520                        for j in 0..n {
14521                            s_mat[i * n + j] += ki * sk_buf[j];
14522                        }
14523                    }
14524
14525                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
14526                    for j in 0..n {
14527                        let mut acc = 0f32;
14528                        for i in 0..n {
14529                            acc += s_mat[i * n + j] * q_row[i];
14530                        }
14531                        out_row[j] = acc * scale;
14532                    }
14533                }
14534            }
14535        }
14536
14537        write_f16(dst, &out_f);
14538        if state != 0 {
14539            write_f16(state, &state_f);
14540        }
14541    }
14542}
14543
14544/// Host fallback for NCHW group norm (Metal unified-memory arena).
14545pub unsafe fn execute_group_norm_nchw_f32(
14546    src: usize,
14547    g: usize,
14548    b: usize,
14549    dst: usize,
14550    n: usize,
14551    c: usize,
14552    h: usize,
14553    w: usize,
14554    num_groups: usize,
14555    eps: f32,
14556    base: *mut u8,
14557) {
14558    let plane = c * h * w;
14559    for ni in 0..n {
14560        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14561        let gamma = unsafe { sl(g, base, c) };
14562        let beta = unsafe { sl(b, base, c) };
14563        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
14564        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
14565    }
14566}
14567
14568/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
14569pub unsafe fn execute_layer_norm2d_nchw_f32(
14570    src: usize,
14571    g: usize,
14572    b: usize,
14573    dst: usize,
14574    n: usize,
14575    c: usize,
14576    h: usize,
14577    w: usize,
14578    eps: f32,
14579    base: *mut u8,
14580) {
14581    let plane = c * h * w;
14582    unsafe {
14583        let input = sl(src, base, n * plane);
14584        let gamma = sl(g, base, c);
14585        let beta = sl(b, base, c);
14586        let output = sl_mut(dst, base, n * plane);
14587        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
14588    }
14589}
14590
14591/// Host fallback for NCHW ConvTranspose2d.
14592pub unsafe fn execute_conv_transpose2d_nchw_f32(
14593    src: usize,
14594    weight: usize,
14595    dst: usize,
14596    n: usize,
14597    c_in: usize,
14598    h: usize,
14599    w_in: usize,
14600    c_out: usize,
14601    h_out: usize,
14602    w_out: usize,
14603    kh: usize,
14604    kw: usize,
14605    sh: usize,
14606    sw: usize,
14607    ph: usize,
14608    pw: usize,
14609    dh: usize,
14610    dw: usize,
14611    groups: usize,
14612    base: *mut u8,
14613) {
14614    let in_elems = n * c_in * h * w_in;
14615    let w_elems = c_in * (c_out / groups) * kh * kw;
14616    let out_elems = n * c_out * h_out * w_out;
14617    unsafe {
14618        let input = sl(src, base, in_elems);
14619        let wt = sl(weight, base, w_elems);
14620        let output = sl_mut(dst, base, out_elems);
14621        crate::kernels::conv_transpose2d_nchw(
14622            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
14623            dw, groups,
14624        );
14625    }
14626}
14627
14628/// Host fallback for nearest 2× upsample on NCHW.
14629pub unsafe fn execute_resize_nearest_2x_f32(
14630    src: usize,
14631    dst: usize,
14632    n: usize,
14633    c: usize,
14634    h: usize,
14635    w: usize,
14636    base: *mut u8,
14637) {
14638    let in_plane = c * h * w;
14639    let out_plane = c * h * 2 * w * 2;
14640    for ni in 0..n {
14641        let input = unsafe {
14642            sl(
14643                src + ni * in_plane * std::mem::size_of::<f32>(),
14644                base,
14645                in_plane,
14646            )
14647        };
14648        let output = unsafe {
14649            sl_mut(
14650                dst + ni * out_plane * std::mem::size_of::<f32>(),
14651                base,
14652                out_plane,
14653            )
14654        };
14655        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
14656    }
14657}
14658
14659/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
14660pub unsafe fn execute_axial_rope2d_f32(
14661    src: usize,
14662    dst: usize,
14663    batch: usize,
14664    seq: usize,
14665    hidden: usize,
14666    end_x: usize,
14667    end_y: usize,
14668    head_dim: usize,
14669    num_heads: usize,
14670    theta: f32,
14671    repeat_factor: usize,
14672    base: *mut u8,
14673) {
14674    let plane = seq * hidden;
14675    let plane_bytes = plane * std::mem::size_of::<f32>();
14676    for bi in 0..batch {
14677        let in_off = src + bi * plane_bytes;
14678        let input = unsafe { sl(in_off, base, plane) };
14679        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
14680            input,
14681            num_heads,
14682            seq,
14683            head_dim,
14684            end_x,
14685            end_y,
14686            theta,
14687            repeat_factor,
14688        );
14689        let out_off = dst + bi * plane_bytes;
14690        let output = unsafe { sl_mut(out_off, base, plane) };
14691        output.copy_from_slice(&rotated);
14692    }
14693}
14694
14695/// Ternary pruned radix-2 butterfly stage on `[batch, n_fft, 2]` interleaved state.
14696pub unsafe fn execute_fft_butterfly_stage_f32(
14697    state_src: usize,
14698    state_dst: usize,
14699    gate_src: usize,
14700    rev_src: usize,
14701    tw_re_src: usize,
14702    tw_im_src: usize,
14703    batch: usize,
14704    n_fft: usize,
14705    stage: usize,
14706    base: *mut u8,
14707) {
14708    let half = n_fft / 2;
14709    let stride = 1usize << stage;
14710    let gate = unsafe { sl(gate_src, base, half) };
14711    let rev = unsafe { sl(rev_src, base, half) };
14712    let tw_re = unsafe { sl(tw_re_src, base, half) };
14713    let tw_im = unsafe { sl(tw_im_src, base, half) };
14714    let row_elems = n_fft * 2;
14715    for b in 0..batch {
14716        let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
14717        let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
14718        let inp = unsafe { sl(in_off, base, row_elems) };
14719        let out = unsafe { sl_mut(out_off, base, row_elems) };
14720        out.copy_from_slice(inp);
14721        for bf in 0..half {
14722            if gate[bf] == 0.0 {
14723                continue;
14724            }
14725            let group = bf / stride;
14726            let k = bf % stride;
14727            let i0 = group * 2 * stride + k;
14728            let i1 = i0 + stride;
14729            let w_re = tw_re[bf];
14730            let w_im = tw_im[bf];
14731            let in_a_re = inp[i0 * 2];
14732            let in_a_im = inp[i0 * 2 + 1];
14733            let in_b_re = inp[i1 * 2];
14734            let in_b_im = inp[i1 * 2 + 1];
14735            let (b_re, b_im) = (
14736                in_b_re * w_re - in_b_im * w_im,
14737                in_b_re * w_im + in_b_im * w_re,
14738            );
14739            let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
14740            let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
14741            let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
14742                (bot_re, bot_im, top_re, top_im)
14743            } else {
14744                (top_re, top_im, bot_re, bot_im)
14745            };
14746            out[i0 * 2] = oa_re;
14747            out[i0 * 2 + 1] = oa_im;
14748            out[i1 * 2] = ob_re;
14749            out[i1 * 2 + 1] = ob_im;
14750        }
14751    }
14752}
14753
14754/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
14755pub unsafe fn execute_fft1d_f32(
14756    src: usize,
14757    dst: usize,
14758    outer: usize,
14759    n_complex: usize,
14760    inverse: bool,
14761    norm_tag: u32,
14762    base: *mut u8,
14763) {
14764    let row_elems = 2 * n_complex;
14765    let mut re = vec![0f32; n_complex];
14766    let mut im = vec![0f32; n_complex];
14767    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14768    let scale = norm.output_scale(n_complex, inverse) as f32;
14769    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14770        BluesteinScratchF32::empty()
14771    } else {
14772        BluesteinScratchF32::build(n_complex, inverse)
14773    };
14774    for o in 0..outer {
14775        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
14776        let s = unsafe { sl(row_offset, base, row_elems) };
14777        re.copy_from_slice(&s[..n_complex]);
14778        im.copy_from_slice(&s[n_complex..]);
14779        if n_complex.is_power_of_two() {
14780            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14781        } else if n_complex <= 16 {
14782            fft_naive_inplace_f32(&mut re, &mut im, inverse);
14783        } else {
14784            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14785        }
14786        if scale != 1.0 {
14787            re.iter_mut().for_each(|v| *v *= scale);
14788            im.iter_mut().for_each(|v| *v *= scale);
14789        }
14790        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
14791        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
14792        d[..n_complex].copy_from_slice(&re);
14793        d[n_complex..].copy_from_slice(&im);
14794    }
14795}
14796
14797/// C64 interleaved layout: each complex element is `[re: f32, im: f32]`.
14798pub unsafe fn execute_fft1d_c64(
14799    src: usize,
14800    dst: usize,
14801    outer: usize,
14802    n_complex: usize,
14803    inverse: bool,
14804    norm_tag: u32,
14805    base: *mut u8,
14806) {
14807    let row_bytes = n_complex * 8;
14808    let mut re = vec![0f32; n_complex];
14809    let mut im = vec![0f32; n_complex];
14810    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14811    let scale = norm.output_scale(n_complex, inverse) as f32;
14812    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14813        BluesteinScratchF32::empty()
14814    } else {
14815        BluesteinScratchF32::build(n_complex, inverse)
14816    };
14817    for o in 0..outer {
14818        let row_offset = src + o * row_bytes;
14819        for i in 0..n_complex {
14820            let elem_off = row_offset + i * 8;
14821            re[i] = f32::from_le_bytes([
14822                *base.add(elem_off),
14823                *base.add(elem_off + 1),
14824                *base.add(elem_off + 2),
14825                *base.add(elem_off + 3),
14826            ]);
14827            im[i] = f32::from_le_bytes([
14828                *base.add(elem_off + 4),
14829                *base.add(elem_off + 5),
14830                *base.add(elem_off + 6),
14831                *base.add(elem_off + 7),
14832            ]);
14833        }
14834        if n_complex.is_power_of_two() {
14835            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
14836        } else if n_complex <= 16 {
14837            fft_naive_inplace_f32(&mut re, &mut im, inverse);
14838        } else {
14839            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
14840        }
14841        if scale != 1.0 {
14842            re.iter_mut().for_each(|v| *v *= scale);
14843            im.iter_mut().for_each(|v| *v *= scale);
14844        }
14845        let dst_row = dst + o * row_bytes;
14846        for i in 0..n_complex {
14847            let elem_off = dst_row + i * 8;
14848            let re_b = re[i].to_le_bytes();
14849            let im_b = im[i].to_le_bytes();
14850            for j in 0..4 {
14851                *base.add(elem_off + j) = re_b[j];
14852                *base.add(elem_off + 4 + j) = im_b[j];
14853            }
14854        }
14855    }
14856}
14857
14858/// Dtype-dispatching host entry for `Op::LogMel` (shared by GPU host fallbacks).
14859pub unsafe fn execute_log_mel(
14860    spec: usize,
14861    filters: usize,
14862    dst: usize,
14863    outer: usize,
14864    n_fft: usize,
14865    n_bins: usize,
14866    n_mels: usize,
14867    base: *mut u8,
14868) {
14869    execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
14870}
14871
14872pub unsafe fn execute_log_mel_f32(
14873    spec: usize,
14874    filters: usize,
14875    dst: usize,
14876    outer: usize,
14877    n_fft: usize,
14878    n_bins: usize,
14879    n_mels: usize,
14880    base: *mut u8,
14881) {
14882    let spec_ptr = base.add(spec) as *const f32;
14883    let filt_ptr = base.add(filters) as *const f32;
14884    let dst_ptr = base.add(dst) as *mut f32;
14885    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14886    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14887    let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
14888    rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
14889}
14890
14891pub unsafe fn execute_welch_peaks_f32(
14892    spec: usize,
14893    dst: usize,
14894    welch_batch: usize,
14895    n_fft: usize,
14896    n_segments: usize,
14897    k: usize,
14898    base: *mut u8,
14899) {
14900    let spec_ptr = base.add(spec) as *const f32;
14901    let dst_ptr = base.add(dst) as *mut f32;
14902    let outer = welch_batch * n_segments;
14903    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14904    let out = std::slice::from_raw_parts_mut(dst_ptr, welch_batch * k * 2);
14905    rlx_ir::audio::welch_peaks_block_f32(spec, welch_batch, n_fft, n_segments, k, out);
14906}
14907
14908pub unsafe fn execute_log_mel_backward_f32(
14909    spec: usize,
14910    filters: usize,
14911    dy: usize,
14912    dst: usize,
14913    outer: usize,
14914    n_fft: usize,
14915    n_bins: usize,
14916    n_mels: usize,
14917    base: *mut u8,
14918) {
14919    let spec_ptr = base.add(spec) as *const f32;
14920    let filt_ptr = base.add(filters) as *const f32;
14921    let dy_ptr = base.add(dy) as *const f32;
14922    let dst_ptr = base.add(dst) as *mut f32;
14923    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
14924    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
14925    let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
14926    let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
14927    d_spec.fill(0.0);
14928    rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
14929}
14930
14931/// Dtype-dispatching host entry for `Op::Fft` (shared by GPU host fallbacks).
14932pub unsafe fn execute_fft1d(
14933    src: usize,
14934    dst: usize,
14935    outer: usize,
14936    n_complex: usize,
14937    inverse: bool,
14938    norm_tag: u32,
14939    dtype: rlx_ir::DType,
14940    base: *mut u8,
14941) {
14942    match dtype {
14943        rlx_ir::DType::F32 => {
14944            execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
14945        }
14946        rlx_ir::DType::F64 => {
14947            execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
14948        }
14949        rlx_ir::DType::C64 => {
14950            execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
14951        }
14952        other => panic!("execute_fft1d: unsupported dtype {other:?}"),
14953    }
14954}
14955
14956/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
14957/// the f64 path; twiddle recurrence is kept in f64 so accumulated
14958/// rotation drift doesn't dominate the per-stage error budget at
14959/// larger N.
14960fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
14961    let n = re.len();
14962    debug_assert_eq!(im.len(), n);
14963    debug_assert!(
14964        n.is_power_of_two(),
14965        "fft_radix2_f32: n={n} must be a power of two"
14966    );
14967    if n <= 1 {
14968        return;
14969    }
14970
14971    let mut j = 0usize;
14972    for i in 1..n {
14973        let mut bit = n >> 1;
14974        while j & bit != 0 {
14975            j ^= bit;
14976            bit >>= 1;
14977        }
14978        j ^= bit;
14979        if i < j {
14980            re.swap(i, j);
14981            im.swap(i, j);
14982        }
14983    }
14984
14985    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
14986    let mut len = 2usize;
14987    while len <= n {
14988        let half = len / 2;
14989        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
14990        let w_re_step = theta.cos();
14991        let w_im_step = theta.sin();
14992        let mut i = 0usize;
14993        while i < n {
14994            let mut wre = 1.0_f64;
14995            let mut wim = 0.0_f64;
14996            for k in 0..half {
14997                let wre_f = wre as f32;
14998                let wim_f = wim as f32;
14999                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
15000                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
15001                let u_re = re[i + k];
15002                let u_im = im[i + k];
15003                re[i + k] = u_re + t_re;
15004                im[i + k] = u_im + t_im;
15005                re[i + k + half] = u_re - t_re;
15006                im[i + k + half] = u_im - t_im;
15007                let new_wre = wre * w_re_step - wim * w_im_step;
15008                let new_wim = wre * w_im_step + wim * w_re_step;
15009                wre = new_wre;
15010                wim = new_wim;
15011            }
15012            i += len;
15013        }
15014        len <<= 1;
15015    }
15016}
15017
15018/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
15019/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
15020/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
15021fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15022    let n = re.len();
15023    debug_assert_eq!(im.len(), n);
15024    debug_assert!(
15025        n.is_power_of_two(),
15026        "fft_radix2: n={n} must be a power of two"
15027    );
15028    if n <= 1 {
15029        return;
15030    }
15031
15032    // Bit-reverse permutation.
15033    let mut j = 0usize;
15034    for i in 1..n {
15035        let mut bit = n >> 1;
15036        while j & bit != 0 {
15037            j ^= bit;
15038            bit >>= 1;
15039        }
15040        j ^= bit;
15041        if i < j {
15042            re.swap(i, j);
15043            im.swap(i, j);
15044        }
15045    }
15046
15047    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
15048    let sign = if inverse { 1.0 } else { -1.0 };
15049    let mut len = 2usize;
15050    while len <= n {
15051        let half = len / 2;
15052        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15053        let w_re_step = theta.cos();
15054        let w_im_step = theta.sin();
15055        let mut i = 0usize;
15056        while i < n {
15057            // Twiddle starts at 1+0i for each segment.
15058            let mut wre = 1.0_f64;
15059            let mut wim = 0.0_f64;
15060            for k in 0..half {
15061                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
15062                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
15063                let u_re = re[i + k];
15064                let u_im = im[i + k];
15065                re[i + k] = u_re + t_re;
15066                im[i + k] = u_im + t_im;
15067                re[i + k + half] = u_re - t_re;
15068                im[i + k + half] = u_im - t_im;
15069                let new_wre = wre * w_re_step - wim * w_im_step;
15070                let new_wim = wre * w_im_step + wim * w_re_step;
15071                wre = new_wre;
15072                wim = new_wim;
15073            }
15074            i += len;
15075        }
15076        len <<= 1;
15077    }
15078}
15079
15080/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
15081/// Built once per call to `execute_fft1d_f64` and reused across rows
15082/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
15083struct BluesteinScratchF64 {
15084    /// Power-of-two convolution length, ≥ 2N - 1.
15085    m: usize,
15086    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
15087    /// the requested direction. Forward chirp on the way in, output
15088    /// chirp on the way out.
15089    w_re: Vec<f64>,
15090    w_im: Vec<f64>,
15091    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
15092    /// Doesn't depend on the input — precomputed once.
15093    bf_re: Vec<f64>,
15094    bf_im: Vec<f64>,
15095    /// Workspace reused per row (avoids per-row allocation).
15096    ar: Vec<f64>,
15097    ai: Vec<f64>,
15098}
15099
15100impl BluesteinScratchF64 {
15101    fn empty() -> Self {
15102        Self {
15103            m: 0,
15104            w_re: Vec::new(),
15105            w_im: Vec::new(),
15106            bf_re: Vec::new(),
15107            bf_im: Vec::new(),
15108            ar: Vec::new(),
15109            ai: Vec::new(),
15110        }
15111    }
15112
15113    fn build(n: usize, inverse: bool) -> Self {
15114        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
15115        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
15116        let m = if n <= 1 {
15117            1
15118        } else {
15119            (2 * n - 1).next_power_of_two()
15120        };
15121
15122        // Chirp arg reduced via k² mod 2N — without this, large N
15123        // bleeds precision into the trig call (n² grows quadratically).
15124        let mod_2n = (2 * n) as u64;
15125        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15126        let mut w_re = vec![0.0_f64; n];
15127        let mut w_im = vec![0.0_f64; n];
15128        for k in 0..n {
15129            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15130            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15131            w_re[k] = theta.cos();
15132            w_im[k] = theta.sin();
15133        }
15134
15135        // Embed b[k] = conj(w[|k|]) into length M with the negative
15136        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
15137        let mut bf_re = vec![0.0_f64; m];
15138        let mut bf_im = vec![0.0_f64; m];
15139        if n > 0 {
15140            bf_re[0] = w_re[0];
15141            bf_im[0] = -w_im[0];
15142            for k in 1..n {
15143                bf_re[k] = w_re[k];
15144                bf_im[k] = -w_im[k];
15145                bf_re[m - k] = w_re[k];
15146                bf_im[m - k] = -w_im[k];
15147            }
15148        }
15149        if m > 1 {
15150            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15151        }
15152
15153        Self {
15154            m,
15155            w_re,
15156            w_im,
15157            bf_re,
15158            bf_im,
15159            ar: vec![0.0_f64; m],
15160            ai: vec![0.0_f64; m],
15161        }
15162    }
15163}
15164
15165/// Direct O(N²) DFT for small non-pow2 N (faster than Bluestein setup).
15166fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15167    let n = re.len();
15168    if n <= 1 {
15169        return;
15170    }
15171    let sign = if inverse { 1.0 } else { -1.0 };
15172    let mut out_re = vec![0.0_f64; n];
15173    let mut out_im = vec![0.0_f64; n];
15174    for k in 0..n {
15175        for nn in 0..n {
15176            let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15177            let c = theta.cos();
15178            let s = theta.sin();
15179            out_re[k] += re[nn] * c - im[nn] * s;
15180            out_im[k] += re[nn] * s + im[nn] * c;
15181        }
15182    }
15183    re.copy_from_slice(&out_re);
15184    im.copy_from_slice(&out_im);
15185}
15186
15187fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15188    let n = re.len();
15189    if n <= 1 {
15190        return;
15191    }
15192    let sign = if inverse { 1.0f32 } else { -1.0f32 };
15193    let mut out_re = vec![0.0_f32; n];
15194    let mut out_im = vec![0.0_f32; n];
15195    for k in 0..n {
15196        for nn in 0..n {
15197            let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
15198            let c = theta.cos();
15199            let s = theta.sin();
15200            out_re[k] += re[nn] * c - im[nn] * s;
15201            out_im[k] += re[nn] * s + im[nn] * c;
15202        }
15203    }
15204    re.copy_from_slice(&out_re);
15205    im.copy_from_slice(&out_im);
15206}
15207
15208/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
15209///   `n·k = (n² + k² - (k-n)²) / 2`
15210/// which lets the DFT be written as a linear convolution sandwiched
15211/// between two chirp multiplies:
15212///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
15213/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
15214/// Both directions stay unnormalized to match the radix-2 path, so the
15215/// chain rule keeps working without scaling.
15216fn fft_bluestein_inplace_f64(
15217    re: &mut [f64],
15218    im: &mut [f64],
15219    _inverse: bool,
15220    s: &mut BluesteinScratchF64,
15221) {
15222    let n = re.len();
15223    debug_assert_eq!(im.len(), n);
15224    debug_assert_eq!(s.w_re.len(), n);
15225    if n <= 1 {
15226        return;
15227    }
15228    let m = s.m;
15229
15230    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
15231    for k in 0..m {
15232        s.ar[k] = 0.0;
15233        s.ai[k] = 0.0;
15234    }
15235    for k in 0..n {
15236        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15237        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15238    }
15239
15240    // Length-M forward FFT of the padded chirped input.
15241    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
15242
15243    // Pointwise product with FFT(b). Stored back into (ar, ai).
15244    for k in 0..m {
15245        let ar = s.ar[k];
15246        let ai = s.ai[k];
15247        let br = s.bf_re[k];
15248        let bi = s.bf_im[k];
15249        s.ar[k] = ar * br - ai * bi;
15250        s.ai[k] = ar * bi + ai * br;
15251    }
15252
15253    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
15254    // divide by M to recover the true circular convolution.
15255    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
15256    let inv_m = 1.0 / (m as f64);
15257
15258    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
15259    for k in 0..n {
15260        let yr = s.ar[k] * inv_m;
15261        let yi = s.ai[k] * inv_m;
15262        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15263        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15264    }
15265}
15266
15267/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
15268/// precision (same justification as the radix-2 f32 path: twiddles in
15269/// f64, butterflies in f32). The actual conv buffers are f32.
15270struct BluesteinScratchF32 {
15271    m: usize,
15272    w_re: Vec<f32>,
15273    w_im: Vec<f32>,
15274    bf_re: Vec<f32>,
15275    bf_im: Vec<f32>,
15276    ar: Vec<f32>,
15277    ai: Vec<f32>,
15278}
15279
15280impl BluesteinScratchF32 {
15281    fn empty() -> Self {
15282        Self {
15283            m: 0,
15284            w_re: Vec::new(),
15285            w_im: Vec::new(),
15286            bf_re: Vec::new(),
15287            bf_im: Vec::new(),
15288            ar: Vec::new(),
15289            ai: Vec::new(),
15290        }
15291    }
15292
15293    fn build(n: usize, inverse: bool) -> Self {
15294        let m = if n <= 1 {
15295            1
15296        } else {
15297            (2 * n - 1).next_power_of_two()
15298        };
15299
15300        let mod_2n = (2 * n) as u64;
15301        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15302        let mut w_re = vec![0.0_f32; n];
15303        let mut w_im = vec![0.0_f32; n];
15304        for k in 0..n {
15305            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15306            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15307            w_re[k] = theta.cos() as f32;
15308            w_im[k] = theta.sin() as f32;
15309        }
15310
15311        let mut bf_re = vec![0.0_f32; m];
15312        let mut bf_im = vec![0.0_f32; m];
15313        if n > 0 {
15314            bf_re[0] = w_re[0];
15315            bf_im[0] = -w_im[0];
15316            for k in 1..n {
15317                bf_re[k] = w_re[k];
15318                bf_im[k] = -w_im[k];
15319                bf_re[m - k] = w_re[k];
15320                bf_im[m - k] = -w_im[k];
15321            }
15322        }
15323        if m > 1 {
15324            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
15325        }
15326
15327        Self {
15328            m,
15329            w_re,
15330            w_im,
15331            bf_re,
15332            bf_im,
15333            ar: vec![0.0_f32; m],
15334            ai: vec![0.0_f32; m],
15335        }
15336    }
15337}
15338
15339fn fft_bluestein_inplace_f32(
15340    re: &mut [f32],
15341    im: &mut [f32],
15342    _inverse: bool,
15343    s: &mut BluesteinScratchF32,
15344) {
15345    let n = re.len();
15346    debug_assert_eq!(im.len(), n);
15347    debug_assert_eq!(s.w_re.len(), n);
15348    if n <= 1 {
15349        return;
15350    }
15351    let m = s.m;
15352
15353    for k in 0..m {
15354        s.ar[k] = 0.0;
15355        s.ai[k] = 0.0;
15356    }
15357    for k in 0..n {
15358        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
15359        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
15360    }
15361
15362    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
15363
15364    for k in 0..m {
15365        let ar = s.ar[k];
15366        let ai = s.ai[k];
15367        let br = s.bf_re[k];
15368        let bi = s.bf_im[k];
15369        s.ar[k] = ar * br - ai * bi;
15370        s.ai[k] = ar * bi + ai * br;
15371    }
15372
15373    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
15374    let inv_m = 1.0_f32 / (m as f32);
15375
15376    for k in 0..n {
15377        let yr = s.ar[k] * inv_m;
15378        let yi = s.ai[k] * inv_m;
15379        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
15380        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
15381    }
15382}
15383
15384/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
15385/// [`CpuTensorRef`] for each input *at that input's declared dtype*
15386/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
15387/// typed slices) and a [`CpuTensorMut`] for the output, then calls
15388/// the kernel's single `execute` method.
15389unsafe fn dispatch_custom_op(
15390    kernel: &dyn crate::op_registry::CpuKernel,
15391    inputs: &[(usize, u32, Shape)],
15392    out_off: usize,
15393    out_len: u32,
15394    out_shape: &Shape,
15395    attrs: &[u8],
15396    base: *mut u8,
15397) {
15398    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
15399    use rlx_ir::DType;
15400
15401    // One arm per `DType` variant — single source of truth for
15402    // "which dtypes the CPU custom-op dispatcher wires." If a new
15403    // DType lands in `rlx-ir`, the compiler flags this match as
15404    // non-exhaustive and the gap gets named at the right place.
15405    macro_rules! build_in_view {
15406        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
15407            CpuTensorRef::$variant {
15408                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
15409                shape: $shape,
15410            }
15411        };
15412    }
15413    macro_rules! build_out_view {
15414        ($variant:ident, $rust_ty:ty) => {
15415            CpuTensorMut::$variant {
15416                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
15417                shape: out_shape,
15418            }
15419        };
15420    }
15421
15422    let in_views: Vec<CpuTensorRef<'_>> = inputs
15423        .iter()
15424        .map(|(off, len, shape)| {
15425            let n = *len as usize;
15426            let off = *off;
15427            match shape.dtype() {
15428                DType::F32 => build_in_view!(shape, off, n, F32, f32),
15429                DType::F64 => build_in_view!(shape, off, n, F64, f64),
15430                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
15431                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
15432                DType::I8 => build_in_view!(shape, off, n, I8, i8),
15433                DType::I16 => build_in_view!(shape, off, n, I16, i16),
15434                DType::I32 => build_in_view!(shape, off, n, I32, i32),
15435                DType::I64 => build_in_view!(shape, off, n, I64, i64),
15436                DType::U8 => build_in_view!(shape, off, n, U8, u8),
15437                DType::U32 => build_in_view!(shape, off, n, U32, u32),
15438                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
15439                // C64 isn't a CpuTensor variant today; the user-registered
15440                // op_registry path doesn't see complex inputs (those are
15441                // handled by built-in ops with dedicated kernels).
15442                DType::C64 => panic!(
15443                    "Op::Custom kernel input has DType::C64 — built-in \
15444                 complex ops handle their own kernels; user-registered \
15445                 ops don't yet see complex tensors"
15446                ),
15447            }
15448        })
15449        .collect();
15450
15451    let result = match out_shape.dtype() {
15452        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
15453        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
15454        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
15455        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
15456        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
15457        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
15458        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
15459        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
15460        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
15461        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
15462        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
15463        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
15464    };
15465    if let Err(e) = result {
15466        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
15467    }
15468}
15469
15470/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
15471/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
15472/// uses them at call sites with concrete dtypes); the custom-op
15473/// dispatcher uses these to enumerate every `DType` uniformly without
15474/// listing one helper per dtype.
15475#[inline(always)]
15476unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
15477    if offset == usize::MAX {
15478        return &[];
15479    }
15480    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
15481}
15482
15483#[inline(always)]
15484unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
15485    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
15486}
15487
15488// Unsafe helpers to create slices from arena base + offset
15489#[inline(always)]
15490/// In-place per-element activation. Mirrors the dispatch in
15491/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
15492/// apply the activation after `bias_add` for all non-Gelu cases.
15493fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
15494    use rlx_ir::op::Activation;
15495    match act {
15496        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
15497        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
15498        Activation::Silu => crate::kernels::par_silu_inplace(d),
15499        Activation::Relu => {
15500            for v in d.iter_mut() {
15501                *v = v.max(0.0);
15502            }
15503        }
15504        Activation::Sigmoid => {
15505            for v in d.iter_mut() {
15506                *v = 1.0 / (1.0 + (-*v).exp());
15507            }
15508        }
15509        Activation::Tanh => {
15510            for v in d.iter_mut() {
15511                *v = v.tanh();
15512            }
15513        }
15514        Activation::Exp => {
15515            for v in d.iter_mut() {
15516                *v = v.exp();
15517            }
15518        }
15519        Activation::Log => {
15520            for v in d.iter_mut() {
15521                *v = v.ln();
15522            }
15523        }
15524        Activation::Sqrt => {
15525            for v in d.iter_mut() {
15526                *v = v.sqrt();
15527            }
15528        }
15529        Activation::Rsqrt => {
15530            for v in d.iter_mut() {
15531                *v = 1.0 / v.sqrt();
15532            }
15533        }
15534        Activation::Neg => {
15535            for v in d.iter_mut() {
15536                *v = -*v;
15537            }
15538        }
15539        Activation::Abs => {
15540            for v in d.iter_mut() {
15541                *v = v.abs();
15542            }
15543        }
15544        Activation::Round => {
15545            for v in d.iter_mut() {
15546                *v = v.round();
15547            }
15548        }
15549        Activation::Sin => {
15550            for v in d.iter_mut() {
15551                *v = v.sin();
15552            }
15553        }
15554        Activation::Cos => {
15555            for v in d.iter_mut() {
15556                *v = v.cos();
15557            }
15558        }
15559        Activation::Tan => {
15560            for v in d.iter_mut() {
15561                *v = v.tan();
15562            }
15563        }
15564        Activation::Atan => {
15565            for v in d.iter_mut() {
15566                *v = v.atan();
15567            }
15568        }
15569    }
15570}
15571
15572/// im2col for one image (single batch + group slice).
15573///
15574/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
15575/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
15576/// (in the padded region) are written as 0.
15577///
15578/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
15579///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
15580#[allow(clippy::too_many_arguments)]
15581fn im2col(
15582    x: &[f32],
15583    col: &mut [f32],
15584    c_in: usize,
15585    h: usize,
15586    w: usize,
15587    h_out: usize,
15588    w_out: usize,
15589    kh: usize,
15590    kw: usize,
15591    sh: usize,
15592    sw: usize,
15593    ph: usize,
15594    pw: usize,
15595    dh: usize,
15596    dw_dil: usize,
15597) {
15598    let n_dim = h_out * w_out;
15599    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15600    debug_assert_eq!(x.len(), c_in * h * w);
15601    let h_isz = h as isize;
15602    let w_isz = w as isize;
15603    let ph_isz = ph as isize;
15604    let pw_isz = pw as isize;
15605    for ci in 0..c_in {
15606        for ki in 0..kh {
15607            for kj in 0..kw {
15608                let row = ((ci * kh) + ki) * kw + kj;
15609                let row_off = row * n_dim;
15610                for ho in 0..h_out {
15611                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
15612                    if hi < 0 || hi >= h_isz {
15613                        for wo in 0..w_out {
15614                            col[row_off + ho * w_out + wo] = 0.0;
15615                        }
15616                        continue;
15617                    }
15618                    let hi = hi as usize;
15619                    let in_row_off = (ci * h + hi) * w;
15620                    for wo in 0..w_out {
15621                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15622                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
15623                            0.0
15624                        } else {
15625                            x[in_row_off + wi as usize]
15626                        };
15627                    }
15628                }
15629            }
15630        }
15631    }
15632}
15633
15634/// col2im — inverse of `im2col` with scatter-accumulation. The caller
15635/// is responsible for zeroing `x` if it doesn't already start zero
15636/// (the conv-input-grad path zeros once before the batch loop).
15637///
15638/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
15639/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
15640#[allow(clippy::too_many_arguments)]
15641fn col2im(
15642    col: &[f32],
15643    x: &mut [f32],
15644    c_in: usize,
15645    h: usize,
15646    w: usize,
15647    h_out: usize,
15648    w_out: usize,
15649    kh: usize,
15650    kw: usize,
15651    sh: usize,
15652    sw: usize,
15653    ph: usize,
15654    pw: usize,
15655    dh: usize,
15656    dw_dil: usize,
15657) {
15658    let n_dim = h_out * w_out;
15659    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
15660    debug_assert_eq!(x.len(), c_in * h * w);
15661    let h_isz = h as isize;
15662    let w_isz = w as isize;
15663    let ph_isz = ph as isize;
15664    let pw_isz = pw as isize;
15665    for ci in 0..c_in {
15666        for ki in 0..kh {
15667            for kj in 0..kw {
15668                let row = ((ci * kh) + ki) * kw + kj;
15669                let row_off = row * n_dim;
15670                for ho in 0..h_out {
15671                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
15672                    if hi < 0 || hi >= h_isz {
15673                        continue;
15674                    }
15675                    let hi = hi as usize;
15676                    let in_row_off = (ci * h + hi) * w;
15677                    for wo in 0..w_out {
15678                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
15679                        if wi < 0 || wi >= w_isz {
15680                            continue;
15681                        }
15682                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
15683                    }
15684                }
15685            }
15686        }
15687    }
15688}
15689
15690/// Element-wise backward for `Op::Activation`. `xs` is the original
15691/// input to the forward activation; `dys` is the upstream gradient.
15692/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
15693/// Decompose a per-channel quantization shape into the
15694/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
15695/// flat output index to a channel index. Per-tensor (`axis = None`)
15696/// degenerates to `chan_dim = 1, inner = len`, which makes the
15697/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
15698/// scalar version used.
15699fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
15700    match axis {
15701        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
15702        Some(d) => {
15703            let chan_dim = shape.dim(d).unwrap_static();
15704            let inner: usize = (d + 1..shape.rank())
15705                .map(|i| shape.dim(i).unwrap_static())
15706                .product::<usize>()
15707                .max(1);
15708            (d, chan_dim, inner)
15709        }
15710    }
15711}
15712
15713fn activation_backward_kernel(
15714    act: rlx_ir::op::Activation,
15715    xs: &[f32],
15716    dys: &[f32],
15717    out: &mut [f32],
15718) {
15719    use rlx_ir::op::Activation;
15720    let n = xs.len();
15721    debug_assert_eq!(dys.len(), n);
15722    debug_assert_eq!(out.len(), n);
15723    match act {
15724        Activation::Relu => {
15725            for i in 0..n {
15726                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15727            }
15728        }
15729        Activation::Sigmoid => {
15730            for i in 0..n {
15731                let s = 1.0 / (1.0 + (-xs[i]).exp());
15732                out[i] = s * (1.0 - s) * dys[i];
15733            }
15734        }
15735        Activation::Tanh => {
15736            for i in 0..n {
15737                let t = xs[i].tanh();
15738                out[i] = (1.0 - t * t) * dys[i];
15739            }
15740        }
15741        Activation::Silu => {
15742            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
15743            for i in 0..n {
15744                let s = 1.0 / (1.0 + (-xs[i]).exp());
15745                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15746            }
15747        }
15748        Activation::Gelu => {
15749            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
15750            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
15751            const INV_SQRT2: f32 = 0.707_106_77;
15752            const INV_SQRT_2PI: f32 = 0.398_942_3;
15753            for i in 0..n {
15754                let x = xs[i];
15755                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
15756                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15757                out[i] = (phi + x * pdf) * dys[i];
15758            }
15759        }
15760        Activation::GeluApprox => {
15761            // Tanh-approximation:
15762            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
15763            const C: f32 = 0.797_884_6; // √(2/π)
15764            const A: f32 = 0.044_715;
15765            for i in 0..n {
15766                let x = xs[i];
15767                let inner = C * (x + A * x * x * x);
15768                let t = inner.tanh();
15769                let dinner = C * (1.0 + 3.0 * A * x * x);
15770                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
15771                out[i] = d * dys[i];
15772            }
15773        }
15774        Activation::Exp => {
15775            for i in 0..n {
15776                out[i] = xs[i].exp() * dys[i];
15777            }
15778        }
15779        Activation::Log => {
15780            for i in 0..n {
15781                out[i] = dys[i] / xs[i];
15782            }
15783        }
15784        Activation::Sqrt => {
15785            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
15786            for i in 0..n {
15787                let s = xs[i].sqrt();
15788                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15789            }
15790        }
15791        Activation::Rsqrt => {
15792            // d/dx (1/√x) = -0.5 · x^(-3/2).
15793            for i in 0..n {
15794                let s = xs[i].sqrt();
15795                out[i] = if s > 0.0 {
15796                    -0.5 * dys[i] / (xs[i] * s)
15797                } else {
15798                    0.0
15799                };
15800            }
15801        }
15802        Activation::Neg => {
15803            for i in 0..n {
15804                out[i] = -dys[i];
15805            }
15806        }
15807        Activation::Abs => {
15808            // sign(x); 0 at x=0.
15809            for i in 0..n {
15810                let x = xs[i];
15811                let s = if x > 0.0 {
15812                    1.0
15813                } else if x < 0.0 {
15814                    -1.0
15815                } else {
15816                    0.0
15817                };
15818                out[i] = s * dys[i];
15819            }
15820        }
15821        Activation::Round => {
15822            // STE: pretend the round was identity in the backward
15823            // pass. The round step has zero gradient almost
15824            // everywhere, so without this trick the optimizer can't
15825            // learn through it.
15826            out.copy_from_slice(dys);
15827        }
15828        Activation::Sin => {
15829            // d/dx sin(x) = cos(x).
15830            for i in 0..n {
15831                out[i] = xs[i].cos() * dys[i];
15832            }
15833        }
15834        Activation::Cos => {
15835            for i in 0..n {
15836                out[i] = -xs[i].sin() * dys[i];
15837            }
15838        }
15839        Activation::Tan => {
15840            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
15841            for i in 0..n {
15842                let t = xs[i].tan();
15843                out[i] = (1.0 + t * t) * dys[i];
15844            }
15845        }
15846        Activation::Atan => {
15847            // d/dx atan(x) = 1 / (1 + x²)
15848            for i in 0..n {
15849                let x = xs[i];
15850                out[i] = dys[i] / (1.0 + x * x);
15851            }
15852        }
15853    }
15854}
15855
15856/// f64 sibling of `activation_backward_kernel`. Same math, twice the
15857/// precision — used by f64 graphs where the f32 kernel reading bytes
15858/// as `&[f32]` would silently discard half of every f64 value.
15859fn activation_backward_kernel_f64(
15860    act: rlx_ir::op::Activation,
15861    xs: &[f64],
15862    dys: &[f64],
15863    out: &mut [f64],
15864) {
15865    use rlx_ir::op::Activation;
15866    let n = xs.len();
15867    debug_assert_eq!(dys.len(), n);
15868    debug_assert_eq!(out.len(), n);
15869    match act {
15870        Activation::Relu => {
15871            for i in 0..n {
15872                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
15873            }
15874        }
15875        Activation::Sigmoid => {
15876            for i in 0..n {
15877                let s = 1.0 / (1.0 + (-xs[i]).exp());
15878                out[i] = s * (1.0 - s) * dys[i];
15879            }
15880        }
15881        Activation::Tanh => {
15882            for i in 0..n {
15883                let t = xs[i].tanh();
15884                out[i] = (1.0 - t * t) * dys[i];
15885            }
15886        }
15887        Activation::Silu => {
15888            for i in 0..n {
15889                let s = 1.0 / (1.0 + (-xs[i]).exp());
15890                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
15891            }
15892        }
15893        Activation::Gelu | Activation::GeluApprox => {
15894            // Both rare on f64 paths; use the high-quality libm erf.
15895            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
15896            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
15897            for i in 0..n {
15898                let x = xs[i];
15899                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
15900                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
15901                out[i] = (phi + x * pdf) * dys[i];
15902            }
15903        }
15904        Activation::Exp => {
15905            for i in 0..n {
15906                out[i] = xs[i].exp() * dys[i];
15907            }
15908        }
15909        Activation::Log => {
15910            for i in 0..n {
15911                out[i] = dys[i] / xs[i];
15912            }
15913        }
15914        Activation::Sqrt => {
15915            for i in 0..n {
15916                let s = xs[i].sqrt();
15917                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
15918            }
15919        }
15920        Activation::Rsqrt => {
15921            for i in 0..n {
15922                let s = xs[i].sqrt();
15923                out[i] = if s > 0.0 {
15924                    -0.5 * dys[i] / (xs[i] * s)
15925                } else {
15926                    0.0
15927                };
15928            }
15929        }
15930        Activation::Neg => {
15931            for i in 0..n {
15932                out[i] = -dys[i];
15933            }
15934        }
15935        Activation::Abs => {
15936            for i in 0..n {
15937                let x = xs[i];
15938                let s = if x > 0.0 {
15939                    1.0
15940                } else if x < 0.0 {
15941                    -1.0
15942                } else {
15943                    0.0
15944                };
15945                out[i] = s * dys[i];
15946            }
15947        }
15948        Activation::Round => {
15949            out.copy_from_slice(dys);
15950        }
15951        Activation::Sin => {
15952            for i in 0..n {
15953                out[i] = xs[i].cos() * dys[i];
15954            }
15955        }
15956        Activation::Cos => {
15957            for i in 0..n {
15958                out[i] = -xs[i].sin() * dys[i];
15959            }
15960        }
15961        Activation::Tan => {
15962            for i in 0..n {
15963                let t = xs[i].tan();
15964                out[i] = (1.0 + t * t) * dys[i];
15965            }
15966        }
15967        Activation::Atan => {
15968            for i in 0..n {
15969                let x = xs[i];
15970                out[i] = dys[i] / (1.0 + x * x);
15971            }
15972        }
15973    }
15974}
15975
15976/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
15977/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
15978/// arithmetic). Adequate for gradient kernels; if higher precision is
15979/// needed, swap in a libm dependency.
15980#[inline(always)]
15981fn erf_f64(x: f64) -> f64 {
15982    let s = x.signum();
15983    let x = x.abs();
15984    let t = 1.0 / (1.0 + 0.327_591_1 * x);
15985    let y = 1.0
15986        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
15987            + 0.254_829_59)
15988            * t
15989            * (-x * x).exp();
15990    s * y
15991}
15992
15993/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
15994/// over all of ℝ — plenty for f32 gradient kernels).
15995#[inline(always)]
15996fn erf_f32(x: f32) -> f32 {
15997    let s = x.signum();
15998    let x = x.abs();
15999    let t = 1.0 / (1.0 + 0.327_591_1 * x);
16000    let y = 1.0
16001        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
16002            + 0.254_829_6)
16003            * t
16004            * (-x * x).exp();
16005    s * y
16006}
16007
16008fn narrow_thunk_closure(
16009    src: usize,
16010    dst: usize,
16011    outer: u32,
16012    src_stride: u32,
16013    dst_stride: u32,
16014    inner: u32,
16015    elem_bytes: u8,
16016) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
16017    let (outer, ss, ds, inner, eb) = (
16018        outer as usize,
16019        src_stride as usize,
16020        dst_stride as usize,
16021        inner as usize,
16022        elem_bytes as usize,
16023    );
16024    let row_bytes = inner.saturating_mul(eb);
16025    let src_row_stride = ss.saturating_mul(eb);
16026    let dst_row_stride = ds.saturating_mul(eb);
16027    Arc::new(move |base: *mut u8| unsafe {
16028        if row_bytes == 0 || src == dst {
16029            return;
16030        }
16031        // Compiled-fn path has no arena length; skip if offsets look bogus.
16032        let arena_len = usize::MAX;
16033        for o in 0..outer {
16034            let s_off = src + o * src_row_stride;
16035            let d_off = dst + o * dst_row_stride;
16036            if s_off == d_off {
16037                continue;
16038            }
16039            if s_off.saturating_add(row_bytes) > arena_len
16040                || d_off.saturating_add(row_bytes) > arena_len
16041            {
16042                break;
16043            }
16044            std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
16045        }
16046    })
16047}
16048
16049unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
16050    if offset == usize::MAX {
16051        return &[];
16052    }
16053    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
16054}
16055
16056#[inline(always)]
16057unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
16058    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
16059}
16060
16061#[inline(always)]
16062unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
16063    if offset == usize::MAX {
16064        return &[];
16065    }
16066    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
16067}
16068
16069#[inline(always)]
16070unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
16071    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
16072}
16073
16074// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
16075// integer-tensor thunks that haven't landed yet (Sample, Gather index
16076// buffers); deleting them now would force re-deriving the unsafe
16077// boilerplate when the next int-typed thunk lands.
16078#[inline(always)]
16079#[allow(dead_code)]
16080unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16081    if offset == usize::MAX {
16082        return &[];
16083    }
16084    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16085}
16086
16087#[inline(always)]
16088#[allow(dead_code)]
16089unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16090    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16091}
16092
16093#[inline(always)]
16094unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16095    if offset == usize::MAX {
16096        return &[];
16097    }
16098    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16099}
16100
16101#[inline(always)]
16102unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16103    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16104}
16105
16106/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
16107/// the output shape; `in_strides` gives the source stride for each
16108/// output dim (broadcast axes have stride 0).
16109fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16110    let rank = out_dims.len();
16111    let mut idx = vec![0u32; rank];
16112    for o in 0..out.len() {
16113        let mut src_off = 0usize;
16114        for d in 0..rank {
16115            src_off += idx[d] as usize * in_strides[d] as usize;
16116        }
16117        out[o] = inp[src_off];
16118        // Increment index — last dim varies fastest.
16119        for d in (0..rank).rev() {
16120            idx[d] += 1;
16121            if idx[d] < out_dims[d] {
16122                break;
16123            }
16124            idx[d] = 0;
16125        }
16126    }
16127}
16128
16129/// f64 elementwise activation. Reads `inp`, writes `out`. For now
16130/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
16131/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
16132/// transcendental-free subset). Approximate Gelu/Silu deferred until a
16133/// workload demands them at f64.
16134fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16135    match kind {
16136        Activation::Neg => {
16137            for (o, &v) in out.iter_mut().zip(inp) {
16138                *o = -v;
16139            }
16140        }
16141        Activation::Exp => {
16142            for (o, &v) in out.iter_mut().zip(inp) {
16143                *o = v.exp();
16144            }
16145        }
16146        Activation::Log => {
16147            for (o, &v) in out.iter_mut().zip(inp) {
16148                *o = v.ln();
16149            }
16150        }
16151        Activation::Sqrt => {
16152            for (o, &v) in out.iter_mut().zip(inp) {
16153                *o = v.sqrt();
16154            }
16155        }
16156        Activation::Rsqrt => {
16157            for (o, &v) in out.iter_mut().zip(inp) {
16158                *o = 1.0 / v.sqrt();
16159            }
16160        }
16161        Activation::Abs => {
16162            for (o, &v) in out.iter_mut().zip(inp) {
16163                *o = v.abs();
16164            }
16165        }
16166        Activation::Tanh => {
16167            for (o, &v) in out.iter_mut().zip(inp) {
16168                *o = v.tanh();
16169            }
16170        }
16171        Activation::Sigmoid => {
16172            for (o, &v) in out.iter_mut().zip(inp) {
16173                *o = 1.0 / (1.0 + (-v).exp());
16174            }
16175        }
16176        Activation::Relu => {
16177            for (o, &v) in out.iter_mut().zip(inp) {
16178                *o = v.max(0.0);
16179            }
16180        }
16181        Activation::Round => {
16182            for (o, &v) in out.iter_mut().zip(inp) {
16183                *o = v.round_ties_even();
16184            }
16185        }
16186        Activation::Sin => {
16187            for (o, &v) in out.iter_mut().zip(inp) {
16188                *o = v.sin();
16189            }
16190        }
16191        Activation::Cos => {
16192            for (o, &v) in out.iter_mut().zip(inp) {
16193                *o = v.cos();
16194            }
16195        }
16196        Activation::Tan => {
16197            for (o, &v) in out.iter_mut().zip(inp) {
16198                *o = v.tan();
16199            }
16200        }
16201        Activation::Atan => {
16202            for (o, &v) in out.iter_mut().zip(inp) {
16203                *o = v.atan();
16204            }
16205        }
16206        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
16207            panic!(
16208                "apply_activation_f64: {kind:?} not yet implemented at f64. \
16209                    Add when a workload needs it."
16210            );
16211        }
16212    }
16213}
16214
16215#[inline]
16216fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
16217    match op {
16218        BinaryOp::Add => a + b,
16219        BinaryOp::Sub => a - b,
16220        BinaryOp::Mul => a * b,
16221        BinaryOp::Div => a / b,
16222        BinaryOp::Max => a.max(b),
16223        BinaryOp::Min => a.min(b),
16224        BinaryOp::Pow => a.powf(b),
16225    }
16226}
16227
16228/// f64 sum reduction over a contiguous middle range.
16229/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
16230fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
16231    for o in 0..outer {
16232        for n in 0..inner {
16233            let mut acc = 0.0_f64;
16234            for r in 0..reduced {
16235                acc += inp[o * reduced * inner + r * inner + n];
16236            }
16237            out[o * inner + n] = acc;
16238        }
16239    }
16240}
16241
16242#[cfg(test)]
16243mod tests {
16244    use super::*;
16245    use rlx_ir::*;
16246
16247    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
16248    /// fusion pass collapses them — the Narrow becomes Nop, and the
16249    /// Rope reads from the parent buffer with its row stride. This
16250    /// test runs the unfused path (batch*seq > FusedAttnBlock
16251    /// threshold) and asserts the rewrite happened.
16252    #[test]
16253    fn narrow_rope_fuses_in_unfused_path() {
16254        let f = DType::F32;
16255        let mut g = Graph::new("nr_fuse");
16256        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
16257        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
16258        let cos = g.input("cos", Shape::new(&[16], f));
16259        let sin = g.input("sin", Shape::new(&[16], f));
16260        // Last-axis narrow: Q = qkv[..., 0..64]
16261        let q = g.narrow_(qkv, 2, 0, 64);
16262        let q_rope = g.rope(q, cos, sin, 16);
16263        g.set_outputs(vec![q_rope]);
16264
16265        let plan = rlx_opt::memory::plan_memory(&g);
16266        let arena = crate::arena::Arena::from_plan(plan);
16267        let sched = compile_thunks(&g, &arena);
16268
16269        let mut narrow_count = 0;
16270        let mut rope_with_stride: Option<u32> = None;
16271        for t in &sched.thunks {
16272            match t {
16273                Thunk::Narrow { .. } => narrow_count += 1,
16274                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
16275                _ => {}
16276            }
16277        }
16278        // After fusion the Narrow is gone; only the Rope remains, and
16279        // it now walks with the parent QKV's row stride (3 * 64 = 192).
16280        assert_eq!(
16281            narrow_count, 0,
16282            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
16283        );
16284        assert_eq!(
16285            rope_with_stride,
16286            Some(192),
16287            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
16288        );
16289    }
16290
16291    /// Plan #15: SSM selective scan matches a naive Python-style
16292    /// Python-style sequential reference.
16293    #[test]
16294    fn ssm_selective_scan_matches_reference() {
16295        use rlx_ir::Philox4x32;
16296        let bch = 1usize;
16297        let s = 4usize;
16298        let h = 3usize;
16299        let n = 2usize;
16300
16301        let mut rng = Philox4x32::new(13);
16302        let mut x = vec![0f32; bch * s * h];
16303        rng.fill_normal(&mut x);
16304        let mut delta = vec![0f32; bch * s * h];
16305        // Keep Δ small so exp(Δ·A) doesn't blow up.
16306        for v in delta.iter_mut() {
16307            *v = (rng.next_f32() - 0.5) * 0.1;
16308        }
16309        let mut a = vec![0f32; h * n];
16310        for v in a.iter_mut() {
16311            *v = -(rng.next_f32() * 0.5 + 0.1);
16312        } // negative for stability
16313        let mut b = vec![0f32; bch * s * n];
16314        rng.fill_normal(&mut b);
16315        let mut c = vec![0f32; bch * s * n];
16316        rng.fill_normal(&mut c);
16317
16318        // Reference scan.
16319        let mut expected = vec![0f32; bch * s * h];
16320        for bi in 0..bch {
16321            let mut state = vec![0f32; h * n];
16322            for si in 0..s {
16323                for ci in 0..h {
16324                    let d = delta[bi * s * h + si * h + ci];
16325                    let xv = x[bi * s * h + si * h + ci];
16326                    let mut acc = 0f32;
16327                    for ni in 0..n {
16328                        let da = (d * a[ci * n + ni]).exp();
16329                        state[ci * n + ni] =
16330                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
16331                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
16332                    }
16333                    expected[bi * s * h + si * h + ci] = acc;
16334                }
16335            }
16336        }
16337
16338        // RLX path.
16339        let f = DType::F32;
16340        let mut g = Graph::new("ssm");
16341        let xn = g.input("x", Shape::new(&[bch, s, h], f));
16342        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
16343        let an = g.param("a", Shape::new(&[h, n], f));
16344        let bn = g.param("b", Shape::new(&[bch, s, n], f));
16345        let cn = g.param("c", Shape::new(&[bch, s, n], f));
16346        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
16347        g.set_outputs(vec![yn]);
16348
16349        let plan = rlx_opt::memory::plan_memory(&g);
16350        let mut arena = crate::arena::Arena::from_plan(plan);
16351        let sched = compile_thunks(&g, &arena);
16352
16353        let xn_off = arena.byte_offset(xn);
16354        let dn_off = arena.byte_offset(dn);
16355        let an_off = arena.byte_offset(an);
16356        let bn_off = arena.byte_offset(bn);
16357        let cn_off = arena.byte_offset(cn);
16358        let yn_off = arena.byte_offset(yn);
16359        let buf = arena.raw_buf_mut();
16360        unsafe {
16361            let copy = |dst: *mut f32, data: &[f32]| {
16362                for (i, &v) in data.iter().enumerate() {
16363                    *dst.add(i) = v;
16364                }
16365            };
16366            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16367            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
16368            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16369            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16370            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
16371        }
16372        execute_thunks(&sched, arena.raw_buf_mut());
16373
16374        let actual: Vec<f32> = unsafe {
16375            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
16376            (0..bch * s * h).map(|i| *p.add(i)).collect()
16377        };
16378
16379        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16380            assert!(
16381                (e - a).abs() < 1e-3,
16382                "mismatch at {i}: expected {e}, got {a}"
16383            );
16384        }
16385    }
16386
16387    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
16388    /// scalar 7-loop reference.
16389    #[test]
16390    fn conv_1x1_fast_path_matches_scalar() {
16391        use rlx_ir::Philox4x32;
16392        // [N=2, C_in=4, H=3, W=3]
16393        let n = 2usize;
16394        let c_in = 4usize;
16395        let h = 3usize;
16396        let w = 3usize;
16397        let c_out = 5usize;
16398        let mut rng = Philox4x32::new(31);
16399        let mut x = vec![0f32; n * c_in * h * w];
16400        rng.fill_normal(&mut x);
16401        let mut weight = vec![0f32; c_out * c_in];
16402        rng.fill_normal(&mut weight);
16403
16404        // Reference: scalar 1×1 conv = per-batch matmul
16405        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
16406        let mut expected = vec![0f32; n * c_out * h * w];
16407        for ni in 0..n {
16408            for co in 0..c_out {
16409                for hi in 0..h {
16410                    for wi in 0..w {
16411                        let mut acc = 0f32;
16412                        for ci in 0..c_in {
16413                            acc += weight[co * c_in + ci]
16414                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
16415                        }
16416                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
16417                    }
16418                }
16419            }
16420        }
16421
16422        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
16423        let f = DType::F32;
16424        let mut g = Graph::new("conv1x1");
16425        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
16426        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
16427        // Manually add Op::Conv since there's no `g.conv()` helper.
16428        let cn = g.add_node(
16429            rlx_ir::Op::Conv {
16430                kernel_size: vec![1, 1],
16431                stride: vec![1, 1],
16432                padding: vec![0, 0],
16433                dilation: vec![1, 1],
16434                groups: 1,
16435            },
16436            vec![xn, wn],
16437            Shape::new(&[n, c_out, h, w], f),
16438        );
16439        g.set_outputs(vec![cn]);
16440
16441        let plan = rlx_opt::memory::plan_memory(&g);
16442        let mut arena = crate::arena::Arena::from_plan(plan);
16443        let sched = compile_thunks(&g, &arena);
16444
16445        // Verify the fast path was selected.
16446        let saw_fast = sched
16447            .thunks
16448            .iter()
16449            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
16450        let saw_slow = sched
16451            .thunks
16452            .iter()
16453            .any(|t| matches!(t, Thunk::Conv2D { .. }));
16454        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
16455        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
16456
16457        let xn_off = arena.byte_offset(xn);
16458        let wn_off = arena.byte_offset(wn);
16459        let cn_off = arena.byte_offset(cn);
16460        let buf = arena.raw_buf_mut();
16461        unsafe {
16462            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16463            for (i, &v) in x.iter().enumerate() {
16464                *xp.add(i) = v;
16465            }
16466            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
16467            for (i, &v) in weight.iter().enumerate() {
16468                *wp.add(i) = v;
16469            }
16470        }
16471        execute_thunks(&sched, arena.raw_buf_mut());
16472
16473        let actual: Vec<f32> = unsafe {
16474            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
16475            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
16476        };
16477
16478        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16479            assert!(
16480                (e - a).abs() < 1e-3,
16481                "mismatch at {i}: expected {e}, got {a}"
16482            );
16483        }
16484    }
16485
16486    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
16487    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
16488    #[test]
16489    fn dequant_matmul_int8_sym_matches_reference() {
16490        use rlx_ir::Philox4x32;
16491        use rlx_ir::quant::QuantScheme;
16492
16493        let m = 3usize;
16494        let k = 8usize;
16495        let n = 4usize;
16496        let block_size = 4usize; // 2 blocks per column
16497        let blocks_per_col = k / block_size;
16498
16499        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
16500        let mut rng = Philox4x32::new(99);
16501        let mut x = vec![0f32; m * k];
16502        rng.fill_normal(&mut x);
16503        let w_q: Vec<i8> = (0..(k * n))
16504            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
16505            .collect();
16506        let scales: Vec<f32> = (0..(blocks_per_col * n))
16507            .map(|i| 0.01 + 0.001 * i as f32)
16508            .collect();
16509
16510        // Reference: build f32 weights from (q * scale) per block.
16511        let mut w_f32 = vec![0f32; k * n];
16512        for p in 0..k {
16513            let block = p / block_size;
16514            for j in 0..n {
16515                let s = scales[block * n + j];
16516                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
16517            }
16518        }
16519        let mut expected = vec![0f32; m * n];
16520        for i in 0..m {
16521            for j in 0..n {
16522                let mut acc = 0f32;
16523                for p in 0..k {
16524                    acc += x[i * k + p] * w_f32[p * n + j];
16525                }
16526                expected[i * n + j] = acc;
16527            }
16528        }
16529
16530        // RLX path.
16531        let f = DType::F32;
16532        let mut g = Graph::new("dq");
16533        let xn = g.input("x", Shape::new(&[m, k], f));
16534        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
16535        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
16536        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
16537        let dq = g.dequant_matmul(
16538            xn,
16539            wn,
16540            sn,
16541            zn,
16542            QuantScheme::Int8Block {
16543                block_size: block_size as u32,
16544            },
16545            Shape::new(&[m, n], f),
16546        );
16547        g.set_outputs(vec![dq]);
16548
16549        let plan = rlx_opt::memory::plan_memory(&g);
16550        let mut arena = crate::arena::Arena::from_plan(plan);
16551        let sched = compile_thunks(&g, &arena);
16552
16553        let xn_off = arena.byte_offset(xn);
16554        let wn_off = arena.byte_offset(wn);
16555        let sn_off = arena.byte_offset(sn);
16556        let zn_off = arena.byte_offset(zn);
16557        let dq_off = arena.byte_offset(dq);
16558        let buf = arena.raw_buf_mut();
16559        unsafe {
16560            // Seed f32 inputs.
16561            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
16562            for (i, &v) in x.iter().enumerate() {
16563                *xp.add(i) = v;
16564            }
16565            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
16566            for (i, &v) in scales.iter().enumerate() {
16567                *sp.add(i) = v;
16568            }
16569            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
16570            for i in 0..(blocks_per_col * n) {
16571                *zp.add(i) = 0.0;
16572            }
16573            // Seed i8 weights byte-by-byte.
16574            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
16575            for (i, &v) in w_q.iter().enumerate() {
16576                *wp.add(i) = v;
16577            }
16578        }
16579        execute_thunks(&sched, arena.raw_buf_mut());
16580
16581        let actual: Vec<f32> = unsafe {
16582            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
16583            (0..m * n).map(|i| *p.add(i)).collect()
16584        };
16585
16586        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16587            assert!(
16588                (e - a).abs() < 1e-3,
16589                "mismatch at {i}: expected {e}, got {a}"
16590            );
16591        }
16592    }
16593
16594    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
16595    #[test]
16596    fn lora_matmul_matches_unfused_reference() {
16597        use rlx_ir::Philox4x32;
16598
16599        let m = 4usize;
16600        let k = 8usize;
16601        let n = 6usize;
16602        let r = 2usize;
16603        let scale = 0.5f32;
16604
16605        // Random inputs (deterministic via Philox).
16606        let mut rng = Philox4x32::new(42);
16607        let mut x = vec![0f32; m * k];
16608        rng.fill_normal(&mut x);
16609        let mut w = vec![0f32; k * n];
16610        rng.fill_normal(&mut w);
16611        let mut a = vec![0f32; k * r];
16612        rng.fill_normal(&mut a);
16613        let mut b = vec![0f32; r * n];
16614        rng.fill_normal(&mut b);
16615
16616        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
16617        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
16618            let mut o = vec![0f32; rows * cols];
16619            for i in 0..rows {
16620                for j in 0..cols {
16621                    let mut acc = 0f32;
16622                    for p in 0..inner {
16623                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
16624                    }
16625                    o[i * cols + j] = acc;
16626                }
16627            }
16628            o
16629        };
16630        let xw = naive(&x, &w, m, k, n);
16631        let xa = naive(&x, &a, m, k, r);
16632        let xab = naive(&xa, &b, m, r, n);
16633        let mut expected = xw;
16634        for i in 0..(m * n) {
16635            expected[i] += scale * xab[i];
16636        }
16637
16638        // RLX path: build a graph with one LoraMatMul.
16639        let f = DType::F32;
16640        let mut g = Graph::new("lora");
16641        let xn = g.input("x", Shape::new(&[m, k], f));
16642        let wn = g.param("w", Shape::new(&[k, n], f));
16643        let an = g.param("a", Shape::new(&[k, r], f));
16644        let bn = g.param("b", Shape::new(&[r, n], f));
16645        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
16646        g.set_outputs(vec![lm]);
16647
16648        let plan = rlx_opt::memory::plan_memory(&g);
16649        let mut arena = crate::arena::Arena::from_plan(plan);
16650        let sched = compile_thunks(&g, &arena);
16651
16652        let xn_off = arena.byte_offset(xn);
16653        let wn_off = arena.byte_offset(wn);
16654        let an_off = arena.byte_offset(an);
16655        let bn_off = arena.byte_offset(bn);
16656        let lm_off = arena.byte_offset(lm);
16657        let buf = arena.raw_buf_mut();
16658        unsafe {
16659            let copy = |dst: *mut f32, data: &[f32]| {
16660                for (i, &v) in data.iter().enumerate() {
16661                    *dst.add(i) = v;
16662                }
16663            };
16664            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
16665            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
16666            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
16667            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
16668        }
16669        execute_thunks(&sched, arena.raw_buf_mut());
16670
16671        let actual: Vec<f32> = unsafe {
16672            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
16673            (0..m * n).map(|i| *p.add(i)).collect()
16674        };
16675
16676        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
16677            assert!(
16678                (e - a).abs() < 1e-3,
16679                "mismatch at {i}: expected {e}, got {a}"
16680            );
16681        }
16682    }
16683
16684    /// Plan #42: fused sampling kernel determinism + greedy fallback.
16685    #[test]
16686    fn sample_temperature_zero_is_argmax() {
16687        // Very low temperature → distribution collapses on argmax.
16688        // Same seed → same output bit-for-bit.
16689        let f = DType::F32;
16690        let mut g = Graph::new("samp");
16691        let logits = g.input("logits", Shape::new(&[1, 8], f));
16692        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
16693        g.set_outputs(vec![s]);
16694        let plan = rlx_opt::memory::plan_memory(&g);
16695        let mut arena = crate::arena::Arena::from_plan(plan);
16696        let sched = compile_thunks(&g, &arena);
16697
16698        let logits_off = arena.byte_offset(logits);
16699        let s_off = arena.byte_offset(s);
16700        let buf = arena.raw_buf_mut();
16701        unsafe {
16702            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16703            // argmax = index 5 (value 9.0).
16704            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
16705            for (i, &v) in inputs.iter().enumerate() {
16706                *p.add(i) = v;
16707            }
16708        }
16709        execute_thunks(&sched, arena.raw_buf_mut());
16710
16711        let token = unsafe {
16712            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16713            *p as usize
16714        };
16715        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
16716    }
16717
16718    #[test]
16719    fn sample_top_k_one_is_deterministic() {
16720        // top_k=1 forces only the argmax to have nonzero probability.
16721        let f = DType::F32;
16722        let mut g = Graph::new("samp_k1");
16723        let logits = g.input("logits", Shape::new(&[1, 4], f));
16724        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
16725        g.set_outputs(vec![s]);
16726        let plan = rlx_opt::memory::plan_memory(&g);
16727        let mut arena = crate::arena::Arena::from_plan(plan);
16728        let sched = compile_thunks(&g, &arena);
16729
16730        let logits_off = arena.byte_offset(logits);
16731        let s_off = arena.byte_offset(s);
16732        let buf = arena.raw_buf_mut();
16733        unsafe {
16734            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
16735            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
16736            for (i, &v) in inputs.iter().enumerate() {
16737                *p.add(i) = v;
16738            }
16739        }
16740        execute_thunks(&sched, arena.raw_buf_mut());
16741        let token = unsafe {
16742            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
16743            *p as usize
16744        };
16745        assert_eq!(token, 1);
16746    }
16747
16748    /// Plan #44: cumsum primitive parity vs. naive scan.
16749    #[test]
16750    fn cumsum_inclusive_matches_naive() {
16751        let f = DType::F32;
16752        let mut g = Graph::new("cumsum");
16753        let x = g.input("x", Shape::new(&[2, 4], f));
16754        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
16755        g.set_outputs(vec![cs]);
16756        let plan = rlx_opt::memory::plan_memory(&g);
16757        let mut arena = crate::arena::Arena::from_plan(plan);
16758        let sched = compile_thunks(&g, &arena);
16759
16760        // Cache offsets up-front so we can drop the immutable borrow.
16761        let x_off = arena.byte_offset(x);
16762        let out_off = arena.byte_offset(cs);
16763        let buf = arena.raw_buf_mut();
16764        unsafe {
16765            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
16766            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
16767            for (i, &v) in inputs.iter().enumerate() {
16768                *p.add(i) = v;
16769            }
16770        }
16771        execute_thunks(&sched, arena.raw_buf_mut());
16772
16773        let out: Vec<f32> = unsafe {
16774            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
16775            (0..8).map(|i| *p.add(i)).collect()
16776        };
16777        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
16778    }
16779
16780    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
16781    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
16782    /// path collapse into a single strided-Attention thunk.
16783    #[test]
16784    fn narrow_attention_fuses_in_unfused_path() {
16785        let f = DType::F32;
16786        let mut g = Graph::new("nattn_fuse");
16787        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
16788        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
16789        let mask = g.input("mask", Shape::new(&[8, 16], f));
16790        let q = g.narrow_(qkv, 2, 0, 64);
16791        let k = g.narrow_(qkv, 2, 64, 64);
16792        let v = g.narrow_(qkv, 2, 128, 64);
16793        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
16794        g.set_outputs(vec![attn]);
16795
16796        let plan = rlx_opt::memory::plan_memory(&g);
16797        let arena = crate::arena::Arena::from_plan(plan);
16798        let sched = compile_thunks(&g, &arena);
16799
16800        let mut narrow_count = 0;
16801        let mut attn_strides: Option<(u32, u32, u32)> = None;
16802        for t in &sched.thunks {
16803            match t {
16804                Thunk::Narrow { .. } => narrow_count += 1,
16805                Thunk::Attention {
16806                    q_row_stride,
16807                    k_row_stride,
16808                    v_row_stride,
16809                    ..
16810                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
16811                _ => {}
16812            }
16813        }
16814        // After fusion the 3 narrows are gone; Attention now walks the
16815        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
16816        assert_eq!(
16817            narrow_count, 0,
16818            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
16819        );
16820        assert_eq!(
16821            attn_strides,
16822            Some((192, 192, 192)),
16823            "Attention should walk Q/K/V with parent row stride 192"
16824        );
16825    }
16826
16827    // ── Backward / training op parity tests ────────────────────
16828    //
16829    // Strategy: build a graph that contains exactly the backward op
16830    // under test (plus its inputs as graph Inputs), execute, and
16831    // compare against a hand-rolled scalar reference. For
16832    // Conv2dBackwardInput we additionally check against the numerical
16833    // gradient of the forward Conv2D — that's the gold-standard test
16834    // that validates the math, not just consistency between two
16835    // implementations of the same formula.
16836
16837    fn run_graph(
16838        g: &Graph,
16839        inputs: &[(NodeId, &[f32])],
16840        out_id: NodeId,
16841        out_len: usize,
16842    ) -> Vec<f32> {
16843        let plan = rlx_opt::memory::plan_memory(g);
16844        let mut arena = crate::arena::Arena::from_plan(plan);
16845        let sched = compile_thunks(g, &arena);
16846        for &(id, data) in inputs {
16847            let off = arena.byte_offset(id);
16848            let buf = arena.raw_buf_mut();
16849            unsafe {
16850                let p = buf.as_mut_ptr().add(off) as *mut f32;
16851                for (i, &v) in data.iter().enumerate() {
16852                    *p.add(i) = v;
16853                }
16854            }
16855        }
16856        execute_thunks(&sched, arena.raw_buf_mut());
16857        let off = arena.byte_offset(out_id);
16858        unsafe {
16859            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
16860            (0..out_len).map(|i| *p.add(i)).collect()
16861        }
16862    }
16863
16864    #[test]
16865    fn relu_backward_matches_mask() {
16866        let f = DType::F32;
16867        let len = 7usize;
16868        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
16869        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
16870
16871        let mut g = Graph::new("relu_bw");
16872        let xn = g.input("x", Shape::new(&[len], f));
16873        let dyn_ = g.input("dy", Shape::new(&[len], f));
16874        let dx = g.relu_backward(xn, dyn_);
16875        g.set_outputs(vec![dx]);
16876
16877        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
16878        // Reference: gradient is dy where x>0 strictly, else 0.
16879        // (zero is not "positive" — the forward applied max(0, x), and at
16880        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
16881        let expected: Vec<f32> = x
16882            .iter()
16883            .zip(&dy)
16884            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
16885            .collect();
16886        for (a, e) in actual.iter().zip(&expected) {
16887            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
16888        }
16889    }
16890
16891    #[test]
16892    fn maxpool2d_backward_routes_to_argmax() {
16893        let f = DType::F32;
16894        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
16895        let x: Vec<f32> = vec![
16896            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
16897        ];
16898        // Argmax of each 2x2 window:
16899        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
16900        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
16901        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
16902
16903        let mut g = Graph::new("maxpool_bw");
16904        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
16905        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
16906        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
16907        g.set_outputs(vec![dx]);
16908
16909        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
16910        let mut expected = vec![0f32; 16];
16911        expected[5] = 0.5;
16912        expected[7] = 1.0;
16913        expected[13] = 2.0;
16914        expected[15] = 4.0;
16915        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
16916            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
16917        }
16918    }
16919
16920    #[test]
16921    fn conv2d_backward_input_matches_numerical_gradient() {
16922        use rlx_ir::Philox4x32;
16923        // Small enough to numerically differentiate exhaustively but
16924        // big enough to exercise stride/padding edge cases.
16925        let n = 1usize;
16926        let c_in = 2usize;
16927        let h = 4usize;
16928        let w = 4usize;
16929        let c_out = 3usize;
16930        let kh = 3usize;
16931        let kw = 3usize;
16932        let ph = 1usize;
16933        let pw = 1usize;
16934        let sh = 1usize;
16935        let sw = 1usize;
16936        // Output dims with padding=1, stride=1: same as input.
16937        let h_out = (h + 2 * ph - kh) / sh + 1;
16938        let w_out = (w + 2 * pw - kw) / sw + 1;
16939        assert_eq!(h_out, 4);
16940        assert_eq!(w_out, 4);
16941
16942        let mut rng = Philox4x32::new(7);
16943        let mut x = vec![0f32; n * c_in * h * w];
16944        rng.fill_normal(&mut x);
16945        let mut wt = vec![0f32; c_out * c_in * kh * kw];
16946        rng.fill_normal(&mut wt);
16947        let mut dy = vec![0f32; n * c_out * h_out * w_out];
16948        rng.fill_normal(&mut dy);
16949
16950        // Analytical: Conv2dBackwardInput on (dy, w).
16951        let f = DType::F32;
16952        let mut g = Graph::new("conv_bwi");
16953        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
16954        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
16955        let dx = g.conv2d_backward_input(
16956            dy_in,
16957            w_in,
16958            Shape::new(&[n, c_in, h, w], f),
16959            vec![kh, kw],
16960            vec![sh, sw],
16961            vec![ph, pw],
16962            vec![1, 1],
16963            1,
16964        );
16965        g.set_outputs(vec![dx]);
16966        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
16967
16968        // Numerical: for each x[i], finite-difference forward conv twice.
16969        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
16970        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
16971        let forward = |x: &[f32]| -> Vec<f32> {
16972            let mut out = vec![0f32; n * c_out * h_out * w_out];
16973            for ni in 0..n {
16974                for co in 0..c_out {
16975                    for ho in 0..h_out {
16976                        for wo in 0..w_out {
16977                            let mut acc = 0f32;
16978                            for ci in 0..c_in {
16979                                for ki in 0..kh {
16980                                    for kj in 0..kw {
16981                                        let hi = ho * sh + ki;
16982                                        let wi = wo * sw + kj;
16983                                        if hi < ph || wi < pw {
16984                                            continue;
16985                                        }
16986                                        let hi = hi - ph;
16987                                        let wi = wi - pw;
16988                                        if hi >= h || wi >= w {
16989                                            continue;
16990                                        }
16991                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
16992                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
16993                                        acc += xv * wv;
16994                                    }
16995                                }
16996                            }
16997                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
16998                        }
16999                    }
17000                }
17001            }
17002            out
17003        };
17004        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17005        let eps = 1e-3f32;
17006        let mut numerical = vec![0f32; x.len()];
17007        for i in 0..x.len() {
17008            let saved = x[i];
17009            x[i] = saved + eps;
17010            let plus = dot(&forward(&x), &dy);
17011            x[i] = saved - eps;
17012            let minus = dot(&forward(&x), &dy);
17013            x[i] = saved;
17014            numerical[i] = (plus - minus) / (2.0 * eps);
17015        }
17016        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17017            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
17018            assert!(
17019                (a - n).abs() < 5e-3,
17020                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
17021            );
17022        }
17023    }
17024
17025    #[test]
17026    fn conv2d_backward_weight_matches_numerical_gradient() {
17027        use rlx_ir::Philox4x32;
17028        let n = 2usize;
17029        let c_in = 2usize;
17030        let h = 4usize;
17031        let w = 4usize;
17032        let c_out = 2usize;
17033        let kh = 3usize;
17034        let kw = 3usize;
17035        let ph = 0usize;
17036        let pw = 0usize;
17037        let sh = 1usize;
17038        let sw = 1usize;
17039        let h_out = (h + 2 * ph - kh) / sh + 1;
17040        let w_out = (w + 2 * pw - kw) / sw + 1;
17041
17042        let mut rng = Philox4x32::new(11);
17043        let mut x = vec![0f32; n * c_in * h * w];
17044        rng.fill_normal(&mut x);
17045        let mut wt = vec![0f32; c_out * c_in * kh * kw];
17046        rng.fill_normal(&mut wt);
17047        let mut dy = vec![0f32; n * c_out * h_out * w_out];
17048        rng.fill_normal(&mut dy);
17049
17050        let f = DType::F32;
17051        let mut g = Graph::new("conv_bww");
17052        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17053        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17054        let dwn = g.conv2d_backward_weight(
17055            xn,
17056            dyn_,
17057            Shape::new(&[c_out, c_in, kh, kw], f),
17058            vec![kh, kw],
17059            vec![sh, sw],
17060            vec![ph, pw],
17061            vec![1, 1],
17062            1,
17063        );
17064        g.set_outputs(vec![dwn]);
17065        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
17066
17067        let forward = |wt: &[f32]| -> Vec<f32> {
17068            let mut out = vec![0f32; n * c_out * h_out * w_out];
17069            for ni in 0..n {
17070                for co in 0..c_out {
17071                    for ho in 0..h_out {
17072                        for wo in 0..w_out {
17073                            let mut acc = 0f32;
17074                            for ci in 0..c_in {
17075                                for ki in 0..kh {
17076                                    for kj in 0..kw {
17077                                        let hi = ho + ki;
17078                                        let wi = wo + kj;
17079                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17080                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17081                                        acc += xv * wv;
17082                                    }
17083                                }
17084                            }
17085                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17086                        }
17087                    }
17088                }
17089            }
17090            out
17091        };
17092        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17093        let eps = 1e-3f32;
17094        let mut numerical = vec![0f32; wt.len()];
17095        for i in 0..wt.len() {
17096            let saved = wt[i];
17097            wt[i] = saved + eps;
17098            let plus = dot(&forward(&wt), &dy);
17099            wt[i] = saved - eps;
17100            let minus = dot(&forward(&wt), &dy);
17101            wt[i] = saved;
17102            numerical[i] = (plus - minus) / (2.0 * eps);
17103        }
17104        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17105            assert!(
17106                (a - n).abs() < 5e-3,
17107                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17108            );
17109        }
17110    }
17111
17112    #[test]
17113    fn softmax_cross_entropy_matches_reference() {
17114        let f = DType::F32;
17115        let logits: Vec<f32> = vec![
17116            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
17117            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
17118            5.0, 5.0, 5.0, // row 2: uniform
17119        ];
17120        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17121
17122        let mut g = Graph::new("sce");
17123        let lg = g.input("logits", Shape::new(&[3, 3], f));
17124        let lb = g.input("labels", Shape::new(&[3], f));
17125        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17126        g.set_outputs(vec![loss]);
17127        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17128
17129        // Reference per-row: -log(softmax(row)[label]).
17130        let mut expected = vec![0f32; 3];
17131        for ni in 0..3 {
17132            let row = &logits[ni * 3..(ni + 1) * 3];
17133            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17134            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17135            let lse = m + sum.ln();
17136            let label_idx = labels[ni] as usize;
17137            expected[ni] = lse - row[label_idx];
17138        }
17139        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17140            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17141        }
17142    }
17143
17144    #[test]
17145    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
17146        use rlx_ir::Philox4x32;
17147        let n = 4usize;
17148        let c = 5usize;
17149        let mut rng = Philox4x32::new(23);
17150        let mut logits = vec![0f32; n * c];
17151        rng.fill_normal(&mut logits);
17152        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
17153        let mut d_loss = vec![0f32; n];
17154        rng.fill_normal(&mut d_loss);
17155
17156        let f = DType::F32;
17157        let mut g = Graph::new("sce_bw");
17158        let lg = g.input("logits", Shape::new(&[n, c], f));
17159        let lb = g.input("labels", Shape::new(&[n], f));
17160        let dl = g.input("d_loss", Shape::new(&[n], f));
17161        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
17162        g.set_outputs(vec![dlogits]);
17163        let analytical = run_graph(
17164            &g,
17165            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
17166            dlogits,
17167            n * c,
17168        );
17169
17170        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
17171        let sce_loss = |logits: &[f32]| -> Vec<f32> {
17172            let mut out = vec![0f32; n];
17173            for ni in 0..n {
17174                let row = &logits[ni * c..(ni + 1) * c];
17175                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17176                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17177                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
17178            }
17179            out
17180        };
17181        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
17182        let eps = 1e-3f32;
17183        let mut numerical = vec![0f32; logits.len()];
17184        for i in 0..logits.len() {
17185            let saved = logits[i];
17186            logits[i] = saved + eps;
17187            let plus = dot(&sce_loss(&logits), &d_loss);
17188            logits[i] = saved - eps;
17189            let minus = dot(&sce_loss(&logits), &d_loss);
17190            logits[i] = saved;
17191            numerical[i] = (plus - minus) / (2.0 * eps);
17192        }
17193        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
17194            assert!(
17195                (a - num).abs() < 5e-3,
17196                "sce_bw[{i}]: analytical {a} vs numerical {num}"
17197            );
17198        }
17199    }
17200
17201    // ── End-to-end autodiff parity tests ──────────────────────
17202    //
17203    // Build a forward graph, run `grad_with_loss` to produce a graph
17204    // that emits [loss, gradients...], execute it through rlx-cpu,
17205    // and compare each gradient to a finite-difference estimate
17206    // produced by re-running the forward graph with each parameter
17207    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
17208    // 5e-3 absolute error.
17209
17210    /// Initialize Op::Constant slots in the arena with their literal
17211    /// data. Mirrors the loop in rlx_runtime::backend (which serves
17212    /// the same role for production runs).
17213    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
17214        for node in graph.nodes() {
17215            if let Op::Constant { data } = &node.op
17216                && arena.has_buffer(node.id)
17217                && !data.is_empty()
17218            {
17219                let buf = arena.slice_mut(node.id);
17220                let n_floats = data.len() / 4;
17221                let n = buf.len().min(n_floats);
17222                for i in 0..n {
17223                    let bytes = [
17224                        data[i * 4],
17225                        data[i * 4 + 1],
17226                        data[i * 4 + 2],
17227                        data[i * 4 + 3],
17228                    ];
17229                    buf[i] = f32::from_le_bytes(bytes);
17230                }
17231            }
17232        }
17233    }
17234
17235    /// Compile + arena-prep helper for these tests. Returns the
17236    /// schedule and a populated arena. `seed_inputs` writes f32 input
17237    /// data into the arena slot for each (NodeId, &[f32]) pair.
17238    fn prepare(
17239        graph: &Graph,
17240        seed_inputs: &[(NodeId, &[f32])],
17241    ) -> (ThunkSchedule, crate::arena::Arena) {
17242        let plan = rlx_opt::memory::plan_memory(graph);
17243        let mut arena = crate::arena::Arena::from_plan(plan);
17244        let sched = compile_thunks(graph, &arena);
17245        fill_constants_into_arena(graph, &mut arena);
17246        for &(id, data) in seed_inputs {
17247            let off = arena.byte_offset(id);
17248            let buf = arena.raw_buf_mut();
17249            unsafe {
17250                let p = buf.as_mut_ptr().add(off) as *mut f32;
17251                for (i, &v) in data.iter().enumerate() {
17252                    *p.add(i) = v;
17253                }
17254            }
17255        }
17256        (sched, arena)
17257    }
17258
17259    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
17260        let off = arena.byte_offset(id);
17261        unsafe {
17262            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17263            (0..len).map(|i| *p.add(i)).collect()
17264        }
17265    }
17266
17267    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
17268        let off = arena.byte_offset(id);
17269        let buf = arena.raw_buf_mut();
17270        unsafe {
17271            let p = buf.as_mut_ptr().add(off) as *mut f32;
17272            for (i, &v) in data.iter().enumerate() {
17273                *p.add(i) = v;
17274            }
17275        }
17276    }
17277
17278    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
17279    fn prepare_f64(
17280        graph: &Graph,
17281        seed_inputs: &[(NodeId, &[f64])],
17282    ) -> (ThunkSchedule, crate::arena::Arena) {
17283        let plan = rlx_opt::memory::plan_memory(graph);
17284        let mut arena = crate::arena::Arena::from_plan(plan);
17285        let sched = compile_thunks(graph, &arena);
17286        fill_constants_into_arena(graph, &mut arena);
17287        for &(id, data) in seed_inputs {
17288            let off = arena.byte_offset(id);
17289            let buf = arena.raw_buf_mut();
17290            unsafe {
17291                let p = buf.as_mut_ptr().add(off) as *mut f64;
17292                for (i, &v) in data.iter().enumerate() {
17293                    *p.add(i) = v;
17294                }
17295            }
17296        }
17297        (sched, arena)
17298    }
17299
17300    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
17301        let off = arena.byte_offset(id);
17302        unsafe {
17303            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
17304            (0..len).map(|i| *p.add(i)).collect()
17305        }
17306    }
17307
17308    /// End-to-end f64 DenseSolve through the full compile + execute
17309    /// path. Validates: IR shape inference, memory planner f64 sizing,
17310    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
17311    /// dispatch, Accelerate dgesv FFI.
17312    ///
17313    /// System:
17314    ///   A = [[2, 1],
17315    ///        [1, 3]]   b = [5, 10]
17316    ///   ⇒  x = [1, 3]   (verified by hand)
17317    #[test]
17318    fn dense_solve_f64_end_to_end() {
17319        let mut g = Graph::new("solve_e2e");
17320        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
17321        let b = g.input("b", Shape::new(&[2], DType::F64));
17322        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
17323        g.set_outputs(vec![x]);
17324
17325        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
17326        let b_data = [5.0, 10.0_f64];
17327        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17328        execute_thunks(&sched, arena.raw_buf_mut());
17329
17330        let got = read_arena_f64(&arena, x, 2);
17331        let want = [1.0, 3.0_f64];
17332        for i in 0..2 {
17333            assert!(
17334                (got[i] - want[i]).abs() < 1e-12,
17335                "x[{i}] = {} (expected {})",
17336                got[i],
17337                want[i]
17338            );
17339        }
17340    }
17341
17342    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
17343    /// MNA structure for a passive RC mesh in Circulax). Validates
17344    /// that the solve scales beyond the trivial 2×2 and that the
17345    /// row-major ↔ col-major dance in `dgesv` is correct for the
17346    /// general case.
17347    #[test]
17348    fn dense_solve_f64_5x5_laplacian() {
17349        let n = 5usize;
17350        let mut g = Graph::new("solve_5x5");
17351        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17352        let b = g.input("b", Shape::new(&[n], DType::F64));
17353        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17354        g.set_outputs(vec![x]);
17355
17356        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
17357        let mut a_data = vec![0.0_f64; n * n];
17358        for i in 0..n {
17359            a_data[i * n + i] = 2.0;
17360            if i > 0 {
17361                a_data[i * n + (i - 1)] = -1.0;
17362            }
17363            if i + 1 < n {
17364                a_data[i * n + (i + 1)] = -1.0;
17365            }
17366        }
17367        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
17368        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17369        execute_thunks(&sched, arena.raw_buf_mut());
17370
17371        let got = read_arena_f64(&arena, x, n);
17372        // Verify A·x ≈ b by computing the residual.
17373        let mut residual = vec![0.0_f64; n];
17374        for i in 0..n {
17375            for j in 0..n {
17376                residual[i] += a_data[i * n + j] * got[j];
17377            }
17378        }
17379        for i in 0..n {
17380            assert!(
17381                (residual[i] - b_data[i]).abs() < 1e-10,
17382                "row {i}: residual {} vs b {}",
17383                residual[i],
17384                b_data[i]
17385            );
17386        }
17387    }
17388
17389    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
17390    ///
17391    /// Forward:
17392    ///   A      : Param  [N, N]   f64
17393    ///   b      : Input  [N]      f64
17394    ///   x      = solve(A, b)            (DenseSolve)
17395    ///   loss   = sum(x)                 (Reduce::Sum)
17396    ///
17397    /// Backward (via grad_with_loss):
17398    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
17399    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
17400    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
17401    ///   db       = dx_int                       (DenseSolve VJP step 3)
17402    ///
17403    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
17404    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
17405    /// We verify this against the autodiff-emitted graph's output and
17406    /// against a finite-difference baseline.
17407    #[test]
17408    fn hello_resistor_gradient_end_to_end() {
17409        use rlx_opt::autodiff::grad_with_loss;
17410        let n = 3usize;
17411
17412        // ── Build forward graph ──
17413        let mut g = Graph::new("hello_resistor");
17414        let a = g.param("A", Shape::new(&[n, n], DType::F64));
17415        let b = g.input("b", Shape::new(&[n], DType::F64));
17416        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17417        let loss = g.reduce(
17418            x,
17419            ReduceOp::Sum,
17420            vec![0],
17421            false,
17422            Shape::new(&[1], DType::F64),
17423        );
17424        g.set_outputs(vec![loss]);
17425
17426        // ── Run reverse-mode AD ──
17427        let bwd = grad_with_loss(&g, &[a, b]);
17428        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
17429
17430        // ── Locate the inputs the bwd graph still needs from us ──
17431        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
17432        // appear under their original names. Find them by name.
17433        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17434            for node in graph.nodes() {
17435                let name = match &node.op {
17436                    rlx_ir::Op::Input { name } => Some(name.as_str()),
17437                    rlx_ir::Op::Param { name } => Some(name.as_str()),
17438                    _ => None,
17439                };
17440                if name == Some(want) {
17441                    return node.id;
17442                }
17443            }
17444            panic!("no node named {want:?} in bwd graph");
17445        };
17446        let a_bwd = find_by_name(&bwd, "A");
17447        let b_bwd = find_by_name(&bwd, "b");
17448        let d_out_bwd = find_by_name(&bwd, "d_output");
17449
17450        // ── Test data ──
17451        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
17452        // b = [1,2,3]
17453        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17454        let b_data = [1.0, 2.0, 3.0_f64];
17455        let d_output = [1.0_f64]; // ∂loss/∂loss
17456
17457        // ── Compile + execute backward graph ──
17458        let (sched, mut arena) = prepare_f64(
17459            &bwd,
17460            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
17461        );
17462        execute_thunks(&sched, arena.raw_buf_mut());
17463
17464        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
17465        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
17466        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
17467
17468        // ── Closed-form reference ──
17469        // x = A⁻¹ b ; loss = sum(x).
17470        let x_ref = {
17471            let mut a = a_data;
17472            let mut b = b_data;
17473            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
17474            assert_eq!(info, 0);
17475            b
17476        };
17477        let loss_ref: f64 = x_ref.iter().sum();
17478        // db = (Aᵀ)⁻¹ · 1
17479        let db_ref = {
17480            let mut at = [0.0_f64; 9];
17481            for i in 0..n {
17482                for j in 0..n {
17483                    at[i * n + j] = a_data[j * n + i];
17484                }
17485            }
17486            let mut ones = [1.0_f64; 3];
17487            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
17488            assert_eq!(info, 0);
17489            ones
17490        };
17491        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
17492        let mut da_ref = [0.0_f64; 9];
17493        for i in 0..n {
17494            for j in 0..n {
17495                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
17496            }
17497        }
17498
17499        // ── Assertions vs analytic answer ──
17500        assert!(
17501            (loss_out[0] - loss_ref).abs() < 1e-10,
17502            "loss: got {}, want {}",
17503            loss_out[0],
17504            loss_ref
17505        );
17506        for i in 0..n {
17507            assert!(
17508                (db_out[i] - db_ref[i]).abs() < 1e-10,
17509                "db[{i}]: got {}, want {}",
17510                db_out[i],
17511                db_ref[i]
17512            );
17513        }
17514        for i in 0..n * n {
17515            assert!(
17516                (da_out[i] - da_ref[i]).abs() < 1e-10,
17517                "dA[{i}]: got {}, want {}",
17518                da_out[i],
17519                da_ref[i]
17520            );
17521        }
17522
17523        // ── Cross-check vs finite differences on db (a few entries) ──
17524        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
17525        let h = 1e-6_f64;
17526        for k in 0..n {
17527            let mut bp = b_data;
17528            bp[k] += h;
17529            let mut bm = b_data;
17530            bm[k] -= h;
17531            let lp = {
17532                let mut ac = a_data;
17533                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
17534                assert_eq!(info, 0);
17535                bp.iter().sum::<f64>()
17536            };
17537            let lm = {
17538                let mut ac = a_data;
17539                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
17540                assert_eq!(info, 0);
17541                bm.iter().sum::<f64>()
17542            };
17543            let fd = (lp - lm) / (2.0 * h);
17544            assert!(
17545                (db_out[k] - fd).abs() < 1e-7,
17546                "FD mismatch on db[{k}]: AD={} FD={}",
17547                db_out[k],
17548                fd
17549            );
17550        }
17551    }
17552
17553    /// Smallest possible Op::Scan basic test: geometric growth.
17554    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
17555    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
17556    /// to f64 precision.
17557    #[test]
17558    fn scan_geometric_growth_f64() {
17559        let n = 3usize;
17560        let length = 10u32;
17561
17562        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
17563        let mut body = Graph::new("scan_body");
17564        let x = body.input("carry", Shape::new(&[n], DType::F64));
17565        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
17566        let scale = body.add_node(
17567            Op::Constant { data: scale_bytes },
17568            vec![],
17569            Shape::new(&[n], DType::F64),
17570        );
17571        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17572        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
17573        body.set_outputs(vec![next]);
17574
17575        // Outer graph: scan(init, body, length).
17576        let mut g = Graph::new("scan_outer");
17577        let init = g.input("init", Shape::new(&[n], DType::F64));
17578        let final_carry = g.scan(init, body, length);
17579        g.set_outputs(vec![final_carry]);
17580
17581        let init_data = vec![1.0_f64; n];
17582        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17583        execute_thunks(&sched, arena.raw_buf_mut());
17584        let got = read_arena_f64(&arena, final_carry, n);
17585        let want: f64 = 1.1_f64.powi(length as i32);
17586        for i in 0..n {
17587            assert!(
17588                (got[i] - want).abs() < 1e-12,
17589                "got[{i}] = {} want {}",
17590                got[i],
17591                want
17592            );
17593        }
17594    }
17595
17596    /// Per-step xs scan: cumulative-sum.
17597    ///   carry_0 = init
17598    ///   carry_{t+1} = carry_t + xs\[t\]
17599    ///   final = sum_{t<length} xs\[t\] + init
17600    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
17601    /// (next carry). Validates the per-step-input plumbing end-to-end.
17602    #[test]
17603    fn scan_with_xs_cumulative_sum() {
17604        let n = 3usize;
17605        let length = 4u32;
17606
17607        let mut body = Graph::new("cumsum_body");
17608        // carry must come first in NodeId order — declare it first.
17609        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17610        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
17611        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
17612        body.set_outputs(vec![next]);
17613
17614        let mut g = Graph::new("cumsum_outer");
17615        let init = g.input("init", Shape::new(&[n], DType::F64));
17616        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17617        let final_carry = g.scan_with_xs(init, &[xs], body, length);
17618        g.set_outputs(vec![final_carry]);
17619
17620        let init_data = vec![0.0_f64; n];
17621        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
17622        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17623        execute_thunks(&sched, arena.raw_buf_mut());
17624        let got = read_arena_f64(&arena, final_carry, n);
17625
17626        // Reference: column-wise sum of xs rows + init. With our row-major
17627        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
17628        // (per-step row at offset t*n contributes element j to slot j).
17629        let mut want = init_data.clone();
17630        for t in 0..length as usize {
17631            for j in 0..n {
17632                want[j] += xs_data[t * n + j];
17633            }
17634        }
17635        for i in 0..n {
17636            assert!(
17637                (got[i] - want[i]).abs() < 1e-12,
17638                "got[{i}] = {} want {}",
17639                got[i],
17640                want[i]
17641            );
17642        }
17643    }
17644
17645    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
17646    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
17647    /// Models a Backward-Euler step driven by a time-varying source.
17648    #[test]
17649    fn scan_with_xs_be_with_drive() {
17650        let n = 3usize;
17651        let length = 4u32;
17652        let dt = 0.1_f64;
17653
17654        let mut m_data = vec![0.0_f64; n * n];
17655        for i in 0..n {
17656            m_data[i * n + i] = 1.0 + dt * 2.0;
17657            if i > 0 {
17658                m_data[i * n + (i - 1)] = -dt;
17659            }
17660            if i + 1 < n {
17661                m_data[i * n + (i + 1)] = -dt;
17662            }
17663        }
17664        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17665
17666        let mut body = Graph::new("be_drive_body");
17667        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17668        let drive = body.input("drive", Shape::new(&[n], DType::F64));
17669        let m = body.add_node(
17670            Op::Constant { data: m_bytes },
17671            vec![],
17672            Shape::new(&[n, n], DType::F64),
17673        );
17674        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17675        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17676        body.set_outputs(vec![next]);
17677
17678        let mut g = Graph::new("be_drive_outer");
17679        let init = g.input("init", Shape::new(&[n], DType::F64));
17680        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17681        let final_carry = g.scan_with_xs(init, &[xs], body, length);
17682        g.set_outputs(vec![final_carry]);
17683
17684        let init_data = vec![0.0_f64; n];
17685        // Drive the system with a unit pulse on element 0 at t=0,
17686        // zeros after.
17687        let mut xs_data = vec![0.0_f64; length as usize * n];
17688        xs_data[0] = 1.0;
17689
17690        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
17691        execute_thunks(&sched, arena.raw_buf_mut());
17692        let got = read_arena_f64(&arena, final_carry, n);
17693
17694        // Reference: per-step in pure Rust.
17695        let mut x = init_data.clone();
17696        for t in 0..length as usize {
17697            for j in 0..n {
17698                x[j] += xs_data[t * n + j];
17699            }
17700            let mut a_copy = m_data.clone();
17701            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
17702        }
17703        for i in 0..n {
17704            assert!(
17705                (got[i] - x[i]).abs() < 1e-12,
17706                "got[{i}] = {} ref {}",
17707                got[i],
17708                x[i]
17709            );
17710        }
17711    }
17712
17713    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
17714    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
17715    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
17716    /// per batch (each slice matches what the unbatched DenseSolve VJP
17717    /// would compute).
17718    #[test]
17719    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
17720        use rlx_opt::autodiff::grad_with_loss;
17721        let n = 3usize;
17722        let batch = 4usize;
17723
17724        let mut g = Graph::new("bds_grad");
17725        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
17726        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
17727        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
17728        let loss = g.reduce(
17729            x,
17730            ReduceOp::Sum,
17731            vec![0, 1],
17732            false,
17733            Shape::new(&[1], DType::F64),
17734        );
17735        g.set_outputs(vec![loss]);
17736
17737        let bwd = grad_with_loss(&g, &[a, b]);
17738
17739        let find = |graph: &Graph, want: &str| -> NodeId {
17740            for node in graph.nodes() {
17741                let name = match &node.op {
17742                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17743                    _ => None,
17744                };
17745                if name == Some(want) {
17746                    return node.id;
17747                }
17748            }
17749            panic!("no node named {want}");
17750        };
17751        let a_id = find(&bwd, "A");
17752        let b_id = find(&bwd, "b");
17753        let d_out_id = find(&bwd, "d_output");
17754
17755        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
17756        let mut a_data = vec![0.0_f64; batch * n * n];
17757        let mut b_data = vec![0.0_f64; batch * n];
17758        for bi in 0..batch {
17759            for i in 0..n {
17760                for j in 0..n {
17761                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
17762                }
17763                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
17764            }
17765            for i in 0..n {
17766                b_data[bi * n + i] = rng.next_f32() as f64;
17767            }
17768        }
17769        let d_seed = [1.0_f64];
17770
17771        let (sched, mut arena) = prepare_f64(
17772            &bwd,
17773            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
17774        );
17775        execute_thunks(&sched, arena.raw_buf_mut());
17776        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
17777        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
17778
17779        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
17780        // dA_i = -dB_i ⊗ x_i.
17781        for bi in 0..batch {
17782            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
17783            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
17784            let mut a_copy = a_slice.clone();
17785            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
17786            let x_ref = b_slice.clone();
17787            // dB: solve(A^T, ones)
17788            let mut at = vec![0.0_f64; n * n];
17789            for i in 0..n {
17790                for j in 0..n {
17791                    at[i * n + j] = a_slice[j * n + i];
17792                }
17793            }
17794            let mut ones = vec![1.0_f64; n];
17795            crate::blas::dgesv(&mut at, &mut ones, n, 1);
17796            let db_ref = ones;
17797            for i in 0..n {
17798                let got = db_out[bi * n + i];
17799                assert!(
17800                    (got - db_ref[i]).abs() < 1e-10,
17801                    "batch {bi}, db[{i}]: got {got} ref {}",
17802                    db_ref[i]
17803                );
17804            }
17805            // dA: -outer(db, x)
17806            for i in 0..n {
17807                for j in 0..n {
17808                    let got = da_out[bi * n * n + i * n + j];
17809                    let want = -db_ref[i] * x_ref[j];
17810                    assert!(
17811                        (got - want).abs() < 1e-10,
17812                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
17813                    );
17814                }
17815            }
17816        }
17817    }
17818
17819    /// AD knob: gradient through `scan_checkpointed` automatically
17820    /// uses the recompute backward path. Compares dinit from a plain
17821    /// scan against the same forward written with `scan_checkpointed`,
17822    /// both run through `grad_with_loss`. They must match to f64.
17823    #[test]
17824    fn scan_checkpointed_grad_matches_plain_scan_grad() {
17825        use rlx_opt::autodiff::grad_with_loss;
17826        let n = 2usize;
17827        let length = 6u32;
17828
17829        let make_body = || {
17830            let mut body = Graph::new("ck_body");
17831            let carry = body.input("carry", Shape::new(&[n], DType::F64));
17832            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
17833            let scale = body.add_node(
17834                Op::Constant { data: scale_bytes },
17835                vec![],
17836                Shape::new(&[n], DType::F64),
17837            );
17838            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
17839            body.set_outputs(vec![next]);
17840            body
17841        };
17842
17843        // Plain scan path.
17844        let mut g_plain = Graph::new("ck_plain");
17845        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
17846        let final_p = g_plain.scan(init_p, make_body(), length);
17847        let loss_p = g_plain.reduce(
17848            final_p,
17849            ReduceOp::Sum,
17850            vec![0],
17851            false,
17852            Shape::new(&[1], DType::F64),
17853        );
17854        g_plain.set_outputs(vec![loss_p]);
17855        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
17856
17857        // Checkpointed scan path with K=2 (length=6).
17858        let mut g_ck = Graph::new("ck_ckpt");
17859        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
17860        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
17861        let loss_c = g_ck.reduce(
17862            final_c,
17863            ReduceOp::Sum,
17864            vec![0],
17865            false,
17866            Shape::new(&[1], DType::F64),
17867        );
17868        g_ck.set_outputs(vec![loss_c]);
17869        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
17870
17871        let find = |graph: &Graph, want: &str| -> NodeId {
17872            for node in graph.nodes() {
17873                let name = match &node.op {
17874                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17875                    _ => None,
17876                };
17877                if name == Some(want) {
17878                    return node.id;
17879                }
17880            }
17881            panic!("no {want}");
17882        };
17883
17884        let init_data = vec![0.5_f64, -0.5];
17885        let d_seed = [1.0_f64];
17886
17887        let (s_p, mut a_p) = prepare_f64(
17888            &bwd_p,
17889            &[
17890                (find(&bwd_p, "init"), &init_data),
17891                (find(&bwd_p, "d_output"), &d_seed),
17892            ],
17893        );
17894        execute_thunks(&s_p, a_p.raw_buf_mut());
17895        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
17896
17897        let (s_c, mut a_c) = prepare_f64(
17898            &bwd_c,
17899            &[
17900                (find(&bwd_c, "init"), &init_data),
17901                (find(&bwd_c, "d_output"), &d_seed),
17902            ],
17903        );
17904        execute_thunks(&s_c, a_c.raw_buf_mut());
17905        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
17906
17907        for i in 0..n {
17908            assert!(
17909                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
17910                "dinit[{i}]: plain={} checkpointed={}",
17911                dinit_p[i],
17912                dinit_c[i]
17913            );
17914        }
17915    }
17916
17917    /// Recursive checkpointing end-to-end: build a ScanBackward
17918    /// configured with K=2 checkpoints (for length=4), and compare
17919    /// dinit against the same backward graph with full trajectory
17920    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
17921    /// Both paths must agree to f64 precision.
17922    #[test]
17923    fn recursive_checkpointing_matches_full_trajectory() {
17924        let n = 2usize;
17925        let length = 4u32;
17926
17927        // Body: carry + ones (deterministic, no xs)
17928        let build_body = || -> Graph {
17929            let mut body = Graph::new("rc_body");
17930            let carry = body.input("carry", Shape::new(&[n], DType::F64));
17931            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
17932            let ones = body.add_node(
17933                Op::Constant { data: ones_bytes },
17934                vec![],
17935                Shape::new(&[n], DType::F64),
17936            );
17937            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
17938            body.set_outputs(vec![next]);
17939            body
17940        };
17941
17942        // body_vjp: same body + d_output, output dcarry. body_vjp is
17943        // used by ScanBackward to walk the chain rule per step.
17944        let body_vjp_for = || -> Graph {
17945            use rlx_opt::autodiff::grad;
17946            let body = build_body();
17947            // grad(body, [carry_id]) → graph with dcarry as the output.
17948            let carry_id = body
17949                .nodes()
17950                .iter()
17951                .find(|n| matches!(n.op, Op::Input { .. }))
17952                .map(|n| n.id)
17953                .unwrap();
17954            grad(&body, &[carry_id])
17955        };
17956
17957        // ── Forward (All-strategy): scan with full trajectory ──
17958        let mut g_full = Graph::new("rc_outer_full");
17959        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
17960        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
17961        // Hand-build a ScanBackward node that reads the full trajectory.
17962        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17963        let dinit_full_id = g_full.scan_backward(
17964            init_full,
17965            traj_full_id,
17966            upstream_full,
17967            &[],
17968            body_vjp_for(),
17969            length,
17970            true,
17971            Shape::new(&[n], DType::F64),
17972        );
17973        g_full.set_outputs(vec![dinit_full_id]);
17974
17975        // ── Forward (Recursive-2): scan saves only K=2 rows ──
17976        // Build the trajectory shape [K, *carry] = [2, 2].
17977        let k = 2u32;
17978        let mut g_rec = Graph::new("rc_outer_rec");
17979        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
17980        let traj_rec_id = g_rec.add_node(
17981            Op::Scan {
17982                body: Box::new(build_body()),
17983                length,
17984                save_trajectory: true,
17985                num_bcast: 0,
17986                num_xs: 0,
17987                num_checkpoints: k,
17988            },
17989            vec![init_rec],
17990            Shape::new(&[k as usize, n], DType::F64),
17991        );
17992        // Same upstream shape as the full version (the upstream is per
17993        // *forward step*, length rows — independent of K).
17994        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
17995        let dinit_rec_id = g_rec.add_node(
17996            Op::ScanBackward {
17997                body_vjp: Box::new(body_vjp_for()),
17998                length,
17999                save_trajectory: true,
18000                num_xs: 0,
18001                num_checkpoints: k,
18002                forward_body: Some(Box::new(build_body())),
18003            },
18004            vec![init_rec, traj_rec_id, upstream_rec],
18005            Shape::new(&[n], DType::F64),
18006        );
18007        g_rec.set_outputs(vec![dinit_rec_id]);
18008
18009        // ── Run both, same inputs ──
18010        let init_data = vec![0.5_f64, -0.5];
18011        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
18012
18013        let find = |graph: &Graph, want: &str| -> NodeId {
18014            for node in graph.nodes() {
18015                if let Op::Input { name } = &node.op
18016                    && name == want
18017                {
18018                    return node.id;
18019                }
18020            }
18021            panic!("no input {want}");
18022        };
18023
18024        let (s_full, mut a_full) = prepare_f64(
18025            &g_full,
18026            &[
18027                (find(&g_full, "init"), &init_data),
18028                (find(&g_full, "upstream"), &upstream_data),
18029            ],
18030        );
18031        execute_thunks(&s_full, a_full.raw_buf_mut());
18032        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
18033
18034        let (s_rec, mut a_rec) = prepare_f64(
18035            &g_rec,
18036            &[
18037                (find(&g_rec, "init"), &init_data),
18038                (find(&g_rec, "upstream"), &upstream_data),
18039            ],
18040        );
18041        execute_thunks(&s_rec, a_rec.raw_buf_mut());
18042        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
18043
18044        for i in 0..n {
18045            assert!(
18046                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
18047                "i={i}: full={} rec={}",
18048                dinit_full[i],
18049                dinit_rec[i]
18050            );
18051        }
18052    }
18053
18054    /// vmap-of-grad: gradient through Scan, vmap'd over init.
18055    /// Forward (per row):
18056    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
18057    ///   loss = sum(carry_length) = sum(init) + length·n
18058    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
18059    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
18060    /// against per-row grad_with_loss runs. Validates the vmap rule
18061    /// for Op::ScanBackward.
18062    #[test]
18063    fn vmap_of_grad_scan_matches_per_row_runs() {
18064        use rlx_opt::autodiff::grad_with_loss;
18065        use rlx_opt::vmap::vmap;
18066        let n = 2usize;
18067        let length = 3u32;
18068        let batch = 3usize;
18069
18070        let mut body = Graph::new("scan_grad_body");
18071        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18072        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18073        let ones = body.add_node(
18074            Op::Constant { data: ones_bytes },
18075            vec![],
18076            Shape::new(&[n], DType::F64),
18077        );
18078        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18079        body.set_outputs(vec![next]);
18080
18081        let mut g = Graph::new("scan_grad_outer");
18082        let init = g.input("init", Shape::new(&[n], DType::F64));
18083        let final_x = g.scan(init, body, length);
18084        let loss = g.reduce(
18085            final_x,
18086            ReduceOp::Sum,
18087            vec![0],
18088            false,
18089            Shape::new(&[1], DType::F64),
18090        );
18091        g.set_outputs(vec![loss]);
18092
18093        let bwd = grad_with_loss(&g, &[init]);
18094        let bg = vmap(&bwd, &["init"], batch);
18095
18096        let find = |graph: &Graph, want: &str| -> NodeId {
18097            for node in graph.nodes() {
18098                let name = match &node.op {
18099                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18100                    _ => None,
18101                };
18102                if name == Some(want) {
18103                    return node.id;
18104                }
18105            }
18106            panic!("no node named {want}");
18107        };
18108        let init_b = find(&bg, "init");
18109        let d_out_b = find(&bg, "d_output");
18110
18111        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18112        let d_seed = [1.0_f64];
18113
18114        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18115        execute_thunks(&sched, arena.raw_buf_mut());
18116        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18117
18118        for i in 0..batch * n {
18119            assert!(
18120                (dinit_b[i] - 1.0).abs() < 1e-12,
18121                "dinit[{i}] = {} (expected 1.0)",
18122                dinit_b[i]
18123            );
18124        }
18125
18126        // Cross-check vs per-row grad_with_loss.
18127        for bi in 0..batch {
18128            let row = &init_data[bi * n..(bi + 1) * n];
18129            let mut g2 = Graph::new("per_row_grad");
18130            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18131            let mut body2 = Graph::new("per_row_body");
18132            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18133            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18134            let ones2 = body2.add_node(
18135                Op::Constant { data: ones2_bytes },
18136                vec![],
18137                Shape::new(&[n], DType::F64),
18138            );
18139            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18140            body2.set_outputs(vec![next2]);
18141            let final2 = g2.scan(init2, body2, length);
18142            let loss2 = g2.reduce(
18143                final2,
18144                ReduceOp::Sum,
18145                vec![0],
18146                false,
18147                Shape::new(&[1], DType::F64),
18148            );
18149            g2.set_outputs(vec![loss2]);
18150            let bwd2 = grad_with_loss(&g2, &[init2]);
18151            let init2_id = find(&bwd2, "init");
18152            let d_out2_id = find(&bwd2, "d_output");
18153            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
18154            execute_thunks(&s2, a2.raw_buf_mut());
18155            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
18156            for j in 0..n {
18157                let got = dinit_b[bi * n + j];
18158                let want = row_dinit[j];
18159                assert!(
18160                    (got - want).abs() < 1e-12,
18161                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
18162                );
18163            }
18164        }
18165    }
18166
18167    /// vmap of Op::Scan: batched cumulative-sum. Forward
18168    ///   carry_{t+1} = carry_t + xs\[t\]
18169    ///   final = init + sum(xs)
18170    /// vmap over both init and xs at batch=3. Each batch row should
18171    /// equal the scalar run of the same body+xs subset.
18172    #[test]
18173    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
18174        use rlx_opt::vmap::vmap;
18175        let n = 2usize;
18176        let length = 4u32;
18177        let batch = 3usize;
18178
18179        // Body: (carry, x_t) → carry + x_t
18180        let mut body = Graph::new("scan_body_cumsum");
18181        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18182        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18183        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18184        body.set_outputs(vec![next]);
18185
18186        let mut g = Graph::new("scan_outer_cumsum");
18187        let init = g.input("init", Shape::new(&[n], DType::F64));
18188        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18189        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18190        g.set_outputs(vec![final_carry]);
18191
18192        // vmap over both init and xs.
18193        let bg = vmap(&g, &["init", "xs"], batch);
18194
18195        // Test data — distinct per-batch rows.
18196        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
18197        // xs has shape [B, length, n] after vmap (the outer's xs is
18198        // [length, n]; vmap lifts it to [B, length, n]).
18199        let xs_data: Vec<f64> = (0..batch * length as usize * n)
18200            .map(|i| 0.1 * (i as f64))
18201            .collect();
18202
18203        let find = |graph: &Graph, want: &str| -> NodeId {
18204            for node in graph.nodes() {
18205                if let Op::Input { name } = &node.op
18206                    && name == want
18207                {
18208                    return node.id;
18209                }
18210            }
18211            panic!("no input {want}");
18212        };
18213        let init_b = find(&bg, "init");
18214        let xs_b = find(&bg, "xs");
18215        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
18216        execute_thunks(&sched, arena.raw_buf_mut());
18217        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
18218
18219        // Reference: per-batch scalar Scan.
18220        for bi in 0..batch {
18221            let init_slice = &init_data[bi * n..(bi + 1) * n];
18222            let mut x = init_slice.to_vec();
18223            for t in 0..length as usize {
18224                for j in 0..n {
18225                    x[j] += xs_data[bi * length as usize * n + t * n + j];
18226                }
18227            }
18228
18229            for i in 0..n {
18230                let got = batched_out[bi * n + i];
18231                assert!(
18232                    (got - x[i]).abs() < 1e-12,
18233                    "row {bi}, i {i}: got {got} ref {}",
18234                    x[i]
18235                );
18236            }
18237        }
18238    }
18239
18240    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
18241    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
18242    /// and b (batched [B,N]). Run on CPU and compare each batch row
18243    /// against an independent scalar dgesv.
18244    #[test]
18245    fn vmap_dense_solve_matches_scalar_runs() {
18246        use rlx_opt::vmap::vmap;
18247        let n = 3usize;
18248        let batch = 4usize;
18249
18250        let mut g = Graph::new("solve_forward");
18251        let a = g.input("A", Shape::new(&[n, n], DType::F64));
18252        let b = g.input("b", Shape::new(&[n], DType::F64));
18253        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18254        g.set_outputs(vec![x]);
18255
18256        // vmap both A and b across the batch.
18257        let bg = vmap(&g, &["A", "b"], batch);
18258
18259        // Independent A and b per batch row.
18260        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
18261        let mut a_data = vec![0.0_f64; batch * n * n];
18262        let mut b_data = vec![0.0_f64; batch * n];
18263        for bi in 0..batch {
18264            // Diagonally dominant A — guaranteed non-singular.
18265            for i in 0..n {
18266                for j in 0..n {
18267                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18268                }
18269                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18270            }
18271            for i in 0..n {
18272                b_data[bi * n + i] = rng.next_f32() as f64;
18273            }
18274        }
18275
18276        let find = |graph: &Graph, want: &str| -> NodeId {
18277            for node in graph.nodes() {
18278                if let Op::Input { name } = &node.op
18279                    && name == want
18280                {
18281                    return node.id;
18282                }
18283            }
18284            panic!("no input named {want}");
18285        };
18286        let ba = find(&bg, "A");
18287        let bb = find(&bg, "b");
18288        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
18289        execute_thunks(&sched, arena.raw_buf_mut());
18290        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
18291
18292        // Reference: per-batch dgesv.
18293        for bi in 0..batch {
18294            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18295            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18296            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
18297            for i in 0..n {
18298                let got = batched_x[bi * n + i];
18299                let want = b_slice[i];
18300                assert!(
18301                    (got - want).abs() < 1e-12,
18302                    "row {bi}, i {i}: got {got} want {want}"
18303                );
18304            }
18305        }
18306    }
18307
18308    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
18309    /// and reduces to a per-element loss. vmap over x with batch=4.
18310    /// Run the batched graph and compare each output row against an
18311    /// independent scalar run of the original graph. Validates the
18312    /// structural lift + the runtime path for batched MatMul +
18313    /// batched Binary + batched Reduce.
18314    #[test]
18315    fn vmap_matmul_add_reduce_matches_scalar_runs() {
18316        use rlx_opt::vmap::vmap;
18317        let n = 3usize;
18318        let batch = 4usize;
18319
18320        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
18321        let mut g = Graph::new("vmap_e2e_forward");
18322        let x = g.input("x", Shape::new(&[n], DType::F64));
18323        let w = g.input("w", Shape::new(&[n, n], DType::F64));
18324        let b = g.input("b", Shape::new(&[n], DType::F64));
18325        let x_row = g.add_node(
18326            Op::Reshape {
18327                new_shape: vec![1, n as i64],
18328            },
18329            vec![x],
18330            Shape::new(&[1, n], DType::F64),
18331        );
18332        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
18333        let mm_flat = g.add_node(
18334            Op::Reshape {
18335                new_shape: vec![n as i64],
18336            },
18337            vec![mm],
18338            Shape::new(&[n], DType::F64),
18339        );
18340        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
18341        let loss = g.reduce(
18342            yv,
18343            ReduceOp::Sum,
18344            vec![0],
18345            false,
18346            Shape::new(&[1], DType::F64),
18347        );
18348        g.set_outputs(vec![loss]);
18349
18350        // Build the vmap'd version (batch over x; w and b shared).
18351        let bg = vmap(&g, &["x"], batch);
18352
18353        // Test data — distinct rows so we can verify the per-row dispatch.
18354        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
18355        let n_w = n * n;
18356        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
18357        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
18358        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
18359        for _ in 0..batch * n {
18360            x_data_batched.push(rng.next_f32() as f64);
18361        }
18362
18363        // Run the batched graph.
18364        let find = |graph: &Graph, want: &str| -> NodeId {
18365            for node in graph.nodes() {
18366                if let Op::Input { name } = &node.op
18367                    && name == want
18368                {
18369                    return node.id;
18370                }
18371            }
18372            panic!("no input named {want}");
18373        };
18374        let bx = find(&bg, "x");
18375        let bw = find(&bg, "w");
18376        let bb = find(&bg, "b");
18377        let (sched, mut arena) =
18378            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
18379        execute_thunks(&sched, arena.raw_buf_mut());
18380        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
18381        // (it preserves the leading batch axis but reduces what was [n] to [].
18382        // Since the original output was [1] f64 and the reduce was over
18383        // axis 0, after vmap the leading-axis-shifted reduce keeps the
18384        // leading 1 from the original output's [1] shape.)
18385        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
18386
18387        // Reference: run the original (un-batched) graph once per batch row.
18388        for bi in 0..batch {
18389            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
18390            let mut g2 = Graph::new("scalar_run");
18391            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
18392            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
18393            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
18394            let xr = g2.add_node(
18395                Op::Reshape {
18396                    new_shape: vec![1, n as i64],
18397                },
18398                vec![x2],
18399                Shape::new(&[1, n], DType::F64),
18400            );
18401            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
18402            let mf = g2.add_node(
18403                Op::Reshape {
18404                    new_shape: vec![n as i64],
18405                },
18406                vec![m],
18407                Shape::new(&[n], DType::F64),
18408            );
18409            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
18410            let l2 = g2.reduce(
18411                yv2,
18412                ReduceOp::Sum,
18413                vec![0],
18414                false,
18415                Shape::new(&[1], DType::F64),
18416            );
18417            g2.set_outputs(vec![l2]);
18418            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
18419            execute_thunks(&s2, a2.raw_buf_mut());
18420            let scalar_out = read_arena_f64(&a2, l2, 1);
18421            assert!(
18422                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
18423                "row {bi}: batched={} scalar={}",
18424                batched_out[bi],
18425                scalar_out[0]
18426            );
18427        }
18428    }
18429
18430    /// Full gradient through scan-with-xs: dinit AND dxs both checked
18431    /// against finite differences. Forward
18432    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
18433    ///   loss        = sum(carry_length)
18434    /// Verifies that grad_with_loss returns gradients w.r.t. both
18435    /// `init` and `xs` and that dxs matches per-element FD.
18436    #[test]
18437    fn scan_with_xs_dxs_matches_fd() {
18438        use rlx_opt::autodiff::grad_with_loss;
18439        let n = 3usize;
18440        let length = 3u32;
18441        let dt = 0.1_f64;
18442
18443        let mut m_data = vec![0.0_f64; n * n];
18444        for i in 0..n {
18445            m_data[i * n + i] = 1.0 + dt * 2.0;
18446            if i > 0 {
18447                m_data[i * n + (i - 1)] = -dt;
18448            }
18449            if i + 1 < n {
18450                m_data[i * n + (i + 1)] = -dt;
18451            }
18452        }
18453        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18454
18455        let mut body = Graph::new("be_dxs_body");
18456        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18457        let drive = body.input("drive", Shape::new(&[n], DType::F64));
18458        let m = body.add_node(
18459            Op::Constant { data: m_bytes },
18460            vec![],
18461            Shape::new(&[n, n], DType::F64),
18462        );
18463        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18464        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18465        body.set_outputs(vec![next]);
18466
18467        let mut g = Graph::new("be_dxs_outer");
18468        let init = g.input("init", Shape::new(&[n], DType::F64));
18469        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18470        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18471        let loss = g.reduce(
18472            final_carry,
18473            ReduceOp::Sum,
18474            vec![0],
18475            false,
18476            Shape::new(&[1], DType::F64),
18477        );
18478        g.set_outputs(vec![loss]);
18479
18480        // wrt = [init, xs] — get both gradients back.
18481        let bwd = grad_with_loss(&g, &[init, xs]);
18482        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
18483
18484        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18485            for node in graph.nodes() {
18486                let name = match &node.op {
18487                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18488                    _ => None,
18489                };
18490                if name == Some(want) {
18491                    return node.id;
18492                }
18493            }
18494            panic!("no node named {want:?}");
18495        };
18496        let init_bwd = find_by_name(&bwd, "init");
18497        let xs_bwd = find_by_name(&bwd, "xs");
18498        let d_out_bwd = find_by_name(&bwd, "d_output");
18499
18500        let init_data = vec![0.5_f64, 0.0, -0.5];
18501        let xs_data: Vec<f64> = (0..length as usize * n)
18502            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18503            .collect();
18504        let d_seed = [1.0_f64];
18505
18506        let (sched, mut arena) = prepare_f64(
18507            &bwd,
18508            &[
18509                (init_bwd, &init_data),
18510                (xs_bwd, &xs_data),
18511                (d_out_bwd, &d_seed),
18512            ],
18513        );
18514        execute_thunks(&sched, arena.raw_buf_mut());
18515        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18516        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
18517
18518        let h = 1e-6;
18519        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
18520            let mut acc = x0.to_vec();
18521            for t in 0..length as usize {
18522                for j in 0..n {
18523                    acc[j] += xs_in[t * n + j];
18524                }
18525                let mut a_copy = m_data.clone();
18526                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18527            }
18528            acc.iter().sum()
18529        };
18530
18531        // FD on dinit (sanity).
18532        for i in 0..n {
18533            let mut ip = init_data.to_vec();
18534            ip[i] += h;
18535            let mut im = init_data.to_vec();
18536            im[i] -= h;
18537            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
18538            assert!(
18539                (dinit[i] - fd).abs() < 1e-7,
18540                "FD dinit[{i}]: AD={} FD={}",
18541                dinit[i],
18542                fd
18543            );
18544        }
18545
18546        // FD on every dxs entry — full per-step gradient check.
18547        for t in 0..length as usize {
18548            for j in 0..n {
18549                let idx = t * n + j;
18550                let mut xp = xs_data.clone();
18551                xp[idx] += h;
18552                let mut xm = xs_data.clone();
18553                xm[idx] -= h;
18554                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
18555                assert!(
18556                    (dxs[idx] - fd).abs() < 1e-7,
18557                    "FD dxs[t={t},j={j}]: AD={} FD={}",
18558                    dxs[idx],
18559                    fd
18560                );
18561            }
18562        }
18563    }
18564
18565    /// Gradient through a scan with per-step xs (Circulax-shaped).
18566    /// Forward:
18567    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
18568    ///   loss = sum(carry_length)
18569    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
18570    /// but `dinit` flows correctly through the body's reverse Jacobian
18571    /// even with xs in the chain. Verify dinit against finite differences.
18572    #[test]
18573    fn scan_with_xs_gradient_dinit_matches_fd() {
18574        use rlx_opt::autodiff::grad_with_loss;
18575        let n = 3usize;
18576        let length = 3u32;
18577        let dt = 0.1_f64;
18578
18579        let mut m_data = vec![0.0_f64; n * n];
18580        for i in 0..n {
18581            m_data[i * n + i] = 1.0 + dt * 2.0;
18582            if i > 0 {
18583                m_data[i * n + (i - 1)] = -dt;
18584            }
18585            if i + 1 < n {
18586                m_data[i * n + (i + 1)] = -dt;
18587            }
18588        }
18589        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18590
18591        let mut body = Graph::new("be_xs_grad_body");
18592        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18593        let drive = body.input("drive", Shape::new(&[n], DType::F64));
18594        let m = body.add_node(
18595            Op::Constant { data: m_bytes },
18596            vec![],
18597            Shape::new(&[n, n], DType::F64),
18598        );
18599        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18600        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18601        body.set_outputs(vec![next]);
18602
18603        let mut g = Graph::new("be_xs_grad_outer");
18604        let init = g.input("init", Shape::new(&[n], DType::F64));
18605        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18606        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18607        let loss = g.reduce(
18608            final_carry,
18609            ReduceOp::Sum,
18610            vec![0],
18611            false,
18612            Shape::new(&[1], DType::F64),
18613        );
18614        g.set_outputs(vec![loss]);
18615
18616        let bwd = grad_with_loss(&g, &[init]);
18617
18618        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18619            for node in graph.nodes() {
18620                let name = match &node.op {
18621                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18622                    _ => None,
18623                };
18624                if name == Some(want) {
18625                    return node.id;
18626                }
18627            }
18628            panic!("no node named {want:?}");
18629        };
18630        let init_bwd = find_by_name(&bwd, "init");
18631        let xs_bwd = find_by_name(&bwd, "xs");
18632        let d_out_bwd = find_by_name(&bwd, "d_output");
18633
18634        let init_data = vec![0.5_f64, 0.0, -0.5];
18635        // Drive: small per-step pulse, varying per element.
18636        let xs_data: Vec<f64> = (0..length as usize * n)
18637            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
18638            .collect();
18639        let d_seed = [1.0_f64];
18640
18641        let (sched, mut arena) = prepare_f64(
18642            &bwd,
18643            &[
18644                (init_bwd, &init_data),
18645                (xs_bwd, &xs_data),
18646                (d_out_bwd, &d_seed),
18647            ],
18648        );
18649        execute_thunks(&sched, arena.raw_buf_mut());
18650        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18651
18652        let h = 1e-6;
18653        let loss_at = |x0: &[f64]| -> f64 {
18654            let mut acc = x0.to_vec();
18655            for t in 0..length as usize {
18656                for j in 0..n {
18657                    acc[j] += xs_data[t * n + j];
18658                }
18659                let mut a_copy = m_data.clone();
18660                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18661            }
18662            acc.iter().sum()
18663        };
18664        for i in 0..n {
18665            let mut ip = init_data.to_vec();
18666            ip[i] += h;
18667            let mut im = init_data.to_vec();
18668            im[i] -= h;
18669            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18670            assert!(
18671                (dinit[i] - fd).abs() < 1e-7,
18672                "FD dinit[{i}]: AD={} FD={}",
18673                dinit[i],
18674                fd
18675            );
18676        }
18677    }
18678
18679    /// Gradient through a geometric-growth scan: forward
18680    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
18681    ///   final   = x_length     = init · 1.1^length
18682    ///   loss    = sum(final)
18683    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
18684    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
18685    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
18686    #[test]
18687    fn scan_gradient_geometric_matches_closed_form() {
18688        use rlx_opt::autodiff::grad_with_loss;
18689        let n = 3usize;
18690        let length = 5u32;
18691
18692        let mut body = Graph::new("scan_grad_body");
18693        let x = body.input("carry", Shape::new(&[n], DType::F64));
18694        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
18695        let scale = body.add_node(
18696            Op::Constant { data: scale_bytes },
18697            vec![],
18698            Shape::new(&[n], DType::F64),
18699        );
18700        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18701        body.set_outputs(vec![next]);
18702
18703        let mut g = Graph::new("scan_grad_outer");
18704        let init = g.input("init", Shape::new(&[n], DType::F64));
18705        let final_x = g.scan(init, body, length);
18706        let loss = g.reduce(
18707            final_x,
18708            ReduceOp::Sum,
18709            vec![0],
18710            false,
18711            Shape::new(&[1], DType::F64),
18712        );
18713        g.set_outputs(vec![loss]);
18714
18715        let bwd = grad_with_loss(&g, &[init]);
18716        assert_eq!(bwd.outputs.len(), 2);
18717
18718        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18719            for node in graph.nodes() {
18720                let name = match &node.op {
18721                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18722                    _ => None,
18723                };
18724                if name == Some(want) {
18725                    return node.id;
18726                }
18727            }
18728            panic!("no node named {want:?}");
18729        };
18730        let init_bwd = find_by_name(&bwd, "init");
18731        let d_out_bwd = find_by_name(&bwd, "d_output");
18732
18733        let init_data = vec![1.0_f64; n];
18734        let d_seed = [1.0_f64];
18735        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18736        execute_thunks(&sched, arena.raw_buf_mut());
18737        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18738
18739        let want = 1.1_f64.powi(length as i32);
18740        for i in 0..n {
18741            assert!(
18742                (dinit[i] - want).abs() < 1e-12,
18743                "dinit[{i}] = {} want {}",
18744                dinit[i],
18745                want
18746            );
18747        }
18748
18749        // Finite-difference cross-check on init[0].
18750        let h = 1e-6;
18751        let loss_at = |x: &[f64]| -> f64 {
18752            let mut acc = x.to_vec();
18753            for _ in 0..length {
18754                for v in acc.iter_mut() {
18755                    *v *= 1.1;
18756                }
18757            }
18758            acc.iter().sum()
18759        };
18760        let mut ip = init_data.clone();
18761        ip[0] += h;
18762        let mut im = init_data.clone();
18763        im[0] -= h;
18764        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18765        assert!(
18766            (dinit[0] - fd).abs() < 1e-7,
18767            "FD dinit[0]: AD={} FD={}",
18768            dinit[0],
18769            fd
18770        );
18771    }
18772
18773    /// Gradient through Backward Euler scan composing with DenseSolve.
18774    /// Asserts dinit matches finite-difference per coordinate.
18775    #[test]
18776    fn scan_gradient_backward_euler_matches_fd() {
18777        use rlx_opt::autodiff::grad_with_loss;
18778        let n = 4usize;
18779        let length = 3u32;
18780        let dt = 0.05_f64;
18781
18782        let mut m_data = vec![0.0_f64; n * n];
18783        for i in 0..n {
18784            m_data[i * n + i] = 1.0 + dt * 2.0;
18785            if i > 0 {
18786                m_data[i * n + (i - 1)] = -dt;
18787            }
18788            if i + 1 < n {
18789                m_data[i * n + (i + 1)] = -dt;
18790            }
18791        }
18792        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18793
18794        let mut body = Graph::new("be_grad_body");
18795        let x = body.input("x", Shape::new(&[n], DType::F64));
18796        let m = body.add_node(
18797            Op::Constant { data: m_bytes },
18798            vec![],
18799            Shape::new(&[n, n], DType::F64),
18800        );
18801        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18802        body.set_outputs(vec![next]);
18803
18804        let mut g = Graph::new("be_grad_outer");
18805        let init = g.input("x0", Shape::new(&[n], DType::F64));
18806        let final_x = g.scan(init, body, length);
18807        let loss = g.reduce(
18808            final_x,
18809            ReduceOp::Sum,
18810            vec![0],
18811            false,
18812            Shape::new(&[1], DType::F64),
18813        );
18814        g.set_outputs(vec![loss]);
18815
18816        let bwd = grad_with_loss(&g, &[init]);
18817
18818        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18819            for node in graph.nodes() {
18820                let name = match &node.op {
18821                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18822                    _ => None,
18823                };
18824                if name == Some(want) {
18825                    return node.id;
18826                }
18827            }
18828            panic!("no node named {want:?}");
18829        };
18830        let init_bwd = find_by_name(&bwd, "x0");
18831        let d_out_bwd = find_by_name(&bwd, "d_output");
18832
18833        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18834        let d_seed = [1.0_f64];
18835        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
18836        execute_thunks(&sched, arena.raw_buf_mut());
18837        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
18838
18839        let h = 1e-6;
18840        let loss_at = |x0: &[f64]| -> f64 {
18841            let mut acc = x0.to_vec();
18842            for _ in 0..length {
18843                let mut a_copy = m_data.clone();
18844                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
18845            }
18846            acc.iter().sum()
18847        };
18848        for i in 0..n {
18849            let mut ip = init_data.to_vec();
18850            ip[i] += h;
18851            let mut im = init_data.to_vec();
18852            im[i] -= h;
18853            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
18854            assert!(
18855                (dinit[i] - fd).abs() < 1e-7,
18856                "FD dinit[{i}]: AD={} FD={}",
18857                dinit[i],
18858                fd
18859            );
18860        }
18861    }
18862
18863    /// Trajectory-mode scan: same Backward Euler body, but record the
18864    /// carry at every step. Output is `[length, n]` — row `t` is the
18865    /// state after step `t+1`. Validates the SaveAt-style waveform
18866    /// recording end-to-end, including that the last row equals what
18867    /// the no-trajectory variant would have returned.
18868    #[test]
18869    fn scan_trajectory_backward_euler_records_waveform() {
18870        let n = 4usize;
18871        let length = 5u32;
18872        let dt = 0.05_f64;
18873
18874        let mut m_data = vec![0.0_f64; n * n];
18875        for i in 0..n {
18876            m_data[i * n + i] = 1.0 + dt * 2.0;
18877            if i > 0 {
18878                m_data[i * n + (i - 1)] = -dt;
18879            }
18880            if i + 1 < n {
18881                m_data[i * n + (i + 1)] = -dt;
18882            }
18883        }
18884        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18885
18886        let mut body = Graph::new("be_traj_body");
18887        let x = body.input("x", Shape::new(&[n], DType::F64));
18888        let m = body.add_node(
18889            Op::Constant { data: m_bytes },
18890            vec![],
18891            Shape::new(&[n, n], DType::F64),
18892        );
18893        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
18894        body.set_outputs(vec![next]);
18895
18896        let mut g = Graph::new("be_traj_outer");
18897        let init = g.input("x0", Shape::new(&[n], DType::F64));
18898        let traj = g.scan_trajectory(init, body, length);
18899        g.set_outputs(vec![traj]);
18900
18901        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
18902        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18903        execute_thunks(&sched, arena.raw_buf_mut());
18904        let got = read_arena_f64(&arena, traj, length as usize * n);
18905
18906        // Reference: each step's solve, recorded.
18907        let mut want = Vec::<f64>::with_capacity(length as usize * n);
18908        let mut x_ref = init_data.to_vec();
18909        for _ in 0..length {
18910            let mut a_copy = m_data.clone();
18911            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
18912            want.extend_from_slice(&x_ref);
18913        }
18914        for i in 0..length as usize * n {
18915            assert!(
18916                (got[i] - want[i]).abs() < 1e-12,
18917                "got[{i}] = {} ref {}",
18918                got[i],
18919                want[i]
18920            );
18921        }
18922
18923        // Sanity: trajectory rows are monotone-decreasing in mass
18924        // (Backward Euler diffuses; boundary leak removes mass).
18925        for t in 1..length as usize {
18926            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
18927            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
18928            assert!(
18929                curr <= prev + 1e-15,
18930                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
18931                t - 1
18932            );
18933        }
18934
18935        // Last row of the trajectory equals what a non-trajectory
18936        // scan returns — verify by running the same forward through
18937        // the simpler API and comparing.
18938        let mut body2 = Graph::new("be_final_body");
18939        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
18940        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18941        let m2 = body2.add_node(
18942            Op::Constant { data: m_bytes2 },
18943            vec![],
18944            Shape::new(&[n, n], DType::F64),
18945        );
18946        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
18947        body2.set_outputs(vec![next2]);
18948
18949        let mut g2 = Graph::new("be_final_outer");
18950        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
18951        let final_x = g2.scan(init2, body2, length);
18952        g2.set_outputs(vec![final_x]);
18953        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
18954        execute_thunks(&sched2, arena2.raw_buf_mut());
18955        let final_got = read_arena_f64(&arena2, final_x, n);
18956
18957        let last_row = &got[(length as usize - 1) * n..length as usize * n];
18958        for i in 0..n {
18959            assert!(
18960                (last_row[i] - final_got[i]).abs() < 1e-15,
18961                "last trajectory row[{i}] = {} vs final-scan = {}",
18962                last_row[i],
18963                final_got[i]
18964            );
18965        }
18966    }
18967
18968    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
18969    /// pattern for Backward Euler.
18970    /// Body: x_{t+1} = solve(I + dt·A, x_t).
18971    /// 1-D heat-equation Laplacian A; analytic ground truth from
18972    /// composing the same per-step solve in Rust.
18973    #[test]
18974    fn scan_backward_euler_heat_f64() {
18975        let n = 4usize;
18976        let length = 5u32;
18977        let dt = 0.05_f64;
18978
18979        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
18980        // M is constant across iterations; embed it in the body via Op::Constant.
18981        let mut m_data = vec![0.0_f64; n * n];
18982        for i in 0..n {
18983            m_data[i * n + i] = 1.0 + dt * 2.0;
18984            if i > 0 {
18985                m_data[i * n + (i - 1)] = -dt;
18986            }
18987            if i + 1 < n {
18988                m_data[i * n + (i + 1)] = -dt;
18989            }
18990        }
18991        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18992
18993        let mut body = Graph::new("be_body");
18994        let x = body.input("x", Shape::new(&[n], DType::F64));
18995        let m = body.add_node(
18996            Op::Constant { data: m_bytes },
18997            vec![],
18998            Shape::new(&[n, n], DType::F64),
18999        );
19000        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19001        body.set_outputs(vec![next]);
19002
19003        let mut g = Graph::new("be_outer");
19004        let init = g.input("x0", Shape::new(&[n], DType::F64));
19005        let final_x = g.scan(init, body, length);
19006        g.set_outputs(vec![final_x]);
19007
19008        // Initial: a sharp pulse at index 1.
19009        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19010        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19011        execute_thunks(&sched, arena.raw_buf_mut());
19012        let got = read_arena_f64(&arena, final_x, n);
19013
19014        // Reference: apply the same M-solve `length` times in pure Rust.
19015        let mut ref_x = init_data.to_vec();
19016        for _ in 0..length {
19017            let mut a_copy = m_data.clone();
19018            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
19019        }
19020        for i in 0..n {
19021            assert!(
19022                (got[i] - ref_x[i]).abs() < 1e-12,
19023                "got[{i}] = {} ref {}",
19024                got[i],
19025                ref_x[i]
19026            );
19027        }
19028        // Sanity: pulse should diffuse, mass should be conserved-ish
19029        // (Backward Euler is mass-conserving for this stencil with
19030        // zero-flux boundaries — but our boundaries leak, so check
19031        // that mass strictly decreases instead).
19032        let mass: f64 = got.iter().sum();
19033        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
19034    }
19035
19036    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
19037    /// stays correct end-to-end. Verifies the executor/lowering and
19038    /// the LAPACK column-major dance both honour `nrhs > 1`.
19039    #[test]
19040    fn dense_solve_f64_multi_rhs_forward() {
19041        let n = 3usize;
19042        let k = 2usize;
19043        let mut g = Graph::new("solve_multi_rhs");
19044        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19045        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19046        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19047        g.set_outputs(vec![x]);
19048
19049        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19050        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19051        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
19052        execute_thunks(&sched, arena.raw_buf_mut());
19053        let x_got = read_arena_f64(&arena, x, n * k);
19054        for c in 0..k {
19055            for i in 0..n {
19056                let mut acc = 0.0_f64;
19057                for j in 0..n {
19058                    acc += a_data[i * n + j] * x_got[j * k + c];
19059                }
19060                let want = b_data[i * k + c];
19061                assert!(
19062                    (acc - want).abs() < 1e-10,
19063                    "col {c} row {i}: got {acc} want {want}"
19064                );
19065            }
19066        }
19067    }
19068
19069    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
19070    /// Verified analytically + finite differences on dB[0,0].
19071    #[test]
19072    fn dense_solve_f64_multi_rhs_gradient() {
19073        use rlx_opt::autodiff::grad_with_loss;
19074        let n = 3usize;
19075        let k = 2usize;
19076        let mut g = Graph::new("solve_mrhs_grad");
19077        let a = g.param("A", Shape::new(&[n, n], DType::F64));
19078        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19079        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19080        let loss = g.reduce(
19081            x,
19082            ReduceOp::Sum,
19083            vec![0, 1],
19084            false,
19085            Shape::new(&[1], DType::F64),
19086        );
19087        g.set_outputs(vec![loss]);
19088
19089        let bwd = grad_with_loss(&g, &[a, b]);
19090        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19091            for node in graph.nodes() {
19092                let name = match &node.op {
19093                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19094                    _ => None,
19095                };
19096                if name == Some(want) {
19097                    return node.id;
19098                }
19099            }
19100            panic!("no node named {want:?}");
19101        };
19102        let a_bwd = find_by_name(&bwd, "A");
19103        let b_bwd = find_by_name(&bwd, "B");
19104        let d_out = find_by_name(&bwd, "d_output");
19105
19106        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19107        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19108        let d_seed = [1.0_f64];
19109
19110        let (sched, mut arena) = prepare_f64(
19111            &bwd,
19112            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19113        );
19114        execute_thunks(&sched, arena.raw_buf_mut());
19115        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19116        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19117
19118        // Reference.
19119        let mut x_ref = b_data;
19120        {
19121            let mut a_copy = a_data;
19122            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19123        }
19124        let mut at = [0.0_f64; 9];
19125        for i in 0..n {
19126            for j in 0..n {
19127                at[i * n + j] = a_data[j * n + i];
19128            }
19129        }
19130        let mut ones_nk = vec![1.0_f64; n * k];
19131        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19132        let db_ref = ones_nk;
19133        let mut da_ref = [0.0_f64; 9];
19134        for i in 0..n {
19135            for j in 0..n {
19136                let mut acc = 0.0_f64;
19137                for c in 0..k {
19138                    acc += db_ref[i * k + c] * x_ref[j * k + c];
19139                }
19140                da_ref[i * n + j] = -acc;
19141            }
19142        }
19143        for i in 0..n * k {
19144            assert!(
19145                (db_got[i] - db_ref[i]).abs() < 1e-10,
19146                "dB[{i}]: got {} want {}",
19147                db_got[i],
19148                db_ref[i]
19149            );
19150        }
19151        for i in 0..n * n {
19152            assert!(
19153                (da_got[i] - da_ref[i]).abs() < 1e-10,
19154                "dA[{i}]: got {} want {}",
19155                da_got[i],
19156                da_ref[i]
19157            );
19158        }
19159
19160        // FD on dB[0,0].
19161        let h = 1e-6;
19162        let mut bp = b_data;
19163        bp[0] += h;
19164        let mut bm = b_data;
19165        bm[0] -= h;
19166        let xp = {
19167            let mut a_copy = a_data;
19168            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19169            bp
19170        };
19171        let xm = {
19172            let mut a_copy = a_data;
19173            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19174            bm
19175        };
19176        let lp: f64 = xp.iter().sum();
19177        let lm: f64 = xm.iter().sum();
19178        let fd = (lp - lm) / (2.0 * h);
19179        assert!(
19180            (db_got[0] - fd).abs() < 1e-7,
19181            "FD dB[0,0]: AD={} FD={}",
19182            db_got[0],
19183            fd
19184        );
19185    }
19186
19187    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
19188    #[test]
19189    fn dense_solve_f64_multi_rhs_jvp() {
19190        use rlx_opt::autodiff_fwd::jvp;
19191        let n = 3usize;
19192        let k = 2usize;
19193        let mut g = Graph::new("solve_mrhs_jvp");
19194        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19195        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19196        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19197        g.set_outputs(vec![x]);
19198
19199        let jg = jvp(&g, &[b]);
19200        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19201            for node in graph.nodes() {
19202                let name = match &node.op {
19203                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19204                    _ => None,
19205                };
19206                if name == Some(want) {
19207                    return node.id;
19208                }
19209            }
19210            panic!("no node named {want:?}");
19211        };
19212        let a_id = find_by_name(&jg, "A");
19213        let b_id = find_by_name(&jg, "B");
19214        let tb_id = find_by_name(&jg, "tangent_B");
19215
19216        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19217        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19218        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
19219
19220        let (sched, mut arena) =
19221            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19222        execute_thunks(&sched, arena.raw_buf_mut());
19223        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
19224
19225        let mut a_copy = a_data;
19226        let mut tb_copy = tb_data;
19227        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
19228        for i in 0..n * k {
19229            assert!(
19230                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
19231                "t_X[{i}]: AD={} ref={}",
19232                tangent_x[i],
19233                tb_copy[i]
19234            );
19235        }
19236
19237        let h = 1e-6;
19238        let mut bp = b_data;
19239        let mut bm = b_data;
19240        for i in 0..n * k {
19241            bp[i] += h * tb_data[i];
19242            bm[i] -= h * tb_data[i];
19243        }
19244        let xp = {
19245            let mut a_copy = a_data;
19246            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
19247            bp
19248        };
19249        let xm = {
19250            let mut a_copy = a_data;
19251            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
19252            bm
19253        };
19254        for i in 0..n * k {
19255            let fd = (xp[i] - xm[i]) / (2.0 * h);
19256            assert!(
19257                (tangent_x[i] - fd).abs() < 1e-7,
19258                "FD t_X[{i}]: AD={} FD={}",
19259                tangent_x[i],
19260                fd
19261            );
19262        }
19263    }
19264
19265    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
19266    ///
19267    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
19268    /// compile + run, and check the tangent output matches the
19269    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
19270    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
19271    #[test]
19272    fn jvp_dense_solve_b_runs_and_matches_fd() {
19273        use rlx_opt::autodiff_fwd::jvp;
19274        let n = 3usize;
19275
19276        // Forward.
19277        let mut g = Graph::new("jvp_b_e2e");
19278        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19279        let b = g.input("b", Shape::new(&[n], DType::F64));
19280        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19281        g.set_outputs(vec![x]);
19282
19283        // JVP graph perturbing b only.
19284        let jg = jvp(&g, &[b]);
19285        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
19286        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19287            for node in graph.nodes() {
19288                let name = match &node.op {
19289                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19290                    _ => None,
19291                };
19292                if name == Some(want) {
19293                    return node.id;
19294                }
19295            }
19296            panic!("no node named {want:?}");
19297        };
19298        let a_id = find_by_name(&jg, "A");
19299        let b_id = find_by_name(&jg, "b");
19300        let tb_id = find_by_name(&jg, "tangent_b");
19301
19302        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19303        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19304        // Pick an arbitrary perturbation direction.
19305        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
19306
19307        let (sched, mut arena) =
19308            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
19309        execute_thunks(&sched, arena.raw_buf_mut());
19310
19311        // Outputs: [primal_x, tangent_x].
19312        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
19313        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19314
19315        // Closed form: t_x = solve(A, t_b).
19316        let t_x_ref = {
19317            let mut a = a_data;
19318            let mut tb = tb_data;
19319            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
19320            assert_eq!(info, 0);
19321            tb
19322        };
19323        for i in 0..n {
19324            assert!(
19325                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19326                "t_x[{i}]: got {} want {}",
19327                tangent_x[i],
19328                t_x_ref[i]
19329            );
19330        }
19331
19332        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
19333        let h = 1e-6;
19334        let mut bp = b_data;
19335        let mut bm = b_data;
19336        for i in 0..n {
19337            bp[i] += h * tb_data[i];
19338            bm[i] -= h * tb_data[i];
19339        }
19340        let xp = {
19341            let mut a = a_data;
19342            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
19343            assert_eq!(info, 0);
19344            bp
19345        };
19346        let xm = {
19347            let mut a = a_data;
19348            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
19349            assert_eq!(info, 0);
19350            bm
19351        };
19352        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
19353        for i in 0..n {
19354            assert!(
19355                (tangent_x[i] - fd[i]).abs() < 1e-7,
19356                "FD mismatch t_x[{i}]: AD={} FD={}",
19357                tangent_x[i],
19358                fd[i]
19359            );
19360        }
19361        // Sanity: primal output is the actual solve.
19362        let primal_ref = {
19363            let mut a = a_data;
19364            let mut b = b_data;
19365            crate::blas::dgesv(&mut a, &mut b, n, 1);
19366            b
19367        };
19368        for i in 0..n {
19369            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
19370        }
19371    }
19372
19373    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
19374    /// path includes the −t_A·x correction term.
19375    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
19376    /// directional derivative of `solve(A, b)` w.r.t. A in the
19377    /// `t_A` direction.
19378    #[test]
19379    fn jvp_dense_solve_a_runs_and_matches_fd() {
19380        use rlx_opt::autodiff_fwd::jvp;
19381        let n = 3usize;
19382
19383        let mut g = Graph::new("jvp_a_e2e");
19384        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19385        let b = g.input("b", Shape::new(&[n], DType::F64));
19386        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19387        g.set_outputs(vec![x]);
19388
19389        let jg = jvp(&g, &[a]);
19390        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19391            for node in graph.nodes() {
19392                let name = match &node.op {
19393                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19394                    _ => None,
19395                };
19396                if name == Some(want) {
19397                    return node.id;
19398                }
19399            }
19400            panic!("no node named {want:?}");
19401        };
19402        let a_id = find_by_name(&jg, "A");
19403        let b_id = find_by_name(&jg, "b");
19404        let ta_id = find_by_name(&jg, "tangent_A");
19405
19406        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
19407        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
19408        // Asymmetric perturbation direction for A.
19409        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
19410
19411        let (sched, mut arena) =
19412            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
19413        execute_thunks(&sched, arena.raw_buf_mut());
19414
19415        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
19416
19417        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
19418        let x_ref = {
19419            let mut a = a_data;
19420            let mut b = b_data;
19421            crate::blas::dgesv(&mut a, &mut b, n, 1);
19422            b
19423        };
19424        let mut prod = [0.0_f64; 3];
19425        for i in 0..n {
19426            for j in 0..n {
19427                prod[i] += ta_data[i * n + j] * x_ref[j];
19428            }
19429        }
19430        let t_x_ref = {
19431            let mut a = a_data;
19432            let mut p = prod;
19433            crate::blas::dgesv(&mut a, &mut p, n, 1);
19434            [-p[0], -p[1], -p[2]]
19435        };
19436        for i in 0..n {
19437            assert!(
19438                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
19439                "closed-form t_x[{i}]: AD={} ref={}",
19440                tangent_x[i],
19441                t_x_ref[i]
19442            );
19443        }
19444
19445        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
19446        let h = 1e-6;
19447        let mut ap = a_data;
19448        let mut am = a_data;
19449        for i in 0..n * n {
19450            ap[i] += h * ta_data[i];
19451            am[i] -= h * ta_data[i];
19452        }
19453        let xp = {
19454            let mut a = ap;
19455            let mut b = b_data;
19456            crate::blas::dgesv(&mut a, &mut b, n, 1);
19457            b
19458        };
19459        let xm = {
19460            let mut a = am;
19461            let mut b = b_data;
19462            crate::blas::dgesv(&mut a, &mut b, n, 1);
19463            b
19464        };
19465        for i in 0..n {
19466            let fd = (xp[i] - xm[i]) / (2.0 * h);
19467            assert!(
19468                (tangent_x[i] - fd).abs() < 1e-7,
19469                "FD t_x[{i}]: AD={} FD={}",
19470                tangent_x[i],
19471                fd
19472            );
19473        }
19474    }
19475
19476    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
19477    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
19478    /// in-test reference loop that does the same i32 accumulation
19479    /// and requantize math. Symmetric quant (zp=0) to keep the math
19480    /// head-to-head.
19481    #[test]
19482    fn q_conv2d_matches_reference() {
19483        use rlx_ir::Philox4x32;
19484        // Small NCHW shape — enough to exercise stride/padding edges.
19485        let n = 1usize;
19486        let c_in = 2usize;
19487        let h = 5usize;
19488        let w_in = 5usize;
19489        let c_out = 3usize;
19490        let kh = 3usize;
19491        let kw = 3usize;
19492        let ph = 1usize;
19493        let pw = 1usize;
19494        let sh = 1usize;
19495        let sw = 1usize;
19496        let h_out = (h + 2 * ph - kh) / sh + 1;
19497        let w_out = (w_in + 2 * pw - kw) / sw + 1;
19498
19499        let x_scale = 0.04f32;
19500        let w_scale = 0.02f32;
19501        let out_scale = 0.5f32;
19502        let mult = x_scale * w_scale / out_scale;
19503
19504        let mut rng = Philox4x32::new(2099);
19505        let mut xf = vec![0f32; n * c_in * h * w_in];
19506        rng.fill_normal(&mut xf);
19507        let mut wf = vec![0f32; c_out * c_in * kh * kw];
19508        rng.fill_normal(&mut wf);
19509        let xq: Vec<i8> = xf
19510            .iter()
19511            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19512            .collect();
19513        let wq: Vec<i8> = wf
19514            .iter()
19515            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19516            .collect();
19517        let bias: Vec<i32> = vec![0i32; c_out];
19518
19519        let mut g = Graph::new("qconv");
19520        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
19521        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
19522        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
19523        let out = g.q_conv2d(
19524            xn,
19525            wn,
19526            bn,
19527            vec![kh, kw],
19528            vec![sh, sw],
19529            vec![ph, pw],
19530            vec![1, 1],
19531            1,
19532            0,
19533            0,
19534            0,
19535            mult,
19536            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
19537        );
19538        g.set_outputs(vec![out]);
19539
19540        let plan = rlx_opt::memory::plan_memory(&g);
19541        let mut arena = crate::arena::Arena::from_plan(plan);
19542        let sched = compile_thunks(&g, &arena);
19543        // Capture offsets before borrowing the buf mutably (avoids
19544        // overlap between &mut and the &arena.byte_offset reads).
19545        let xn_off = arena.byte_offset(xn);
19546        let wn_off = arena.byte_offset(wn);
19547        let bn_off = arena.byte_offset(bn);
19548        let out_off = arena.byte_offset(out);
19549        let buf = arena.raw_buf_mut();
19550        unsafe {
19551            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19552            for (i, &v) in xq.iter().enumerate() {
19553                *p.add(i) = v;
19554            }
19555            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19556            for (i, &v) in wq.iter().enumerate() {
19557                *p.add(i) = v;
19558            }
19559            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19560            for (i, &v) in bias.iter().enumerate() {
19561                *p.add(i) = v;
19562            }
19563        }
19564        execute_thunks(&sched, arena.raw_buf_mut());
19565        let out_q: Vec<i8> = unsafe {
19566            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19567            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
19568        };
19569
19570        // Reference: scalar loop in NCHW with the same requantize.
19571        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
19572        for ni in 0..n {
19573            for co in 0..c_out {
19574                for ho in 0..h_out {
19575                    for wo in 0..w_out {
19576                        let mut acc: i32 = 0;
19577                        for ci in 0..c_in {
19578                            for ki in 0..kh {
19579                                for kj in 0..kw {
19580                                    let hi = ho * sh + ki;
19581                                    let wi = wo * sw + kj;
19582                                    if hi < ph || wi < pw {
19583                                        continue;
19584                                    }
19585                                    let hi = hi - ph;
19586                                    let wi = wi - pw;
19587                                    if hi >= h || wi >= w_in {
19588                                        continue;
19589                                    }
19590                                    let xv =
19591                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
19592                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
19593                                    acc += xv * wv;
19594                                }
19595                            }
19596                        }
19597                        let r = (acc as f32 * mult).round() as i32;
19598                        let r = r.clamp(-128, 127) as i8;
19599                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
19600                    }
19601                }
19602            }
19603        }
19604
19605        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19606            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
19607        }
19608    }
19609
19610    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
19611    /// fake-quant reference `Dequantize → MatMul → Quantize` that
19612    /// would produce the same output if we round-tripped through
19613    /// f32. Both should agree element-for-element (or within ±1 i8
19614    /// step, since rounding in the requantize uses different code
19615    /// paths). Symmetric quantization (zp=0) for both paths to keep
19616    /// the math head-to-head.
19617    #[test]
19618    fn q_matmul_matches_fake_quant_reference() {
19619        use rlx_ir::Philox4x32;
19620        let m = 3usize;
19621        let k = 8usize;
19622        let n = 5usize;
19623        let mut rng = Philox4x32::new(2031);
19624
19625        // Pick scales and quantize random f32 inputs to i8.
19626        let x_scale = 0.05f32;
19627        let w_scale = 0.03f32;
19628        let out_scale = 0.4f32;
19629        let mult = x_scale * w_scale / out_scale;
19630        let mut xf = vec![0f32; m * k];
19631        rng.fill_normal(&mut xf);
19632        let mut wf = vec![0f32; k * n];
19633        rng.fill_normal(&mut wf);
19634        let xq: Vec<i8> = xf
19635            .iter()
19636            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
19637            .collect();
19638        let wq: Vec<i8> = wf
19639            .iter()
19640            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
19641            .collect();
19642        let bias: Vec<i32> = vec![0i32; n];
19643
19644        // ── Direct INT8 path ──
19645        let _f = DType::F32;
19646        let mut g_q = Graph::new("qmm_direct");
19647        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
19648        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
19649        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
19650        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
19651        g_q.set_outputs(vec![out]);
19652        let plan = rlx_opt::memory::plan_memory(&g_q);
19653        let mut arena = crate::arena::Arena::from_plan(plan);
19654        let sched = compile_thunks(&g_q, &arena);
19655
19656        // Fill inputs.
19657        let xn_off = arena.byte_offset(xn);
19658        let wn_off = arena.byte_offset(wn);
19659        let bn_off = arena.byte_offset(bn);
19660        let out_off = arena.byte_offset(out);
19661        let buf = arena.raw_buf_mut();
19662        unsafe {
19663            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
19664            for (i, &v) in xq.iter().enumerate() {
19665                *p.add(i) = v;
19666            }
19667            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
19668            for (i, &v) in wq.iter().enumerate() {
19669                *p.add(i) = v;
19670            }
19671            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
19672            for (i, &v) in bias.iter().enumerate() {
19673                *p.add(i) = v;
19674            }
19675        }
19676        execute_thunks(&sched, arena.raw_buf_mut());
19677        let out_q: Vec<i8> = unsafe {
19678            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
19679            (0..m * n).map(|i| *p.add(i)).collect()
19680        };
19681
19682        // ── Fake-quant reference: scalar emulation in plain Rust ──
19683        // Same arithmetic the kernel does, but in a verifier loop:
19684        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
19685        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
19686        let mut out_ref = vec![0i8; m * n];
19687        for mi in 0..m {
19688            for ni in 0..n {
19689                let mut acc: i32 = 0;
19690                for ki in 0..k {
19691                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
19692                }
19693                let r = (acc as f32 * mult).round() as i32;
19694                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
19695            }
19696        }
19697
19698        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
19699            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
19700        }
19701    }
19702
19703    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
19704    /// dequantize back, and confirm the result tracks the input
19705    /// within the per-element scale (the inevitable rounding error).
19706    /// Also pins the kernel's saturation behavior at the i8 limits.
19707    #[test]
19708    fn quantize_dequantize_round_trip() {
19709        use rlx_ir::Philox4x32;
19710        let len = 64;
19711        let mut rng = Philox4x32::new(2027);
19712        let mut x = vec![0f32; len];
19713        rng.fill_normal(&mut x);
19714        // Stretch a couple values past the +/- saturation cliff so
19715        // the saturate_i8 path is exercised.
19716        x[0] = 999.0;
19717        x[1] = -999.0;
19718
19719        let scale = 0.05f32;
19720        let zp = 3i32;
19721
19722        let f = DType::F32;
19723        let mut g = Graph::new("qdq");
19724        let xn = g.input("x", Shape::new(&[len], f));
19725        let q = g.quantize(xn, scale, zp);
19726        let dq = g.dequantize(q, scale, zp);
19727        g.set_outputs(vec![dq]);
19728
19729        let plan = rlx_opt::memory::plan_memory(&g);
19730        let mut arena = crate::arena::Arena::from_plan(plan);
19731        let sched = compile_thunks(&g, &arena);
19732        let xn_off = arena.byte_offset(xn);
19733        let dq_off = arena.byte_offset(dq);
19734        let buf = arena.raw_buf_mut();
19735        unsafe {
19736            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19737            for (i, &v) in x.iter().enumerate() {
19738                *p.add(i) = v;
19739            }
19740        }
19741        execute_thunks(&sched, arena.raw_buf_mut());
19742        let out: Vec<f32> = unsafe {
19743            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19744            (0..len).map(|i| *p.add(i)).collect()
19745        };
19746
19747        // Saturated values at i=0,1 should clamp to ±127's dequant
19748        // range (= (±127 - zp) · scale).
19749        let sat_pos = (127 - zp) as f32 * scale;
19750        let sat_neg = (-128 - zp) as f32 * scale;
19751        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
19752        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
19753
19754        // Everything else should round-trip within `scale` (one quant
19755        // step = the worst-case rounding error).
19756        for i in 2..len {
19757            assert!(
19758                (out[i] - x[i]).abs() <= scale + 1e-5,
19759                "qdq[{i}]: {} → {}, scale={scale}",
19760                x[i],
19761                out[i]
19762            );
19763        }
19764    }
19765
19766    /// Per-channel quantize / dequantize: independent scale and zp
19767    /// per slice along an axis. Verifies (a) each channel uses its
19768    /// own scale (not a shared one), (b) saturation still respects
19769    /// the i8 range, (c) channel data layout decomposition is
19770    /// correct (no cross-channel leakage).
19771    #[test]
19772    fn quantize_per_channel_round_trip() {
19773        let c = 4usize;
19774        let inner = 5usize;
19775        // Different magnitudes per channel — proves the per-channel
19776        // scale is actually being read for each row.
19777        let mags = [0.01f32, 0.5, 5.0, 50.0];
19778        let mut x = vec![0f32; c * inner];
19779        for ci in 0..c {
19780            for ii in 0..inner {
19781                // Sweep through values that span [-max_abs, +max_abs]
19782                // for each channel, plus one value past the cliff to
19783                // trigger saturation.
19784                x[ci * inner + ii] = match ii {
19785                    0 => -mags[ci],
19786                    1 => 0.0,
19787                    2 => mags[ci],
19788                    3 => mags[ci] * 1000.0,  // saturates +
19789                    _ => -mags[ci] * 1000.0, // saturates -
19790                };
19791            }
19792        }
19793        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
19794        let zps: Vec<i32> = vec![0, 0, 0, 0];
19795
19796        let f = DType::F32;
19797        let mut g = Graph::new("qdq_pc");
19798        let xn = g.input("x", Shape::new(&[c, inner], f));
19799        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
19800        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
19801        g.set_outputs(vec![dq]);
19802
19803        let plan = rlx_opt::memory::plan_memory(&g);
19804        let mut arena = crate::arena::Arena::from_plan(plan);
19805        let sched = compile_thunks(&g, &arena);
19806        let xn_off = arena.byte_offset(xn);
19807        let dq_off = arena.byte_offset(dq);
19808        let buf = arena.raw_buf_mut();
19809        unsafe {
19810            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19811            for (i, &v) in x.iter().enumerate() {
19812                *p.add(i) = v;
19813            }
19814        }
19815        execute_thunks(&sched, arena.raw_buf_mut());
19816        let out: Vec<f32> = unsafe {
19817            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
19818            (0..c * inner).map(|i| *p.add(i)).collect()
19819        };
19820
19821        for ci in 0..c {
19822            // Within-range entries (positions 0, 1, 2) must round-trip
19823            // within one quant step of *that channel's* scale.
19824            for ii in 0..3 {
19825                let idx = ci * inner + ii;
19826                assert!(
19827                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
19828                    "ch {ci} idx {ii}: {} vs {}",
19829                    x[idx],
19830                    out[idx]
19831                );
19832            }
19833            // Saturated positions clamp to ±127 · scale[ci].
19834            let sat_pos = 127.0 * scales[ci];
19835            let sat_neg = -128.0 * scales[ci];
19836            assert!(
19837                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
19838                "ch {ci} +sat: {}",
19839                out[ci * inner + 3]
19840            );
19841            assert!(
19842                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
19843                "ch {ci} -sat: {}",
19844                out[ci * inner + 4]
19845            );
19846        }
19847    }
19848
19849    /// `Op::ActivationBackward` parity for every supported kind.
19850    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
19851    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
19852    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
19853    /// the kernel.
19854    #[test]
19855    fn activation_backward_matches_numerical_per_kind() {
19856        use rlx_ir::Philox4x32;
19857        use rlx_ir::op::Activation;
19858        let mut rng = Philox4x32::new(91);
19859        let len = 32;
19860        // x sampled away from kink/branch points: shifted positive
19861        // (exp/sqrt/log domain) for the unary-positive activations;
19862        // wide range otherwise. Two parallel tests would be cleaner
19863        // but this is concise enough.
19864        let mut x_pos = vec![0f32; len];
19865        rng.fill_normal(&mut x_pos);
19866        for v in x_pos.iter_mut() {
19867            *v = v.abs() + 0.5;
19868        }
19869        let mut x_any = vec![0f32; len];
19870        rng.fill_normal(&mut x_any);
19871        let mut dy = vec![0f32; len];
19872        rng.fill_normal(&mut dy);
19873
19874        for &(kind, x_data, eps, tol) in &[
19875            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
19876            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
19877            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
19878            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
19879            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
19880            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
19881            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
19882            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
19883            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
19884            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
19885        ] {
19886            let f = DType::F32;
19887            let mut g = Graph::new("act_bw");
19888            let xn = g.input("x", Shape::new(&[len], f));
19889            let dyn_ = g.input("dy", Shape::new(&[len], f));
19890            let dx = g.activation_backward(kind, xn, dyn_);
19891            g.set_outputs(vec![dx]);
19892
19893            let plan = rlx_opt::memory::plan_memory(&g);
19894            let mut arena = crate::arena::Arena::from_plan(plan);
19895            let sched = compile_thunks(&g, &arena);
19896
19897            let xn_off = arena.byte_offset(xn);
19898            let dyn_off = arena.byte_offset(dyn_);
19899            let dx_off = arena.byte_offset(dx);
19900            let buf = arena.raw_buf_mut();
19901            unsafe {
19902                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
19903                for (i, &v) in x_data.iter().enumerate() {
19904                    *p.add(i) = v;
19905                }
19906                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
19907                for (i, &v) in dy.iter().enumerate() {
19908                    *p.add(i) = v;
19909                }
19910            }
19911            execute_thunks(&sched, arena.raw_buf_mut());
19912            let analytical: Vec<f32> = unsafe {
19913                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
19914                (0..len).map(|i| *p.add(i)).collect()
19915            };
19916
19917            // Apply the forward activation manually; finite-difference
19918            // each element.
19919            let act_apply = |kind: Activation, x: f32| -> f32 {
19920                match kind {
19921                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
19922                    Activation::Tanh => x.tanh(),
19923                    Activation::Silu => x / (1.0 + (-x).exp()),
19924                    Activation::Gelu => {
19925                        // Match the kernel's exact erf form.
19926                        const INV_SQRT2: f32 = 0.707_106_77;
19927                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
19928                    }
19929                    Activation::GeluApprox => {
19930                        const C: f32 = 0.797_884_6;
19931                        const A: f32 = 0.044_715;
19932                        let inner = C * (x + A * x * x * x);
19933                        0.5 * x * (1.0 + inner.tanh())
19934                    }
19935                    Activation::Exp => x.exp(),
19936                    Activation::Log => x.ln(),
19937                    Activation::Sqrt => x.sqrt(),
19938                    Activation::Rsqrt => 1.0 / x.sqrt(),
19939                    Activation::Neg => -x,
19940                    Activation::Relu => x.max(0.0),
19941                    Activation::Abs => x.abs(),
19942                    Activation::Round => x.round(),
19943                    Activation::Sin => x.sin(),
19944                    Activation::Cos => x.cos(),
19945                    Activation::Tan => x.tan(),
19946                    Activation::Atan => x.atan(),
19947                }
19948            };
19949            for i in 0..len {
19950                let xv = x_data[i];
19951                let plus = act_apply(kind, xv + eps);
19952                let minus = act_apply(kind, xv - eps);
19953                let num = (plus - minus) / (2.0 * eps) * dy[i];
19954                assert!(
19955                    (analytical[i] - num).abs() < tol,
19956                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
19957                    analytical[i]
19958                );
19959            }
19960        }
19961    }
19962
19963    /// Batched 3-D MatMul VJP — the transformer-attention shape
19964    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
19965    /// `Op::Transpose` with a perm that swaps the last two dims.
19966    #[test]
19967    fn matmul_3d_gradient_matches_numerical() {
19968        use rlx_ir::Philox4x32;
19969        let batch = 2usize;
19970        let m = 3usize;
19971        let k = 4usize;
19972        let n = 5usize;
19973        let mut rng = Philox4x32::new(101);
19974        let mut a_data = vec![0f32; batch * m * k];
19975        rng.fill_normal(&mut a_data);
19976        let mut b_data = vec![0f32; batch * k * n];
19977        rng.fill_normal(&mut b_data);
19978
19979        let f = DType::F32;
19980        let mut fwd = Graph::new("matmul_3d");
19981        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
19982        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
19983        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
19984        let loss = fwd.add_node(
19985            Op::Reduce {
19986                op: ReduceOp::Sum,
19987                axes: vec![0, 1, 2],
19988                keep_dim: false,
19989            },
19990            vec![mm],
19991            Shape::from_dims(&[], f),
19992        );
19993        fwd.set_outputs(vec![loss]);
19994
19995        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
19996        let d_out = bwd_graph
19997            .nodes()
19998            .iter()
19999            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20000            .map(|n| n.id)
20001            .unwrap();
20002
20003        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20004        let mut arena = crate::arena::Arena::from_plan(plan);
20005        let sched = compile_thunks(&bwd_graph, &arena);
20006        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
20007            let off = arena.byte_offset(id);
20008            let buf = arena.raw_buf_mut();
20009            unsafe {
20010                let p = buf.as_mut_ptr().add(off) as *mut f32;
20011                for (i, &v) in data.iter().enumerate() {
20012                    *p.add(i) = v;
20013                }
20014            }
20015        }
20016        execute_thunks(&sched, arena.raw_buf_mut());
20017        let gb_id = bwd_graph.outputs[1];
20018        let g_b: Vec<f32> = unsafe {
20019            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
20020            (0..batch * k * n).map(|i| *p.add(i)).collect()
20021        };
20022
20023        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
20024        let forward_loss = |b_vals: &[f32]| -> f32 {
20025            let mut out = vec![0f32; batch * m * n];
20026            for bi in 0..batch {
20027                for mi in 0..m {
20028                    for ni in 0..n {
20029                        let mut acc = 0f32;
20030                        for ki in 0..k {
20031                            acc +=
20032                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
20033                        }
20034                        out[bi * m * n + mi * n + ni] = acc;
20035                    }
20036                }
20037            }
20038            out.iter().sum()
20039        };
20040        let eps = 1e-3f32;
20041        let mut bp_p = b_data.clone();
20042        let mut g_b_num = vec![0f32; b_data.len()];
20043        for i in 0..b_data.len() {
20044            let s = bp_p[i];
20045            bp_p[i] = s + eps;
20046            let lp = forward_loss(&bp_p);
20047            bp_p[i] = s - eps;
20048            let lm = forward_loss(&bp_p);
20049            bp_p[i] = s;
20050            g_b_num[i] = (lp - lm) / (2.0 * eps);
20051        }
20052        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
20053            assert!(
20054                (a - n).abs() < 5e-3,
20055                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
20056            );
20057        }
20058    }
20059
20060    /// Composed `Op::Softmax` VJP — the gradient is built from
20061    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
20062    /// SoftmaxBackward kernel. Verifies the closed-form
20063    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
20064    /// 2-D logits tensor.
20065    #[test]
20066    fn softmax_gradient_matches_numerical() {
20067        use rlx_ir::Philox4x32;
20068        let n = 3usize;
20069        let c = 5usize;
20070        let mut rng = Philox4x32::new(57);
20071        let mut x_data = vec![0f32; n * c];
20072        rng.fill_normal(&mut x_data);
20073
20074        let f = DType::F32;
20075        let mut fwd = Graph::new("softmax_only");
20076        let xn = fwd.input("x", Shape::new(&[n, c], f));
20077        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20078        // Loss = sum(softmax · target) for some random fixed target —
20079        // any linear loss will do; sum-of-all is the simplest and gives
20080        // a uniform gradient flow into the softmax.
20081        let loss = fwd.add_node(
20082            Op::Reduce {
20083                op: ReduceOp::Sum,
20084                axes: vec![0, 1],
20085                keep_dim: false,
20086            },
20087            vec![sm],
20088            Shape::from_dims(&[], f),
20089        );
20090        fwd.set_outputs(vec![loss]);
20091
20092        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
20093        // input so we can compare it directly. The forward NodeId for
20094        // `xn` doubles as its bwd-graph mirror.
20095        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20096        let d_out = bwd_graph
20097            .nodes()
20098            .iter()
20099            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20100            .map(|n| n.id)
20101            .unwrap();
20102
20103        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20104        let mut arena = crate::arena::Arena::from_plan(plan);
20105        let sched = compile_thunks(&bwd_graph, &arena);
20106        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20107            let off = arena.byte_offset(id);
20108            let buf = arena.raw_buf_mut();
20109            unsafe {
20110                let p = buf.as_mut_ptr().add(off) as *mut f32;
20111                for (i, &v) in data.iter().enumerate() {
20112                    *p.add(i) = v;
20113                }
20114            }
20115        }
20116        execute_thunks(&sched, arena.raw_buf_mut());
20117        let g_x_id = bwd_graph.outputs[1];
20118        let g_x: Vec<f32> = unsafe {
20119            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20120            (0..n * c).map(|i| *p.add(i)).collect()
20121        };
20122
20123        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
20124        // analytically. So expect g_x ≈ 0 within FD precision. (This
20125        // doubles as a strong sanity check for the composition.)
20126        let forward_loss = |x: &[f32]| -> f32 {
20127            let mut total = 0f32;
20128            for ni in 0..n {
20129                let row = &x[ni * c..(ni + 1) * c];
20130                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20131                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20132                for &v in row {
20133                    total += (v - m).exp() / denom;
20134                }
20135            }
20136            total
20137        };
20138        let eps = 1e-3f32;
20139        let mut p = x_data.clone();
20140        for i in 0..x_data.len() {
20141            let s = p[i];
20142            p[i] = s + eps;
20143            let lp = forward_loss(&p);
20144            p[i] = s - eps;
20145            let lm = forward_loss(&p);
20146            p[i] = s;
20147            let num = (lp - lm) / (2.0 * eps);
20148            assert!(
20149                (g_x[i] - num).abs() < 5e-3,
20150                "softmax g_x[{i}]: analytical {} vs numerical {num}",
20151                g_x[i]
20152            );
20153        }
20154    }
20155
20156    /// LayerNorm VJP — three gradients in one pass:
20157    ///   d_x via `LayerNormBackwardInput`,
20158    ///   d_gamma via `LayerNormBackwardGamma`,
20159    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
20160    #[test]
20161    fn layer_norm_gradient_matches_numerical() {
20162        use rlx_ir::Philox4x32;
20163        let rows = 3usize;
20164        let h = 6usize;
20165        let mut rng = Philox4x32::new(1009);
20166        let mut x_data = vec![0f32; rows * h];
20167        rng.fill_normal(&mut x_data);
20168        let mut g_data = vec![0f32; h];
20169        rng.fill_normal(&mut g_data);
20170        for v in g_data.iter_mut() {
20171            *v = v.abs() + 0.5;
20172        }
20173        let mut b_data = vec![0f32; h];
20174        rng.fill_normal(&mut b_data);
20175        let eps = 1e-5f32;
20176
20177        let f = DType::F32;
20178        let mut fwd = Graph::new("ln_only");
20179        let xn = fwd.input("x", Shape::new(&[rows, h], f));
20180        let gp = fwd.param("gamma", Shape::new(&[h], f));
20181        let bp = fwd.param("beta", Shape::new(&[h], f));
20182        let ln = fwd.add_node(
20183            Op::LayerNorm { axis: -1, eps },
20184            vec![xn, gp, bp],
20185            Shape::new(&[rows, h], f),
20186        );
20187        let loss = fwd.add_node(
20188            Op::Reduce {
20189                op: ReduceOp::Sum,
20190                axes: vec![0, 1],
20191                keep_dim: false,
20192            },
20193            vec![ln],
20194            Shape::from_dims(&[], f),
20195        );
20196        fwd.set_outputs(vec![loss]);
20197
20198        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
20199        let d_out = bwd_graph
20200            .nodes()
20201            .iter()
20202            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20203            .map(|n| n.id)
20204            .unwrap();
20205
20206        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20207        let mut arena = crate::arena::Arena::from_plan(plan);
20208        let sched = compile_thunks(&bwd_graph, &arena);
20209        for &(id, data) in &[
20210            (xn, &x_data),
20211            (gp, &g_data),
20212            (bp, &b_data),
20213            (d_out, &vec![1.0f32]),
20214        ] {
20215            let off = arena.byte_offset(id);
20216            let buf = arena.raw_buf_mut();
20217            unsafe {
20218                let p = buf.as_mut_ptr().add(off) as *mut f32;
20219                for (i, &v) in data.iter().enumerate() {
20220                    *p.add(i) = v;
20221                }
20222            }
20223        }
20224        execute_thunks(&sched, arena.raw_buf_mut());
20225        let read = |id: NodeId, n: usize| -> Vec<f32> {
20226            let off = arena.byte_offset(id);
20227            unsafe {
20228                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
20229                (0..n).map(|i| *p.add(i)).collect()
20230            }
20231        };
20232        let dx_a = read(bwd_graph.outputs[1], rows * h);
20233        let dg_a = read(bwd_graph.outputs[2], h);
20234        let db_a = read(bwd_graph.outputs[3], h);
20235
20236        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
20237            let mut total = 0f32;
20238            for r in 0..rows {
20239                let row = &x[r * h..(r + 1) * h];
20240                let mean = row.iter().sum::<f32>() / h as f32;
20241                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
20242                let inv_std = 1.0 / (var + eps).sqrt();
20243                for d in 0..h {
20244                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
20245                }
20246            }
20247            total
20248        };
20249        let h_eps = 1e-3f32;
20250
20251        let mut x_p = x_data.clone();
20252        for i in 0..x_p.len() {
20253            let s = x_p[i];
20254            x_p[i] = s + h_eps;
20255            let lp = forward_loss(&x_p, &g_data, &b_data);
20256            x_p[i] = s - h_eps;
20257            let lm = forward_loss(&x_p, &g_data, &b_data);
20258            x_p[i] = s;
20259            let num = (lp - lm) / (2.0 * h_eps);
20260            assert!(
20261                (dx_a[i] - num).abs() < 5e-3,
20262                "ln dx[{i}]: analytical {} vs numerical {num}",
20263                dx_a[i]
20264            );
20265        }
20266        let mut g_p = g_data.clone();
20267        for i in 0..g_p.len() {
20268            let s = g_p[i];
20269            g_p[i] = s + h_eps;
20270            let lp = forward_loss(&x_data, &g_p, &b_data);
20271            g_p[i] = s - h_eps;
20272            let lm = forward_loss(&x_data, &g_p, &b_data);
20273            g_p[i] = s;
20274            let num = (lp - lm) / (2.0 * h_eps);
20275            assert!(
20276                (dg_a[i] - num).abs() < 5e-3,
20277                "ln dg[{i}]: analytical {} vs numerical {num}",
20278                dg_a[i]
20279            );
20280        }
20281        let mut b_p = b_data.clone();
20282        for i in 0..b_p.len() {
20283            let s = b_p[i];
20284            b_p[i] = s + h_eps;
20285            let lp = forward_loss(&x_data, &g_data, &b_p);
20286            b_p[i] = s - h_eps;
20287            let lm = forward_loss(&x_data, &g_data, &b_p);
20288            b_p[i] = s;
20289            let num = (lp - lm) / (2.0 * h_eps);
20290            assert!(
20291                (db_a[i] - num).abs() < 5e-3,
20292                "ln db[{i}]: analytical {} vs numerical {num}",
20293                db_a[i]
20294            );
20295        }
20296    }
20297
20298    /// Single dense layer + softmax-cross-entropy + mean reduce —
20299    /// the simplest non-trivial training graph. Validates MatMul,
20300    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
20301    /// plumbing all at once.
20302    #[test]
20303    fn dense_sce_mean_gradient_matches_numerical() {
20304        use rlx_ir::Philox4x32;
20305        let bs = 4usize;
20306        let k_in = 3usize;
20307        let c = 5usize;
20308        let mut rng = Philox4x32::new(7);
20309        let mut x = vec![0f32; bs * k_in];
20310        rng.fill_normal(&mut x);
20311        let mut w_init = vec![0f32; k_in * c];
20312        rng.fill_normal(&mut w_init);
20313        let mut b_init = vec![0f32; c];
20314        rng.fill_normal(&mut b_init);
20315        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20316
20317        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
20318        let f = DType::F32;
20319        let mut fwd = Graph::new("dense_sce");
20320        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20321        let lb = fwd.input("labels", Shape::new(&[bs], f));
20322        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20323        let bp = fwd.param("b", Shape::new(&[c], f));
20324        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20325        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
20326        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20327        let loss = fwd.add_node(
20328            Op::Reduce {
20329                op: ReduceOp::Sum,
20330                axes: vec![0],
20331                keep_dim: false,
20332            },
20333            vec![loss_per],
20334            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
20335            Shape::from_dims(&[], f),
20336        );
20337        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
20338        // less aggressively than Mean would, and gives us a closed-form
20339        // reference for the loss we expect.
20340        // For simplicity though, switch to Mean which the tests should also cover.
20341        // (Re-using `loss` with Sum here for now; the mean factor cancels in
20342        // the gradient comparison since both analytical and numerical use the
20343        // same forward.)
20344        fwd.set_outputs(vec![loss]);
20345
20346        // ── Backward graph ──
20347        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
20348        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
20349        // in bwd_graph match their fwd ids (the mirror keeps order).
20350        let d_out = bwd_graph
20351            .nodes()
20352            .iter()
20353            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20354            .map(|n| n.id)
20355            .expect("d_output input");
20356
20357        let (sched, mut arena) = prepare(
20358            &bwd_graph,
20359            &[
20360                (xn, &x),
20361                (lb, &labels),
20362                (wp, &w_init),
20363                (bp, &b_init),
20364                (d_out, &[1.0]),
20365            ],
20366        );
20367        execute_thunks(&sched, arena.raw_buf_mut());
20368
20369        let outs = &bwd_graph.outputs;
20370        let loss_id = outs[0];
20371        let gw_id = outs[1];
20372        let gb_id = outs[2];
20373        let loss_actual = read_arena(&arena, loss_id, 1)[0];
20374        let gw_actual = read_arena(&arena, gw_id, k_in * c);
20375        let gb_actual = read_arena(&arena, gb_id, c);
20376
20377        // ── Forward-only graph for finite differences ──
20378        // Re-use the same `fwd` graph; set up its own arena and rerun
20379        // for each perturbed parameter.
20380        let plan = rlx_opt::memory::plan_memory(&fwd);
20381        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20382        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20383        write_arena(&mut fwd_arena, xn, &x);
20384        write_arena(&mut fwd_arena, lb, &labels);
20385
20386        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
20387            write_arena(arena, wp, w);
20388            write_arena(arena, bp, b);
20389            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20390            read_arena(arena, loss, 1)[0]
20391        };
20392
20393        // Sanity: the loss reported by the bwd graph matches the
20394        // forward-only graph on the unperturbed inputs.
20395        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
20396        assert!(
20397            (loss_actual - loss_check).abs() < 1e-4,
20398            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
20399        );
20400
20401        let eps = 1e-3f32;
20402        let mut w_perturbed = w_init.clone();
20403        let mut gw_numerical = vec![0f32; w_init.len()];
20404        for i in 0..w_init.len() {
20405            let saved = w_perturbed[i];
20406            w_perturbed[i] = saved + eps;
20407            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20408            w_perturbed[i] = saved - eps;
20409            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
20410            w_perturbed[i] = saved;
20411            gw_numerical[i] = (lp - lm) / (2.0 * eps);
20412        }
20413        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
20414            assert!(
20415                (a - n).abs() < 5e-3,
20416                "grad_w[{i}]: analytical {a} vs numerical {n}"
20417            );
20418        }
20419
20420        let mut b_perturbed = b_init.clone();
20421        let mut gb_numerical = vec![0f32; b_init.len()];
20422        for i in 0..b_init.len() {
20423            let saved = b_perturbed[i];
20424            b_perturbed[i] = saved + eps;
20425            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20426            b_perturbed[i] = saved - eps;
20427            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
20428            b_perturbed[i] = saved;
20429            gb_numerical[i] = (lp - lm) / (2.0 * eps);
20430        }
20431        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
20432            assert!(
20433                (a - n).abs() < 5e-3,
20434                "grad_b[{i}]: analytical {a} vs numerical {n}"
20435            );
20436        }
20437    }
20438
20439    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
20440    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
20441    #[test]
20442    fn dense_sce_mean_reduce_gradient_matches_numerical() {
20443        use rlx_ir::Philox4x32;
20444        let bs = 3usize;
20445        let k_in = 2usize;
20446        let c = 4usize;
20447        let mut rng = Philox4x32::new(13);
20448        let mut x = vec![0f32; bs * k_in];
20449        rng.fill_normal(&mut x);
20450        let mut w_init = vec![0f32; k_in * c];
20451        rng.fill_normal(&mut w_init);
20452        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
20453
20454        let f = DType::F32;
20455        let mut fwd = Graph::new("dense_sce_mean");
20456        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
20457        let lb = fwd.input("labels", Shape::new(&[bs], f));
20458        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
20459        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
20460        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
20461        let loss = fwd.add_node(
20462            Op::Reduce {
20463                op: ReduceOp::Mean,
20464                axes: vec![0],
20465                keep_dim: false,
20466            },
20467            vec![loss_per],
20468            Shape::from_dims(&[], f),
20469        );
20470        fwd.set_outputs(vec![loss]);
20471
20472        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
20473        let d_out = bwd_graph
20474            .nodes()
20475            .iter()
20476            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20477            .map(|n| n.id)
20478            .unwrap();
20479
20480        let (sched, mut arena) = prepare(
20481            &bwd_graph,
20482            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
20483        );
20484        execute_thunks(&sched, arena.raw_buf_mut());
20485
20486        let outs = &bwd_graph.outputs;
20487        let loss_id = outs[0];
20488        let gw_id = outs[1];
20489        let _ = read_arena(&arena, loss_id, 1)[0];
20490        let gw_actual = read_arena(&arena, gw_id, k_in * c);
20491
20492        let plan = rlx_opt::memory::plan_memory(&fwd);
20493        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20494        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20495        write_arena(&mut fwd_arena, xn, &x);
20496        write_arena(&mut fwd_arena, lb, &labels);
20497
20498        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
20499            write_arena(arena, wp, w);
20500            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20501            read_arena(arena, loss, 1)[0]
20502        };
20503
20504        let eps = 1e-3f32;
20505        let mut wp_p = w_init.clone();
20506        let mut gw_num = vec![0f32; w_init.len()];
20507        for i in 0..w_init.len() {
20508            let s = wp_p[i];
20509            wp_p[i] = s + eps;
20510            let lp = run_loss(&mut fwd_arena, &wp_p);
20511            wp_p[i] = s - eps;
20512            let lm = run_loss(&mut fwd_arena, &wp_p);
20513            wp_p[i] = s;
20514            gw_num[i] = (lp - lm) / (2.0 * eps);
20515        }
20516        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
20517            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
20518        }
20519    }
20520    /// The full TinyConv-MNIST forward path (downsized) plumbed
20521    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
20522    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
20523    /// compose into a graph that produces correct gradients.
20524    #[test]
20525    fn tinyconv_full_gradient_matches_numerical() {
20526        use rlx_ir::Philox4x32;
20527        // Tiny shapes so finite differences finish in <1s.
20528        let n = 1usize;
20529        let c_in = 1usize;
20530        let h = 6usize;
20531        let w_in = 6usize;
20532        let c_mid = 2usize; // first conv output channels
20533        let kh = 3;
20534        let kw = 3;
20535        let h1 = h - kh + 1; // 4
20536        let w1 = w_in - kw + 1; // 4
20537        let h2 = h1 / 2;
20538        let w2 = w1 / 2; // 2 × 2 after 2× pool
20539        let flat = c_mid * h2 * w2; // 8
20540        let num_classes = 3usize;
20541
20542        let mut rng = Philox4x32::new(31);
20543        let mut x = vec![0f32; n * c_in * h * w_in];
20544        rng.fill_normal(&mut x);
20545        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
20546        rng.fill_normal(&mut wc);
20547        for v in wc.iter_mut() {
20548            *v *= 0.2;
20549        }
20550        // Shift conv-bias well away from the ReLU zero-boundary. Without
20551        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
20552        // pre-activation that happened to land near zero — making the
20553        // central-difference numerical gradient discontinuous and
20554        // diverge from the analytical (which assumes local smoothness).
20555        // +5.0 keeps every pre-activation positive for any random init
20556        // produced by Philox seed 31 with the wc/x scales used here, so
20557        // ReLU acts as an identity and finite differences are exact.
20558        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
20559        let mut wfc = vec![0f32; flat * num_classes];
20560        rng.fill_normal(&mut wfc);
20561        for v in wfc.iter_mut() {
20562            *v *= 0.5;
20563        }
20564        let mut bfc = vec![0f32; num_classes];
20565        rng.fill_normal(&mut bfc);
20566        let labels: Vec<f32> = vec![1.0]; // batch=1
20567
20568        let f = DType::F32;
20569        let mut fwd = Graph::new("tinyconv");
20570        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
20571        let lb = fwd.input("labels", Shape::new(&[n], f));
20572        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
20573        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
20574        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
20575        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
20576
20577        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
20578        let conv = fwd.add_node(
20579            Op::Conv {
20580                kernel_size: vec![kh, kw],
20581                stride: vec![1, 1],
20582                padding: vec![0, 0],
20583                dilation: vec![1, 1],
20584                groups: 1,
20585            },
20586            vec![xn, wcp],
20587            Shape::new(&[n, c_mid, h1, w1], f),
20588        );
20589        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
20590        // shape so the Add becomes a plain element-wise op. Going through
20591        // an explicit Reshape→Expand instead of relying on the Add to
20592        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
20593        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
20594        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
20595        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
20596        // across all positions instead of channel-broadcasting. Going
20597        // through Expand (a real broadcast thunk) avoids that path
20598        // entirely. The autodiff still exercises `unbroadcast` because
20599        // `Op::Expand`'s VJP reduces over the broadcast axes.
20600        let bc_4d = fwd.add_node(
20601            Op::Reshape {
20602                new_shape: vec![1, c_mid as i64, 1, 1],
20603            },
20604            vec![bcp],
20605            Shape::new(&[1, c_mid, 1, 1], f),
20606        );
20607        let bc_expanded = fwd.add_node(
20608            Op::Expand {
20609                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
20610            },
20611            vec![bc_4d],
20612            Shape::new(&[n, c_mid, h1, w1], f),
20613        );
20614        let conv_b = fwd.binary(
20615            BinaryOp::Add,
20616            conv,
20617            bc_expanded,
20618            Shape::new(&[n, c_mid, h1, w1], f),
20619        );
20620        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
20621        let pool = fwd.add_node(
20622            Op::Pool {
20623                kind: ReduceOp::Max,
20624                kernel_size: vec![2, 2],
20625                stride: vec![2, 2],
20626                padding: vec![0, 0],
20627            },
20628            vec![relu],
20629            Shape::new(&[n, c_mid, h2, w2], f),
20630        );
20631        let flatn = fwd.add_node(
20632            Op::Reshape {
20633                new_shape: vec![n as i64, flat as i64],
20634            },
20635            vec![pool],
20636            Shape::new(&[n, flat], f),
20637        );
20638        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
20639        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
20640        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
20641        let loss = fwd.add_node(
20642            Op::Reduce {
20643                op: ReduceOp::Mean,
20644                axes: vec![0],
20645                keep_dim: false,
20646            },
20647            vec![loss_per],
20648            Shape::from_dims(&[], f),
20649        );
20650        fwd.set_outputs(vec![loss]);
20651
20652        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
20653        let d_out = bwd_graph
20654            .nodes()
20655            .iter()
20656            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20657            .map(|n| n.id)
20658            .unwrap();
20659
20660        let (sched, mut arena) = prepare(
20661            &bwd_graph,
20662            &[
20663                (xn, &x),
20664                (lb, &labels),
20665                (wcp, &wc),
20666                (bcp, &bc),
20667                (wfp, &wfc),
20668                (bfp, &bfc),
20669                (d_out, &[1.0]),
20670            ],
20671        );
20672        execute_thunks(&sched, arena.raw_buf_mut());
20673
20674        let outs = bwd_graph.outputs.clone();
20675        let loss_id = outs[0];
20676        let g_wc_id = outs[1];
20677        let g_bc_id = outs[2];
20678        let g_wfc_id = outs[3];
20679        let g_bfc_id = outs[4];
20680        let loss_actual = read_arena(&arena, loss_id, 1)[0];
20681        let g_wc = read_arena(&arena, g_wc_id, wc.len());
20682        let g_bc = read_arena(&arena, g_bc_id, bc.len());
20683        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
20684        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
20685
20686        // Forward-only arena for finite differences.
20687        let plan = rlx_opt::memory::plan_memory(&fwd);
20688        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
20689        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
20690        write_arena(&mut fwd_arena, xn, &x);
20691        write_arena(&mut fwd_arena, lb, &labels);
20692
20693        // Closure variant: we need to set all four params each call so
20694        // perturbations to one don't leak between sweeps.
20695        let run_loss = |arena: &mut crate::arena::Arena,
20696                        wc: &[f32],
20697                        bc: &[f32],
20698                        wfc: &[f32],
20699                        bfc: &[f32]|
20700         -> f32 {
20701            write_arena(arena, wcp, wc);
20702            write_arena(arena, bcp, bc);
20703            write_arena(arena, wfp, wfc);
20704            write_arena(arena, bfp, bfc);
20705            execute_thunks(&fwd_sched, arena.raw_buf_mut());
20706            read_arena(arena, loss, 1)[0]
20707        };
20708
20709        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
20710        assert!(
20711            (loss_actual - loss_check).abs() < 1e-4,
20712            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
20713        );
20714
20715        let eps = 1e-3f32;
20716        let check_grad = |arena: &mut crate::arena::Arena,
20717                          name: &str,
20718                          analytical: &[f32],
20719                          mut perturb: Box<
20720            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
20721        >,
20722                          n: usize| {
20723            for i in 0..n {
20724                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
20725                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
20726                let num = (lp - lm) / (2.0 * eps);
20727                assert!(
20728                    (analytical[i] - num).abs() < 5e-3,
20729                    "{name}[{i}]: analytical {} vs numerical {num}",
20730                    analytical[i]
20731                );
20732            }
20733        };
20734
20735        // Helper to perturb one param and run forward. Kept as a
20736        // reference for the explicit per-param sweep pattern below.
20737        #[allow(unused_macros)]
20738        macro_rules! sweep {
20739            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
20740                let n = $base.len();
20741                for i in 0..n {
20742                    let mut p = $base.clone();
20743                    let s = p[i];
20744                    p[i] = s + eps;
20745                    let lp = {
20746                        let $set_param = &p;
20747                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
20748                        // Reset others, set the one being swept, run.
20749                        // (the macro receives one of the four params via $set_param)
20750                        let _ = $set_param;
20751                        // Fall through to the explicit per-param helper:
20752                        0.0_f32
20753                    };
20754                    let _ = lp;
20755                }
20756            }};
20757        }
20758        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
20759
20760        // Per-param sweeps (explicit, not macro — clearer).
20761        for i in 0..wc.len() {
20762            let mut p = wc.clone();
20763            let s = p[i];
20764            p[i] = s + eps;
20765            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20766            p[i] = s - eps;
20767            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
20768            let num = (lp - lm) / (2.0 * eps);
20769            assert!(
20770                (g_wc[i] - num).abs() < 5e-3,
20771                "g_wc[{i}]: {} vs {num}",
20772                g_wc[i]
20773            );
20774        }
20775        for i in 0..bc.len() {
20776            let mut p = bc.clone();
20777            let s = p[i];
20778            p[i] = s + eps;
20779            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20780            p[i] = s - eps;
20781            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
20782            let num = (lp - lm) / (2.0 * eps);
20783            assert!(
20784                (g_bc[i] - num).abs() < 5e-3,
20785                "g_bc[{i}]: {} vs {num}",
20786                g_bc[i]
20787            );
20788        }
20789        for i in 0..wfc.len() {
20790            let mut p = wfc.clone();
20791            let s = p[i];
20792            p[i] = s + eps;
20793            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20794            p[i] = s - eps;
20795            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
20796            let num = (lp - lm) / (2.0 * eps);
20797            assert!(
20798                (g_wfc[i] - num).abs() < 5e-3,
20799                "g_wfc[{i}]: {} vs {num}",
20800                g_wfc[i]
20801            );
20802        }
20803        for i in 0..bfc.len() {
20804            let mut p = bfc.clone();
20805            let s = p[i];
20806            p[i] = s + eps;
20807            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20808            p[i] = s - eps;
20809            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
20810            let num = (lp - lm) / (2.0 * eps);
20811            assert!(
20812                (g_bfc[i] - num).abs() < 5e-3,
20813                "g_bfc[{i}]: {} vs {num}",
20814                g_bfc[i]
20815            );
20816        }
20817    }
20818
20819    /// Negative case: a Narrow whose output has multiple consumers
20820    /// must NOT be fused (we can't elide its write — something else
20821    /// reads it).
20822    #[test]
20823    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
20824        let f = DType::F32;
20825        let mut g = Graph::new("nr_skip");
20826        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
20827        let cos = g.input("cos", Shape::new(&[16], f));
20828        let sin = g.input("sin", Shape::new(&[16], f));
20829        let q = g.narrow_(qkv, 2, 0, 64);
20830        let q_rope = g.rope(q, cos, sin, 16);
20831        // Second consumer of `q` blocks the fusion.
20832        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
20833        g.set_outputs(vec![q_rope, q_dup]);
20834
20835        let plan = rlx_opt::memory::plan_memory(&g);
20836        let arena = crate::arena::Arena::from_plan(plan);
20837        let sched = compile_thunks(&g, &arena);
20838
20839        let narrow_count = sched
20840            .thunks
20841            .iter()
20842            .filter(|t| matches!(t, Thunk::Narrow { .. }))
20843            .count();
20844        assert!(
20845            narrow_count >= 1,
20846            "Narrow with multiple consumers must NOT be fused away"
20847        );
20848    }
20849
20850    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
20851    //
20852    // Validates: forward execution inlines fwd_body; VJP rule inlines
20853    // vjp_body in place of recursing into fwd_body; JVP rule inlines
20854    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
20855    // would yield a *different* gradient than the override, so we know
20856    // the override actually fired.
20857
20858    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
20859    /// without override AD bodies. Verifies the body is compiled,
20860    /// constants in the body fill correctly, and the output lands at
20861    /// the outer node's slot.
20862    #[test]
20863    fn custom_fn_forward_inlines_body() {
20864        let s = Shape::new(&[3], DType::F32);
20865
20866        // Body: f(x) = x + 1
20867        let mut body = Graph::new("addone_body");
20868        let x = body.input("x", s.clone());
20869        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
20870        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
20871        let y = body.binary(BinaryOp::Add, x, one, s.clone());
20872        body.set_outputs(vec![y]);
20873
20874        let mut g = Graph::new("custom_fn_outer");
20875        let xin = g.input("x_in", s.clone());
20876        let cf = g.custom_fn(vec![xin], body, None, None);
20877        g.set_outputs(vec![cf]);
20878
20879        let xs = vec![10.0_f32, 20.0, 30.0];
20880        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
20881        execute_thunks(&sched, arena.raw_buf_mut());
20882        let got = read_arena(&arena, cf, 3);
20883        assert_eq!(got, vec![11.0, 21.0, 31.0]);
20884    }
20885
20886    /// Locate an Op::Input or Op::Param by name in a graph.
20887    fn find_named(graph: &Graph, want: &str) -> NodeId {
20888        for n in graph.nodes() {
20889            let name = match &n.op {
20890                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20891                _ => None,
20892            };
20893            if name == Some(want) {
20894                return n.id;
20895            }
20896        }
20897        panic!("no node named {want:?} in graph");
20898    }
20899
20900    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
20901    /// reported gradient should be 2 — different from the natural 1
20902    /// you'd get by recursing into the identity body.
20903    #[test]
20904    fn custom_fn_vjp_overrides_natural_gradient() {
20905        use rlx_opt::autodiff::grad_with_loss;
20906        let s = Shape::new(&[1], DType::F32);
20907
20908        let mut fwd = Graph::new("id_fwd");
20909        let x = fwd.input("x", s.clone());
20910        fwd.set_outputs(vec![x]);
20911
20912        let mut vjp_g = Graph::new("id_vjp");
20913        let _x_p = vjp_g.input("x", s.clone());
20914        let _y_p = vjp_g.input("primal_output", s.clone());
20915        let dy = vjp_g.input("d_output", s.clone());
20916        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
20917        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
20918        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
20919        vjp_g.set_outputs(vec![dx]);
20920
20921        let mut g = Graph::new("outer");
20922        let xp = g.param("x", s.clone());
20923        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
20924        g.set_outputs(vec![cf]);
20925
20926        let bwd = grad_with_loss(&g, &[xp]);
20927        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
20928
20929        let xb = find_named(&bwd, "x");
20930        let dout = find_named(&bwd, "d_output");
20931        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
20932        execute_thunks(&sched, arena.raw_buf_mut());
20933        let loss = read_arena(&arena, bwd.outputs[0], 1);
20934        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
20935        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
20936        assert!(
20937            (dx_v[0] - 2.0).abs() < 1e-6,
20938            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
20939            dx_v[0]
20940        );
20941    }
20942
20943    /// VJP override: f(a, b) = a*b with vjp_body returning
20944    /// (b * d_output, a * d_output). Validates routing of multiple
20945    /// primals + d_output through the override; matches the natural
20946    /// autodiff-of-Mul gradient (b, a).
20947    #[test]
20948    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
20949        use rlx_opt::autodiff::grad_with_loss;
20950        let s = Shape::new(&[1], DType::F32);
20951
20952        let mut fwd = Graph::new("mul_fwd");
20953        let a_f = fwd.input("a", s.clone());
20954        let b_f = fwd.input("b", s.clone());
20955        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
20956        fwd.set_outputs(vec![y_f]);
20957
20958        let mut vjp_g = Graph::new("mul_vjp");
20959        let a_v = vjp_g.input("a", s.clone());
20960        let b_v = vjp_g.input("b", s.clone());
20961        let _y_v = vjp_g.input("primal_output", s.clone());
20962        let dy_v = vjp_g.input("d_output", s.clone());
20963        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
20964        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
20965        vjp_g.set_outputs(vec![da, db]);
20966
20967        let mut g = Graph::new("outer");
20968        let ap = g.param("a", s.clone());
20969        let bp = g.param("b", s.clone());
20970        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
20971        g.set_outputs(vec![cf]);
20972
20973        let bwd = grad_with_loss(&g, &[ap, bp]);
20974        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
20975
20976        let ab = find_named(&bwd, "a");
20977        let bb = find_named(&bwd, "b");
20978        let dout = find_named(&bwd, "d_output");
20979        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
20980        execute_thunks(&sched, arena.raw_buf_mut());
20981        let loss = read_arena(&arena, bwd.outputs[0], 1);
20982        let da_v = read_arena(&arena, bwd.outputs[1], 1);
20983        let db_v = read_arena(&arena, bwd.outputs[2], 1);
20984        assert!((loss[0] - 15.0).abs() < 1e-5);
20985        assert!(
20986            (da_v[0] - 5.0).abs() < 1e-5,
20987            "da should be b=5.0, got {}",
20988            da_v[0]
20989        );
20990        assert!(
20991            (db_v[0] - 3.0).abs() < 1e-5,
20992            "db should be a=3.0, got {}",
20993            db_v[0]
20994        );
20995    }
20996
20997    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
20998    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
20999    #[test]
21000    fn custom_fn_jvp_overrides_natural_tangent() {
21001        use rlx_opt::autodiff_fwd::jvp;
21002        let s = Shape::new(&[1], DType::F32);
21003
21004        let mut fwd = Graph::new("id_fwd");
21005        let x = fwd.input("x", s.clone());
21006        fwd.set_outputs(vec![x]);
21007
21008        let mut jvp_g = Graph::new("id_jvp");
21009        let _x_p = jvp_g.input("x", s.clone());
21010        let tx = jvp_g.input("tangent_0", s.clone());
21011        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21012        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21013        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
21014        jvp_g.set_outputs(vec![ty]);
21015
21016        let mut g = Graph::new("outer");
21017        let xin = g.input("x_in", s.clone());
21018        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
21019        g.set_outputs(vec![cf]);
21020
21021        let fwd_g = jvp(&g, &[xin]);
21022        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
21023
21024        let xb = find_named(&fwd_g, "x_in");
21025        let tan = find_named(&fwd_g, "tangent_x_in");
21026        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
21027        execute_thunks(&sched, arena.raw_buf_mut());
21028        let y = read_arena(&arena, fwd_g.outputs[0], 1);
21029        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
21030        assert!((y[0] - 7.0).abs() < 1e-6);
21031        assert!(
21032            (ty_v[0] - 2.0).abs() < 1e-6,
21033            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
21034            ty_v[0]
21035        );
21036    }
21037
21038    /// IR-level basic test: `DType::C64` is wired through the dtype
21039    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
21040    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
21041    /// 16 bytes.
21042    #[test]
21043    fn c64_dtype_storage_layout() {
21044        assert_eq!(
21045            DType::C64.size_bytes(),
21046            8,
21047            "C64 should be 8 bytes (f32 real + f32 imag)"
21048        );
21049        assert!(DType::C64.is_complex());
21050        assert!(!DType::C64.is_float());
21051
21052        // A length-2 C64 buffer should have shape size_bytes = 16.
21053        let s = Shape::new(&[2], DType::C64);
21054        assert_eq!(s.size_bytes().unwrap(), 16);
21055    }
21056
21057    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
21058    //
21059    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
21060    // output = a OP b. Run through CompileResult and compare against
21061    // the closed-form complex arithmetic on the four chosen pairs.
21062
21063    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
21064        let n = a.len();
21065        let s = Shape::new(&[n], DType::C64);
21066        let mut g = Graph::new("c64_bin");
21067        let in_a = g.input("a", s.clone());
21068        let in_b = g.input("b", s.clone());
21069        let out = g.binary(op, in_a, in_b, s.clone());
21070        g.set_outputs(vec![out]);
21071
21072        let plan = rlx_opt::memory::plan_memory(&g);
21073        let mut arena = crate::arena::Arena::from_plan(plan);
21074        let sched = compile_thunks(&g, &arena);
21075
21076        let a_off = arena.byte_offset(in_a);
21077        let b_off = arena.byte_offset(in_b);
21078        let out_off = arena.byte_offset(out);
21079        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
21080        let buf = arena.raw_buf_mut();
21081        unsafe {
21082            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21083            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21084            for (i, &(re, im)) in a.iter().enumerate() {
21085                *pa.add(2 * i) = re;
21086                *pa.add(2 * i + 1) = im;
21087            }
21088            for (i, &(re, im)) in b.iter().enumerate() {
21089                *pb.add(2 * i) = re;
21090                *pb.add(2 * i + 1) = im;
21091            }
21092        }
21093        execute_thunks(&sched, arena.raw_buf_mut());
21094        let raw_out: Vec<f32> = unsafe {
21095            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21096            (0..(2 * n)).map(|i| *p.add(i)).collect()
21097        };
21098        (0..n)
21099            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21100            .collect()
21101    }
21102
21103    #[track_caller]
21104    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21105        let dr = (got.0 - expected.0).abs();
21106        let di = (got.1 - expected.1).abs();
21107        assert!(
21108            dr < tol && di < tol,
21109            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21110            got.0,
21111            got.1,
21112            expected.0,
21113            expected.1
21114        );
21115    }
21116
21117    #[test]
21118    fn c64_binary_add_matches_complex_arithmetic() {
21119        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21120        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21121        let out = run_c64_binary(BinaryOp::Add, &a, &b);
21122        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21123        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21124    }
21125
21126    #[test]
21127    fn c64_binary_sub_matches_complex_arithmetic() {
21128        let a = [(5.0_f32, 1.0_f32)];
21129        let b = [(2.0_f32, 3.0_f32)];
21130        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21131        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21132    }
21133
21134    #[test]
21135    fn c64_binary_mul_matches_complex_arithmetic() {
21136        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
21137        let a = [(1.0_f32, 2.0_f32)];
21138        let b = [(3.0_f32, 4.0_f32)];
21139        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21140        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21141    }
21142
21143    #[test]
21144    fn c64_binary_div_matches_complex_arithmetic() {
21145        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
21146        //                     = (11 + 2i) / 25
21147        //                     = 0.44 + 0.08i
21148        let a = [(1.0_f32, 2.0_f32)];
21149        let b = [(3.0_f32, 4.0_f32)];
21150        let out = run_c64_binary(BinaryOp::Div, &a, &b);
21151        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
21152    }
21153
21154    #[test]
21155    fn c64_binary_mul_identity_one_is_no_op() {
21156        // (a + bi) · (1 + 0i) = a + bi.
21157        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
21158        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
21159        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21160        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
21161        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
21162    }
21163
21164    #[test]
21165    fn c64_binary_mul_by_i_rotates_90_degrees() {
21166        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
21167        let a = [(1.0_f32, 0.0_f32)];
21168        let b = [(0.0_f32, 1.0_f32)];
21169        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21170        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
21171    }
21172
21173    #[test]
21174    fn c64_binary_div_by_self_gives_unity() {
21175        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
21176        let out = run_c64_binary(BinaryOp::Div, &a, &a);
21177        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
21178        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
21179    }
21180
21181    #[test]
21182    #[should_panic(expected = "C64: complex max/min/pow")]
21183    fn c64_binary_max_is_rejected_at_lowering() {
21184        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
21185    }
21186
21187    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
21188        let n = a.len();
21189        let s = Shape::new(&[n], DType::C64);
21190        let mut g = Graph::new("c64_act");
21191        let in_a = g.input("a", s.clone());
21192        let out = g.activation(act, in_a, s.clone());
21193        g.set_outputs(vec![out]);
21194        let plan = rlx_opt::memory::plan_memory(&g);
21195        let mut arena = crate::arena::Arena::from_plan(plan);
21196        let sched = compile_thunks(&g, &arena);
21197        let a_off = arena.byte_offset(in_a);
21198        let out_off = arena.byte_offset(out);
21199        let buf = arena.raw_buf_mut();
21200        unsafe {
21201            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21202            for (i, &(re, im)) in a.iter().enumerate() {
21203                *pa.add(2 * i) = re;
21204                *pa.add(2 * i + 1) = im;
21205            }
21206        }
21207        execute_thunks(&sched, arena.raw_buf_mut());
21208        let raw: Vec<f32> = unsafe {
21209            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21210            (0..(2 * n)).map(|i| *p.add(i)).collect()
21211        };
21212        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
21213    }
21214
21215    #[test]
21216    fn c64_activation_neg_negates_both_components() {
21217        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
21218        let out = run_c64_activation(Activation::Neg, &inp);
21219        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
21220        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
21221    }
21222
21223    #[test]
21224    fn c64_activation_exp_matches_euler() {
21225        // exp(0 + i·π) = -1 + 0i.
21226        // exp(1 + 0i) = e ≈ 2.71828.
21227        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
21228        let out = run_c64_activation(Activation::Exp, &inp);
21229        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
21230        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
21231    }
21232
21233    #[test]
21234    fn c64_activation_log_matches_principal_branch() {
21235        // log(1 + 0i) = 0.
21236        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
21237        // log(-1 + 0i) = 0 + i·π.
21238        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
21239        let out = run_c64_activation(Activation::Log, &inp);
21240        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
21241        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
21242        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
21243    }
21244
21245    #[test]
21246    fn c64_activation_sqrt_squared_recovers_input() {
21247        // For positive-real-part inputs, sqrt(z)² should equal z exactly
21248        // to f32 noise.
21249        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
21250        let roots = run_c64_activation(Activation::Sqrt, &inp);
21251        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
21252        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
21253        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
21254    }
21255
21256    #[test]
21257    #[should_panic(expected = "no natural complex extension")]
21258    fn c64_activation_relu_is_rejected_at_lowering() {
21259        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
21260    }
21261
21262    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
21263
21264    /// Forward `|z|²`: returns `[n]` f32.
21265    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
21266        let n = z.len();
21267        let mut g = Graph::new("cns_fwd");
21268        let in_z = g.input("z", Shape::new(&[n], DType::C64));
21269        let out = g.complex_norm_sq(in_z);
21270        g.set_outputs(vec![out]);
21271        let plan = rlx_opt::memory::plan_memory(&g);
21272        let mut arena = crate::arena::Arena::from_plan(plan);
21273        let sched = compile_thunks(&g, &arena);
21274        let z_off = arena.byte_offset(in_z);
21275        let out_off = arena.byte_offset(out);
21276        let buf = arena.raw_buf_mut();
21277        unsafe {
21278            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21279            for (i, &(re, im)) in z.iter().enumerate() {
21280                *pz.add(2 * i) = re;
21281                *pz.add(2 * i + 1) = im;
21282            }
21283        }
21284        execute_thunks(&sched, arena.raw_buf_mut());
21285        unsafe {
21286            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21287            (0..n).map(|i| *p.add(i)).collect()
21288        }
21289    }
21290
21291    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
21292    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
21293        let n = z.len();
21294        let mut gr = Graph::new("cns_bwd");
21295        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
21296        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
21297        let out = gr.complex_norm_sq_backward(in_z, in_g);
21298        gr.set_outputs(vec![out]);
21299        let plan = rlx_opt::memory::plan_memory(&gr);
21300        let mut arena = crate::arena::Arena::from_plan(plan);
21301        let sched = compile_thunks(&gr, &arena);
21302        let z_off = arena.byte_offset(in_z);
21303        let g_off = arena.byte_offset(in_g);
21304        let out_off = arena.byte_offset(out);
21305        let buf = arena.raw_buf_mut();
21306        unsafe {
21307            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
21308            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
21309            for (i, &(re, im)) in z.iter().enumerate() {
21310                *pz.add(2 * i) = re;
21311                *pz.add(2 * i + 1) = im;
21312            }
21313            for (i, &v) in g.iter().enumerate() {
21314                *pg.add(i) = v;
21315            }
21316        }
21317        execute_thunks(&sched, arena.raw_buf_mut());
21318        unsafe {
21319            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21320            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
21321        }
21322    }
21323
21324    #[test]
21325    fn complex_norm_sq_matches_textbook() {
21326        // |3 + 4i|² = 9 + 16 = 25.
21327        // |1 + 0i|² = 1.
21328        // |0 + 0i|² = 0.
21329        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
21330        let out = run_complex_norm_sq(&z);
21331        assert!((out[0] - 25.0).abs() < 1e-5);
21332        assert!((out[1] - 1.0).abs() < 1e-6);
21333        assert!(out[2].abs() < 1e-6);
21334    }
21335
21336    #[test]
21337    fn complex_norm_sq_backward_matches_wirtinger_formula() {
21338        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
21339        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
21340        let g = [1.0_f32, 1.0_f32];
21341        let dz = run_complex_norm_sq_bwd(&z, &g);
21342        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
21343        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
21344    }
21345
21346    #[test]
21347    fn complex_norm_sq_backward_scales_with_upstream() {
21348        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
21349        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
21350        let g = [0.5_f32, -2.0_f32];
21351        let dz = run_complex_norm_sq_bwd(&z, &g);
21352        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
21353        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
21354    }
21355
21356    /// Multi-output Op::CustomFn via the concat-with-Narrow design
21357    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
21358    /// fwd_body returns two outputs (x², 2x), then materialize each
21359    /// via the MultiOutputHandle and verify both numerically.
21360    #[test]
21361    fn custom_fn_multi_extracts_each_subgraph_output() {
21362        use rlx_ir::ops::special::MultiOutputHandle;
21363
21364        let _ = MultiOutputHandle {
21365            source: NodeId(0),
21366            sub_shapes: vec![],
21367            offsets: vec![],
21368        }; // import sanity
21369
21370        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
21371        let mut body = Graph::new("multi_body");
21372        let s3 = Shape::new(&[3], DType::F32);
21373        let x = body.input("x", s3.clone());
21374        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
21375        let two = body.add_node(
21376            Op::Constant {
21377                data: vec![
21378                    2.0_f32.to_le_bytes(),
21379                    2.0_f32.to_le_bytes(),
21380                    2.0_f32.to_le_bytes(),
21381                ]
21382                .into_iter()
21383                .flatten()
21384                .collect(),
21385            },
21386            vec![],
21387            s3.clone(),
21388        );
21389        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
21390        body.set_outputs(vec![x_sq, two_x]);
21391
21392        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
21393        let mut outer = Graph::new("multi_outer");
21394        let in_x = outer.input("xin", s3.clone());
21395        let handle = outer.custom_fn_multi(vec![in_x], body);
21396        assert_eq!(handle.n_outputs(), 2);
21397        let out0 = handle.output(&mut outer, 0); // x²
21398        let out1 = handle.output(&mut outer, 1); // 2x
21399        outer.set_outputs(vec![out0, out1]);
21400
21401        let plan = rlx_opt::memory::plan_memory(&outer);
21402        let mut arena = crate::arena::Arena::from_plan(plan);
21403        let sched = compile_thunks(&outer, &arena);
21404        let xin_off = arena.byte_offset(in_x);
21405        let out0_off = arena.byte_offset(out0);
21406        let out1_off = arena.byte_offset(out1);
21407        let xs = [1.0_f32, 2.0, 3.0];
21408        unsafe {
21409            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
21410            for (i, &v) in xs.iter().enumerate() {
21411                *p.add(i) = v;
21412            }
21413        }
21414        execute_thunks(&sched, arena.raw_buf_mut());
21415        let out0_v: Vec<f32> = unsafe {
21416            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
21417            (0..3).map(|i| *p.add(i)).collect()
21418        };
21419        let out1_v: Vec<f32> = unsafe {
21420            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
21421            (0..3).map(|i| *p.add(i)).collect()
21422        };
21423        // x² = [1, 4, 9]; 2x = [2, 4, 6].
21424        for i in 0..3 {
21425            assert!(
21426                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
21427                "out0[{i}] = {} != x² = {}",
21428                out0_v[i],
21429                xs[i] * xs[i]
21430            );
21431            assert!(
21432                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
21433                "out1[{i}] = {} != 2x = {}",
21434                out1_v[i],
21435                2.0 * xs[i]
21436            );
21437        }
21438    }
21439
21440    #[test]
21441    fn complex_norm_sq_gradient_matches_finite_difference() {
21442        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
21443        let z = [(3.0_f32, 4.0_f32)];
21444        let eps = 1e-3_f32;
21445        let v0 = run_complex_norm_sq(&z)[0];
21446        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
21447        let v1 = run_complex_norm_sq(&z_pert)[0];
21448        let fd_re = (v1 - v0) / eps;
21449        let analytic_re = 2.0 * z[0].0;
21450        assert!((fd_re - analytic_re).abs() < 1e-2);
21451
21452        // ∂/∂im at z = (3, 4) is 2·im = 8.
21453        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
21454        let v2 = run_complex_norm_sq(&z_pert_im)[0];
21455        let fd_im = (v2 - v0) / eps;
21456        let analytic_im = 2.0 * z[0].1;
21457        assert!((fd_im - analytic_im).abs() < 1e-2);
21458
21459        // Compare with the Wirtinger backward at upstream g = 1.
21460        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
21461        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
21462        // 2·im) — that's the factor 2 difference between Wirtinger
21463        // ∂/∂z̄ and the real-vector gradient on (re, im).
21464        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
21465        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
21466        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
21467    }
21468
21469    /// Direct regression test for the 5-D mid-shape singleton broadcast
21470    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
21471    /// The SAM port worked around this by `concat`-tiling the rhs; this
21472    /// test verifies the in-graph broadcast path is bit-correct.
21473    #[test]
21474    fn binary_full_5d_mid_singleton_broadcast() {
21475        let bh = 2usize;
21476        let h = 3;
21477        let w = 4;
21478        let f = DType::F32;
21479
21480        let mut g = Graph::new("bcast_5d");
21481        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
21482        // rhs shape with size-1 at axis 3 (mid-shape singleton).
21483        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
21484        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
21485        g.set_outputs(vec![out]);
21486
21487        // Deterministic data.
21488        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
21489        let rhs_data: Vec<f32> = (0..bh * h * w * w)
21490            .map(|i| (i as f32 + 100.0) * 0.01)
21491            .collect();
21492
21493        // Compute expected output by hand.
21494        let mut expected = vec![0f32; bh * h * w * h * w];
21495        for b_ in 0..bh {
21496            for hq in 0..h {
21497                for wq in 0..w {
21498                    for hk in 0..h {
21499                        for wk in 0..w {
21500                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
21501                            // rhs has hk dim = 1, so it's always index 0 there.
21502                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
21503                            expected[li] = lhs_data[li] + rhs_data[ri];
21504                        }
21505                    }
21506                }
21507            }
21508        }
21509
21510        let plan = rlx_opt::memory::plan_memory(&g);
21511        let mut arena = crate::arena::Arena::from_plan(plan);
21512        let sched = compile_thunks(&g, &arena);
21513        let lhs_off = arena.byte_offset(lhs);
21514        let rhs_off = arena.byte_offset(rhs);
21515        let out_off = arena.byte_offset(out);
21516        let buf = arena.raw_buf_mut();
21517        unsafe {
21518            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
21519            for (i, &v) in lhs_data.iter().enumerate() {
21520                *p.add(i) = v;
21521            }
21522            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
21523            for (i, &v) in rhs_data.iter().enumerate() {
21524                *p.add(i) = v;
21525            }
21526        }
21527        execute_thunks(&sched, arena.raw_buf_mut());
21528        let actual: Vec<f32> = unsafe {
21529            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21530            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
21531        };
21532
21533        // Bit-exact check.
21534        let mut max_diff = 0f32;
21535        let mut max_idx = 0;
21536        for i in 0..actual.len() {
21537            let d = (actual[i] - expected[i]).abs();
21538            if d > max_diff {
21539                max_diff = d;
21540                max_idx = i;
21541            }
21542        }
21543        assert!(
21544            max_diff < 1e-6,
21545            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
21546             (actual={}, expected={})",
21547            actual[max_idx],
21548            expected[max_idx]
21549        );
21550    }
21551
21552    #[test]
21553    fn layer_norm2d_and_conv_transpose2d_kernels() {
21554        let mut out = vec![0f32; 8];
21555        crate::kernels::layer_norm2d_nchw(
21556            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
21557            &[1.0, 1.0],
21558            &[0.0, 0.0],
21559            &mut out,
21560            1,
21561            2,
21562            2,
21563            2,
21564            1e-5,
21565        );
21566        let mean0: f32 = (1.0 + 3.0) / 2.0;
21567        assert!((out[0] - mean0).abs() > 0.1);
21568
21569        let mut up = vec![0f32; 4];
21570        crate::kernels::conv_transpose2d_nchw(
21571            &[2.0],
21572            &[1.0, 0.0, 0.0, 1.0],
21573            &mut up,
21574            1,
21575            1,
21576            1,
21577            1,
21578            1,
21579            2,
21580            2,
21581            2,
21582            2,
21583            2,
21584            2,
21585            0,
21586            0,
21587            1,
21588            1,
21589            1,
21590        );
21591        assert!((up[0] - 2.0).abs() < 1e-5);
21592        assert!((up[3] - 2.0).abs() < 1e-5);
21593    }
21594}