Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// Complex (C64) dense GEMM `C = A·B`. Operands are interleaved
48    /// `[re, im]` f32; `a`/`b`/`c` are byte offsets, `m`/`k`/`n` are
49    /// complex-element matrix dims (`A` `[m,k]`, `B` `[k,n]`, `C` `[m,n]`).
50    CgemmC64 {
51        a: usize,
52        b: usize,
53        c: usize,
54        m: u32,
55        k: u32,
56        n: u32,
57    },
58    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
59    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
60    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
61    /// The kernel materializes scratch copies of A and b internally
62    /// (LAPACK overwrites both with LU factors and solution).
63    DenseSolveF64 {
64        a: usize,
65        b: usize,
66        x: usize,
67        n: u32,
68        nrhs: u32,
69    },
70    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
71    /// no-blas Rust fallback). Same arena byte-offset contract.
72    DenseSolveF32 {
73        a: usize,
74        b: usize,
75        x: usize,
76        n: u32,
77        nrhs: u32,
78    },
79    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
80    /// the leading slice; `batch` is the number of independent
81    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
82    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
83    BatchedDenseSolveF64 {
84        a: usize,
85        b: usize,
86        x: usize,
87        batch: u32,
88        n: u32,
89        nrhs: u32,
90    },
91    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
92    BatchedDenseSolveF32 {
93        a: usize,
94        b: usize,
95        x: usize,
96        batch: u32,
97        n: u32,
98        nrhs: u32,
99    },
100    /// Batched f64 matmul. Both inputs and output have a leading
101    /// batch axis of size `batch`. Per-batch independent dgemm:
102    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
103    /// that emit per-batch outer products (e.g., BatchedDenseSolve
104    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
105    BatchedDgemmF64 {
106        a: usize,
107        b: usize,
108        c: usize,
109        batch: u32,
110        m: u32,
111        k: u32,
112        n: u32,
113    },
114    /// Batched f32 matmul — same loop-per-batch shape as
115    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
116    /// patterns where both operands carry a batch dim (e.g. q@k^T
117    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
118    /// flatten trick is wrong in that case because it treats `b` as
119    /// a single shared RHS across every batch.
120    BatchedSgemm {
121        a: usize,
122        b: usize,
123        c: usize,
124        batch: u32,
125        m: u32,
126        k: u32,
127        n: u32,
128    },
129    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
130    Dgemm {
131        a: usize,
132        b: usize,
133        c: usize,
134        m: u32,
135        k: u32,
136        n: u32,
137    },
138    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
139    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
140    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
141    TransposeF64 {
142        src: usize,
143        dst: usize,
144        in_total: u32,
145        out_dims: Vec<u32>,
146        in_strides: Vec<u32>,
147    },
148    /// f64 element-wise activation. Single-input, single-output. The
149    /// kernel always reads from `src` and writes to `dst`, so it works
150    /// whether or not the planner aliased the two slots.
151    ActivationF64 {
152        src: usize,
153        dst: usize,
154        len: u32,
155        kind: Activation,
156    },
157    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
158    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
159    /// writes `len` f32 to `dst`.
160    ComplexNormSqF32 {
161        src: usize,
162        dst: usize,
163        /// Logical element count (number of complex values).
164        len: u32,
165    },
166    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
167    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
168    /// `2·len` f32 to `dz`.
169    ComplexNormSqBackwardF32 {
170        z: usize,
171        g: usize,
172        dz: usize,
173        len: u32,
174    },
175    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
176    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
177    ConjugateC64 {
178        src: usize,
179        dst: usize,
180        len: u32,
181    },
182    /// C64 element-wise activation. Only kinds with well-defined
183    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
184    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
185    /// Round, GeLU family) is rejected at lowering — those don't have
186    /// single natural complex definitions. `len` is the **complex
187    /// element count** (the f32 buffer holds `2·len` floats).
188    ActivationC64 {
189        src: usize,
190        dst: usize,
191        len: u32,
192        kind: Activation,
193    },
194    /// f64 contiguous reduction along a single axis range. Layout
195    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
196    /// Sum only for now (Mean composes via 1/N multiply post-pass).
197    ReduceSumF64 {
198        src: usize,
199        dst: usize,
200        outer: u32,
201        reduced: u32,
202        inner: u32,
203    },
204    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
205    /// but at 8 bytes per element.
206    CopyF64 {
207        src: usize,
208        dst: usize,
209        len: u32,
210    },
211    /// i64 element copy (Reshape/Cast on i64 tensors).
212    CopyI64 {
213        src: usize,
214        dst: usize,
215        len: u32,
216    },
217    /// Round f32 → i64 (ONNX Cast on duration scalar).
218    CastF32ToI64 {
219        src: usize,
220        dst: usize,
221        len: u32,
222    },
223    CastF32ToF64 {
224        src: usize,
225        dst: usize,
226        len: u32,
227    },
228    CastF32ToI32 {
229        src: usize,
230        dst: usize,
231        len: u32,
232    },
233    /// i64 → f32 (ONNX Cast on shape scalars, e.g. Albert head-dim).
234    CastI64ToF32 {
235        src: usize,
236        dst: usize,
237        len: u32,
238    },
239    /// bool → i32 (BERT attention mask grid).
240    CastBoolToI32 {
241        src: usize,
242        dst: usize,
243        len: u32,
244    },
245    CastBoolToF32 {
246        src: usize,
247        dst: usize,
248        len: u32,
249    },
250    /// i32 → f32 (BERT attention mask cast before subtract).
251    CastI32ToF32 {
252        src: usize,
253        dst: usize,
254        len: u32,
255    },
256    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
257    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
258    /// Mirror of `BinaryFull` at 8 bytes per element.
259    BinaryFullF64 {
260        lhs: usize,
261        rhs: usize,
262        dst: usize,
263        len: u32,
264        lhs_len: u32,
265        rhs_len: u32,
266        op: BinaryOp,
267        /// Output shape dims (row-major). Empty in the fast path. See
268        /// `BinaryFull` doc for the broadcast convention.
269        out_dims_bcast: Vec<u32>,
270        bcast_lhs_strides: Vec<u32>,
271        bcast_rhs_strides: Vec<u32>,
272    },
273    /// f64 concat — byte-for-byte mirror of `Concat` but copies
274    /// 8 bytes per element. Element-counted offsets/strides match
275    /// the f32 variant; the executor scales by elem_size internally.
276    ConcatF64 {
277        dst: usize,
278        outer: u32,
279        inner: u32,
280        total_axis: u32,
281        inputs: Vec<(usize, u32, u32)>,
282    },
283    /// C64 element-wise binary with broadcast. Same `len` /
284    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
285    /// "element" is one complex value (8 bytes = `[re, im]` as two
286    /// f32s). The executor reads the underlying f32 buffer at
287    /// `2·len` floats and walks element pairs. Supports Add / Sub /
288    /// Mul / Div; Max / Min / Pow have no single natural complex
289    /// definition and panic at lowering.
290    BinaryFullC64 {
291        lhs: usize,
292        rhs: usize,
293        dst: usize,
294        /// Complex element count (NOT f32 count). f32 buffer length
295        /// is `2·len`.
296        len: u32,
297        lhs_len: u32,
298        rhs_len: u32,
299        op: BinaryOp,
300        out_dims_bcast: Vec<u32>,
301        bcast_lhs_strides: Vec<u32>,
302        bcast_rhs_strides: Vec<u32>,
303    },
304    /// Bounded scan. Holds a recursively-compiled body schedule + a
305    /// pre-initialized body arena snapshot (constants filled). Each
306    /// outer execution clones the snapshot, copies the carry-in slot
307    /// from the outer arena, runs the body schedule `length` times,
308    /// then writes the final carry to the outer arena.
309    ///
310    /// Single-carry MVP — body has exactly one Input and one output,
311    /// both same shape and dtype.
312    Scan {
313        body: Arc<ThunkSchedule>,
314        body_init: Arc<Vec<u8>>, // pristine body arena bytes
315        body_input_off: usize,   // byte offset of the body's carry-Input slot
316        body_output_off: usize,  // byte offset of the body's output slot
317        outer_init_off: usize,   // outer-arena offset of the initial carry
318        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
319        length: u32,
320        carry_bytes: u32, // carry size in bytes
321        /// When true, write each step's carry to the outer arena at
322        /// offset `outer_final_off + t * carry_bytes`, producing a
323        /// `[length, *carry]` stacked trajectory. When false, only the
324        /// final carry lands at `outer_final_off`.
325        save_trajectory: bool,
326        /// Per-step `xs` inputs. For each: (body_x_input_off,
327        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
328        /// executor copies `outer_xs_base_off + t * per_step_bytes`
329        /// into `body_x_input_off`. Empty when the scan has no xs.
330        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
331        /// Broadcast inputs — values constant across iterations. For
332        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
333        /// Filled into `body_buf` ONCE before the scan loop starts
334        /// (xs in contrast are re-filled every iteration). Empty when
335        /// the scan has no bcasts.
336        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
337        /// Number of trajectory checkpoints (when `save_trajectory`).
338        /// `0` or `length` ⇒ save every iteration. Otherwise save only
339        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
340        /// `k in 0..K`. Last index is always `length-1` so the final
341        /// carry is always cached.
342        num_checkpoints: u32,
343    },
344
345    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
346    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
347    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
348    /// slice, and the current `dcarry` into the body_vjp's Input
349    /// slots, runs body_vjp, reads new `dcarry` from its single output.
350    /// f64 carry only — the upstream-accumulation step in trajectory
351    /// mode does an element-wise f64 add.
352    ScanBackward {
353        body_vjp: Arc<ThunkSchedule>,
354        body_init: Arc<Vec<u8>>,
355        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
356        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
357        body_d_output_off: usize, // body_vjp's "d_output" Input slot
358        body_dcarry_out_off: usize, // body_vjp's gradient output
359        outer_init_off: usize,    // original init carry
360        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
361        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
362        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
363        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
364        outer_xs_offs: Arc<Vec<(usize, u32)>>,
365        outer_dinit_off: usize, // output: dinit
366        length: u32,
367        carry_bytes: u32,
368        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
369        /// Used to dispatch the trajectory-mode upstream accumulation
370        /// kernel (the dcarry += upstream\[t\] add must use the right
371        /// floating-point type — a hard-coded f64 add silently does
372        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
373        carry_elem_size: u32,
374        save_trajectory: bool, // true → upstream is per-step; false → just final
375        /// Recursive checkpointing config. `0` or `length` ⇒ full
376        /// trajectory cached, no recompute (existing behavior).
377        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
378        /// recomputes intermediate carries via `forward_body` between
379        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
380        num_checkpoints: u32,
381        /// Forward body schedule (same compiled body as the forward
382        /// Op::Scan), used for recompute when `num_checkpoints` is
383        /// active. `None` for the All strategy.
384        forward_body: Option<Arc<ThunkSchedule>>,
385        /// Pristine forward body arena bytes (constants filled).
386        forward_body_init: Option<Arc<Vec<u8>>>,
387        /// Forward body's carry-Input and output slot offsets — needed
388        /// to seed/read the body during recompute.
389        forward_body_carry_in_off: usize,
390        forward_body_output_off: usize,
391        /// Forward body's per-step xs Input slots (one per outer xs).
392        /// Same indexing convention as `body_x_offs`.
393        forward_body_x_offs: Arc<Vec<usize>>,
394    },
395
396    /// Companion to `ScanBackward` that materializes one stacked
397    /// `dxs_i`. Same backward loop; per iteration, after running
398    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
399    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
400    /// is identical — we still need it for the body_vjp recurrence
401    /// even though we don't write it back to the outer arena.
402    ScanBackwardXs {
403        body_vjp: Arc<ThunkSchedule>,
404        body_init: Arc<Vec<u8>>,
405        body_carry_in_off: usize,
406        body_x_offs: Arc<Vec<usize>>,
407        body_d_output_off: usize,
408        body_dcarry_out_off: usize,
409        body_dxs_out_off: usize, // the body_vjp output we extract per step
410        outer_init_off: usize,
411        outer_traj_off: usize,
412        outer_upstream_off: usize,
413        outer_xs_offs: Arc<Vec<(usize, u32)>>,
414        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
415        length: u32,
416        carry_bytes: u32,
417        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
418        carry_elem_size: u32,
419        per_step_bytes: u32, // bytes per row of the dxs output
420        save_trajectory: bool,
421        /// Recursive checkpointing config. Same semantics as
422        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
423        /// means "save every step's carry"; `0 < K < length` means
424        /// the trajectory has only K rows and the executor recomputes
425        /// intermediate carries via `forward_body` (which must be
426        /// `Some`). Implemented via segment-cached recompute,
427        /// mirroring the `ScanBackward` path.
428        num_checkpoints: u32,
429        forward_body: Option<Arc<ThunkSchedule>>,
430        forward_body_init: Option<Arc<Vec<u8>>>,
431        forward_body_carry_in_off: usize,
432        forward_body_output_off: usize,
433        forward_body_x_offs: Arc<Vec<usize>>,
434    },
435    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
436    /// Per execution: clone `body_init`, copy each primal input from the
437    /// outer arena into its body Input slot, run the body schedule,
438    /// copy the body's single output back to the outer arena.
439    CustomFn {
440        body: Arc<ThunkSchedule>,
441        body_init: Arc<Vec<u8>>,
442        /// Per primal input: (body_input_off, outer_input_off, bytes).
443        inputs: Arc<Vec<(usize, usize, u32)>>,
444        body_output_off: usize,
445        outer_output_off: usize,
446        out_bytes: u32,
447    },
448    /// C = A @ B; C += bias; C = act(C)
449    FusedMmBiasAct {
450        a: usize,
451        w: usize,
452        bias: usize,
453        c: usize,
454        m: u32,
455        k: u32,
456        n: u32,
457        act: Option<Activation>,
458    },
459    /// out = LN(x + residual + bias, gamma, beta)
460    FusedResidualLN {
461        x: usize,
462        res: usize,
463        bias: usize,
464        g: usize,
465        b: usize,
466        out: usize,
467        rows: u32,
468        h: u32,
469        eps: f32,
470        has_bias: bool,
471    },
472    /// out = RmsNorm(x + residual + bias, gamma, beta)
473    FusedResidualRmsNorm {
474        x: usize,
475        res: usize,
476        bias: usize,
477        g: usize,
478        b: usize,
479        out: usize,
480        rows: u32,
481        h: u32,
482        eps: f32,
483        has_bias: bool,
484    },
485    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
486    BiasAdd {
487        src: usize,
488        bias: usize,
489        dst: usize,
490        m: u32,
491        n: u32,
492    },
493    /// Element-wise binary op with NumPy-style broadcast.
494    ///
495    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
496    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
497    /// are unused.
498    ///
499    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
500    /// `bcast_rhs_strides` to compute per-cell indices into each
501    /// operand. The strides are precomputed at thunk-construction
502    /// time from the operands' true shapes (with stride 0 on any axis
503    /// where the operand has size 1). This is the only correct way
504    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
505    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
506    /// wrong cells.
507    BinaryFull {
508        lhs: usize,
509        rhs: usize,
510        dst: usize,
511        len: u32,
512        lhs_len: u32,
513        rhs_len: u32,
514        op: BinaryOp,
515        /// Output shape dims (row-major). Empty in the fast path.
516        out_dims_bcast: Vec<u32>,
517        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
518        bcast_lhs_strides: Vec<u32>,
519        /// Per-dim stride into `rhs`.
520        bcast_rhs_strides: Vec<u32>,
521        /// Element size (4 = F32, 8 = I64).
522        elem_bytes: u8,
523    },
524    /// Activation in-place
525    ActivationInPlace {
526        data: usize,
527        len: u32,
528        act: Activation,
529    },
530    /// Gather axis=0: table\[idx\] → out
531    Gather {
532        table: usize,
533        table_len: u32,
534        idx: usize,
535        dst: usize,
536        num_idx: u32,
537        trailing: u32,
538        /// 1 when the index tensor is i64 (ONNX Gather indices).
539        idx_i64: u8,
540        /// Element size of table/output (4 = f32, 8 = i64).
541        table_bytes: u8,
542    },
543    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
544    Narrow {
545        src: usize,
546        dst: usize,
547        outer: u32,
548        src_stride: u32,
549        dst_stride: u32,
550        inner: u32,
551        elem_bytes: u8,
552    },
553    /// Copy (reshape, expand)
554    Copy {
555        src: usize,
556        dst: usize,
557        len: u32,
558    },
559    /// LayerNorm standalone
560    LayerNorm {
561        src: usize,
562        g: usize,
563        b: usize,
564        dst: usize,
565        rows: u32,
566        h: u32,
567        eps: f32,
568    },
569    /// GroupNorm on NCHW `[N,C,H,W]`.
570    GroupNorm {
571        src: usize,
572        g: usize,
573        b: usize,
574        dst: usize,
575        n: u32,
576        c: u32,
577        h: u32,
578        w: u32,
579        num_groups: u32,
580        eps: f32,
581    },
582    /// BatchNorm inference: frozen mean/var, feature axis last.
583    BatchNormInference {
584        src: usize,
585        g: usize,
586        b: usize,
587        mean: usize,
588        var: usize,
589        dst: usize,
590        count: u32,
591        channels: u32,
592        eps: f32,
593    },
594    BatchNormInferenceBackwardInput {
595        x: usize,
596        gamma: usize,
597        mean: usize,
598        var: usize,
599        dy: usize,
600        dx: usize,
601        count: u32,
602        channels: u32,
603        eps: f32,
604    },
605    BatchNormInferenceBackwardGamma {
606        x: usize,
607        mean: usize,
608        var: usize,
609        dy: usize,
610        dgamma: usize,
611        count: u32,
612        channels: u32,
613        eps: f32,
614    },
615    BatchNormInferenceBackwardBeta {
616        dy: usize,
617        dbeta: usize,
618        count: u32,
619        channels: u32,
620    },
621    /// LayerNorm2d on NCHW (SAM / candle semantics).
622    LayerNorm2d {
623        src: usize,
624        g: usize,
625        b: usize,
626        dst: usize,
627        n: u32,
628        c: u32,
629        h: u32,
630        w: u32,
631        eps: f32,
632    },
633    /// ConvTranspose2d on NCHW.
634    ConvTranspose2d {
635        src: usize,
636        weight: usize,
637        dst: usize,
638        n: u32,
639        c_in: u32,
640        h: u32,
641        w_in: u32,
642        c_out: u32,
643        h_out: u32,
644        w_out: u32,
645        kh: u32,
646        kw: u32,
647        sh: u32,
648        sw: u32,
649        ph: u32,
650        pw: u32,
651        dh: u32,
652        dw: u32,
653        groups: u32,
654    },
655    /// Nearest 2× upsample on NCHW (per-batch slice).
656    ResizeNearest2x {
657        src: usize,
658        dst: usize,
659        n: u32,
660        c: u32,
661        h: u32,
662        w: u32,
663    },
664    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
665    AxialRope2d {
666        src: usize,
667        dst: usize,
668        batch: u32,
669        seq: u32,
670        hidden: u32,
671        end_x: u32,
672        end_y: u32,
673        head_dim: u32,
674        num_heads: u32,
675        theta: f32,
676        repeat_factor: u32,
677    },
678    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
679    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
680    RmsNorm {
681        src: usize,
682        g: usize,
683        b: usize,
684        dst: usize,
685        rows: u32,
686        h: u32,
687        eps: f32,
688    },
689    /// Softmax
690    Softmax {
691        data: usize,
692        rows: u32,
693        cols: u32,
694    },
695    /// Inclusive (or exclusive) cumulative sum along the last axis
696    /// (callers pre-flatten higher-dim cumsums via reshape views).
697    Cumsum {
698        src: usize,
699        dst: usize,
700        rows: u32,
701        cols: u32,
702        exclusive: bool,
703    },
704    /// Mamba-style selective scan (plan #15).
705    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
706    /// Output: y \[b,s,h\]. State h carries through the seq.
707    SelectiveScan {
708        x: usize,
709        delta: usize,
710        a: usize,
711        b: usize,
712        c: usize,
713        dst: usize,
714        batch: u32,
715        seq: u32,
716        hidden: u32,
717        state_size: u32,
718    },
719
720    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
721    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
722    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
723    GatedDeltaNet {
724        q: usize,
725        k: usize,
726        v: usize,
727        g: usize,
728        beta: usize,
729        /// When non-zero, load initial `[b, h, n, n]` state and write
730        /// the final state back in place after the scan.
731        state: usize,
732        dst: usize,
733        batch: u32,
734        seq: u32,
735        heads: u32,
736        state_size: u32,
737    },
738
739    /// Multi-layer (optionally bidirectional, optional carry) LSTM with
740    /// packed weights. See `Op::Lstm`. `h0`/`c0` are valid only when
741    /// `carry`; `dst` is `[b, s, D*h]`.
742    Lstm {
743        x: usize,
744        w_ih: usize,
745        w_hh: usize,
746        bias: usize,
747        h0: usize,
748        c0: usize,
749        dst: usize,
750        batch: u32,
751        seq: u32,
752        input_size: u32,
753        hidden: u32,
754        num_layers: u32,
755        bidirectional: bool,
756        carry: bool,
757    },
758
759    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
760    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
761    /// groups-1 conv is mathematically a per-batch matmul, and
762    /// dispatching it through BLAS is 3-10× faster than the
763    /// scalar nest. Common case: ViT patch-projection follow-on,
764    /// transformer "expert" reductions in some MoE designs.
765    ///
766    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
767    ///         = output `[c_out, h*w]`.
768    Conv2D1x1 {
769        src: usize,
770        weight: usize,
771        dst: usize,
772        n: u32,
773        c_in: u32,
774        c_out: u32,
775        hw: u32,
776    },
777
778    /// Fused dequant + matmul (plan #5). Today supports
779    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
780    /// at lowering time with a clear message until kernels are added.
781    DequantMatMul {
782        x: usize,
783        w_q: usize,   // packed i8 bytes for Int8 schemes
784        scale: usize, // [k/block, n] f32 scale
785        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
786        dst: usize,
787        m: u32,
788        k: u32,
789        n: u32,
790        block_size: u32,
791        is_asymmetric: bool,
792    },
793
794    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
795    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
796    /// Q8_K). Scales / mins live inside the packed bytes — no
797    /// side-channel scale tensor.
798    ///
799    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
800    /// keeps the *arena* memory footprint down (weights stay packed)
801    /// but the dequant itself happens per matmul. A future fully
802    /// fused tile-streaming kernel would close the compute gap.
803    DequantMatMulGguf {
804        x: usize,   // f32 activations [m, k]
805        w_q: usize, // packed weight bytes (k*n elements packed)
806        dst: usize, // f32 output [m, n]
807        m: u32,
808        k: u32,
809        n: u32,
810        scheme: rlx_ir::quant::QuantScheme,
811    },
812
813    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
814    DequantMatMulInt4 {
815        x: usize,
816        w_q: usize,
817        scale: usize,
818        zp: usize,
819        dst: usize,
820        m: u32,
821        k: u32,
822        n: u32,
823        block_size: u32,
824        is_asymmetric: bool,
825    },
826
827    /// FP8 dequant + matmul (per-tensor or per-column scale).
828    DequantMatMulFp8 {
829        x: usize,
830        w_q: usize,
831        scale: usize,
832        dst: usize,
833        m: u32,
834        k: u32,
835        n: u32,
836        e5m2: bool,
837    },
838
839    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
840    DequantMatMulNvfp4 {
841        x: usize,
842        w_q: usize,
843        scale: usize,
844        global_scale: usize,
845        dst: usize,
846        m: u32,
847        k: u32,
848        n: u32,
849    },
850
851    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
852    /// `r` is the LoRA rank (typically 4-64) — the rank-r
853    /// intermediate `x·A` lives in scratch, never on the arena.
854    LoraMatMul {
855        x: usize,
856        w: usize,
857        a: usize,
858        b: usize,
859        dst: usize,
860        m: u32,
861        k: u32,
862        n: u32,
863        r: u32,
864        scale: f32,
865    },
866    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
867    /// See Op::Sample. Output values are f32-encoded usize indices
868    /// (matches the rest of the IR's "ids as f32" convention).
869    Sample {
870        logits: usize,
871        dst: usize,
872        batch: u32,
873        vocab: u32,
874        top_k: u32,       // 0 = disabled
875        top_p: f32,       // 1.0 = disabled
876        temperature: f32, // 1.0 = neutral
877        seed: u64,
878    },
879    /// ONNX `RandomNormalLike` fill.
880    RngNormal {
881        dst: usize,
882        len: u32,
883        mean: f32,
884        scale: f32,
885        key: u64,
886        op_seed: Option<f32>,
887    },
888    /// ONNX `RandomUniformLike` fill.
889    RngUniform {
890        dst: usize,
891        len: u32,
892        low: f32,
893        high: f32,
894        key: u64,
895        op_seed: Option<f32>,
896    },
897    /// Attention SDPA. `mask` is the offset of the optional mask tensor
898    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
899    /// kinds synthesize the mask in-kernel.
900    ///
901    /// Q/K/V each carry a `_row_stride` (elements per source row).
902    /// Defaults to `heads * head_dim` — matches the standalone
903    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
904    /// Attention fusion below rewrites these to the parent QKV stride
905    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
906    /// directly without materializing the per-head buffers (plan #46).
907    Attention {
908        q: usize,
909        k: usize,
910        v: usize,
911        mask: usize,
912        out: usize,
913        batch: u32,
914        /// Query sequence length.
915        seq: u32,
916        /// Key/value sequence length. Differs from `seq` during cached decode.
917        kv_seq: u32,
918        heads: u32,
919        head_dim: u32,
920        mask_kind: rlx_ir::op::MaskKind,
921        /// Softmax score scale (`Op::Attention::score_scale`). `head_dim^-0.5`
922        /// when the op left it unset. Must be honored — Gemma 4 uses `1.0`
923        /// (Q/K are per-head RMS-normed, so no `1/sqrt(d)` pre-scale).
924        scale: f32,
925        q_row_stride: u32,
926        k_row_stride: u32,
927        v_row_stride: u32,
928        /// Memory layout flag. `false` (the historical default) →
929        /// `[B, S, H, D]` row-major: per-head offset is
930        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
931        /// (head-major), matching the convention used by rlx-cuda /
932        /// rlx-rocm / rlx-tpu: per-head offset is
933        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
934        /// from the input shape vs `num_heads` / `head_dim`.
935        bhsd: bool,
936    },
937    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
938    AttentionBackward {
939        q: usize,
940        k: usize,
941        v: usize,
942        dy: usize,
943        mask: usize,
944        out: usize,
945        batch: u32,
946        seq: u32,
947        kv_seq: u32,
948        heads: u32,
949        head_dim: u32,
950        mask_kind: rlx_ir::op::MaskKind,
951        wrt: rlx_ir::op::AttentionBwdWrt,
952        bhsd: bool,
953    },
954    /// RoPE (rotary position embeddings).
955    /// `src_row_stride` is elements per source row (defaults to `hidden`
956    /// for the standalone case; set to `qkv_axis * inner` when the
957    /// thunk fusion pass below rewires Rope to read directly from the
958    /// fused QKV buffer — plan #45).
959    Rope {
960        src: usize,
961        cos: usize,
962        sin: usize,
963        dst: usize,
964        batch: u32,
965        seq: u32,
966        hidden: u32,
967        head_dim: u32,
968        n_rot: u32,
969        cos_len: u32,
970        src_row_stride: u32,
971    },
972    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
973    /// All intermediates stay in L1 cache. Zero arena writes between ops.
974    FusedAttnBlock {
975        hidden: usize,
976        qkv_w: usize,
977        out_w: usize,
978        mask: usize,
979        out: usize,
980        qkv_b: usize,
981        out_b: usize, // 0 = no bias
982        cos: usize,
983        sin: usize,
984        cos_len: u32, // 0 = no RoPE
985        batch: u32,
986        seq: u32,
987        hs: u32,
988        nh: u32,
989        dh: u32,
990        has_bias: bool,
991        has_rope: bool,
992    },
993    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
994    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
995    FusedBertLayer {
996        // attention
997        hidden: usize,
998        qkv_w: usize,
999        qkv_b: usize,
1000        out_w: usize,
1001        out_b: usize,
1002        mask: usize,
1003        // LN1
1004        ln1_g: usize,
1005        ln1_b: usize,
1006        eps1: f32,
1007        // FFN (GELU)
1008        fc1_w: usize,
1009        fc1_b: usize,
1010        fc2_w: usize,
1011        fc2_b: usize,
1012        // LN2
1013        ln2_g: usize,
1014        ln2_b: usize,
1015        eps2: f32,
1016        // output
1017        out: usize,
1018        // dims
1019        batch: u32,
1020        seq: u32,
1021        hs: u32,
1022        nh: u32,
1023        dh: u32,
1024        int_dim: u32,
1025    },
1026    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
1027    FusedNomicLayer {
1028        hidden: usize,
1029        qkv_w: usize,
1030        out_w: usize,
1031        mask: usize,
1032        cos: usize,
1033        sin: usize,
1034        cos_len: u32,
1035        ln1_g: usize,
1036        ln1_b: usize,
1037        eps1: f32,
1038        fc11_w: usize,
1039        fc12_w: usize,
1040        fc2_w: usize,
1041        ln2_g: usize,
1042        ln2_b: usize,
1043        eps2: f32,
1044        out: usize,
1045        batch: u32,
1046        seq: u32,
1047        hs: u32,
1048        nh: u32,
1049        dh: u32,
1050        int_dim: u32,
1051    },
1052    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
1053    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
1054    /// Output: [outer, n_half].
1055    FusedSwiGLU {
1056        src: usize,
1057        dst: usize,
1058        n_half: u32,
1059        total: u32,
1060        gate_first: bool,
1061    },
1062    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
1063    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
1064    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
1065    /// at compile time to avoid per-run shape work.
1066    Concat {
1067        dst: usize,
1068        outer: u32,
1069        inner: u32,
1070        total_axis: u32,
1071        /// `(src_offset, axis_extent, input_numel)` — `input_numel` enables
1072        /// outer-dim broadcast when rank-deficient inputs are concatenated.
1073        inputs: Vec<(usize, u32, u32)>,
1074    },
1075    /// Element-wise comparison: out = (lhs CMP rhs) ? 1 : 0 (Bool u8 or F32 0/1).
1076    Compare {
1077        lhs: usize,
1078        rhs: usize,
1079        dst: usize,
1080        len: u32,
1081        op: CmpOp,
1082        /// Nonzero when lhs/rhs are i64 (mask/range ops).
1083        inputs_i64: u8,
1084        /// Input element size (1 = Bool, 4 = F32, 8 = I64).
1085        inputs_elem_bytes: u8,
1086        /// Output element size (1 = Bool, 4 = F32).
1087        dst_elem_bytes: u8,
1088    },
1089    /// Reduction along a contiguous range of axes. Input layout (after
1090    /// shape decomposition) is `[outer, reduced, inner]`; output is
1091    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
1092    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
1093    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
1094    /// all map onto this triplet. Non-contiguous axes are not supported
1095    /// and bail to Nop in the compile pass.
1096    Reduce {
1097        src: usize,
1098        dst: usize,
1099        outer: u32,
1100        reduced: u32,
1101        inner: u32,
1102        op: ReduceOp,
1103    },
1104    /// Index of the max (`is_max`) or min along the reduced axis; writes the
1105    /// winning index as an f32 into `dst`.
1106    ArgReduce {
1107        src: usize,
1108        dst: usize,
1109        outer: u32,
1110        reduced: u32,
1111        inner: u32,
1112        is_max: bool,
1113    },
1114    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
1115    /// output `[outer, k]` (f32 or i64 per `indices_i64`). Ties broken by
1116    /// smaller index. Used by MoE gating + beam search.
1117    TopK {
1118        src: usize,
1119        dst: usize,
1120        outer: u32,
1121        axis_dim: u32,
1122        k: u32,
1123        indices_i64: u8,
1124    },
1125    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
1126    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
1127    /// segmented GEMM would amortize. Done when there's a workload.
1128    GroupedMatMul {
1129        input: usize,
1130        weight: usize,
1131        expert_idx: usize,
1132        dst: usize,
1133        m: u32,
1134        k_dim: u32,
1135        n: u32,
1136        num_experts: u32,
1137    },
1138    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
1139    DequantGroupedMatMulGguf {
1140        input: usize,
1141        w_q: usize,
1142        expert_idx: usize,
1143        dst: usize,
1144        m: u32,
1145        k_dim: u32,
1146        n: u32,
1147        num_experts: u32,
1148        scheme: rlx_ir::quant::QuantScheme,
1149    },
1150    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
1151    DequantMoEWeightsGguf {
1152        w_q: usize,
1153        dst: usize,
1154        k_dim: u32,
1155        n: u32,
1156        num_experts: u32,
1157        scheme: rlx_ir::quant::QuantScheme,
1158    },
1159    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
1160    /// Output is zeroed first; multiple updates to the same row accumulate.
1161    ScatterAdd {
1162        updates: usize,
1163        indices: usize,
1164        dst: usize,
1165        num_updates: u32,
1166        out_dim: u32,
1167        trailing: u32,
1168    },
1169    /// Ternary select: out = cond != 0 ? on_true : on_false
1170    Where {
1171        cond: usize,
1172        on_true: usize,
1173        on_false: usize,
1174        dst: usize,
1175        len: u32,
1176        elem_bytes: u8,
1177        /// Element size for cond (1 = Bool mask, 4 = F32 0/1).
1178        cond_elem_bytes: u8,
1179    },
1180    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1181    /// i length; `in_strides[i]` is the input stride (in elements) used to
1182    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1183    /// total element count in the source buffer (≤ output total when
1184    /// broadcasting). Strides are pre-computed at compile time.
1185    Transpose {
1186        src: usize,
1187        dst: usize,
1188        in_total: u32,
1189        out_dims: Vec<u32>,
1190        in_strides: Vec<u32>,
1191        elem_bytes: u8,
1192    },
1193    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1194    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1195    /// being indexed into. Output: outer × num_idx × trailing.
1196    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1197    GatherAxis {
1198        table: usize,
1199        idx: usize,
1200        dst: usize,
1201        outer: u32,
1202        axis_dim: u32,
1203        num_idx: u32,
1204        trailing: u32,
1205        idx_i64: u8,
1206        table_bytes: u8,
1207    },
1208    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1209    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1210    /// the full kernel area (matches torch's `count_include_pad=True`).
1211    Pool2D {
1212        src: usize,
1213        dst: usize,
1214        n: u32,
1215        c: u32,
1216        h: u32,
1217        w: u32,
1218        h_out: u32,
1219        w_out: u32,
1220        kh: u32,
1221        kw: u32,
1222        sh: u32,
1223        sw: u32,
1224        ph: u32,
1225        pw: u32,
1226        kind: ReduceOp,
1227    },
1228    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1229    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1230    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1231    /// Naive direct convolution; sufficient for correctness, not optimised.
1232    Conv2D {
1233        src: usize,
1234        weight: usize,
1235        dst: usize,
1236        n: u32,
1237        c_in: u32,
1238        h: u32,
1239        w: u32,
1240        c_out: u32,
1241        h_out: u32,
1242        w_out: u32,
1243        kh: u32,
1244        kw: u32,
1245        sh: u32,
1246        sw: u32,
1247        ph: u32,
1248        pw: u32,
1249        dh: u32,
1250        dw: u32,
1251        groups: u32,
1252    },
1253
1254    // ── Backward / training kernels ─────────────────────────────
1255    /// Real INT8 matmul with i32 accumulation.
1256    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1257    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1258    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1259    /// to a desktop thunk so a quantized graph compiled here doesn't
1260    /// have to round-trip through fake-quant.
1261    QMatMul {
1262        x: usize,
1263        w: usize,
1264        bias: usize,
1265        out: usize,
1266        m: u32,
1267        k: u32,
1268        n: u32,
1269        x_zp: i32,
1270        w_zp: i32,
1271        out_zp: i32,
1272        mult: f32,
1273    },
1274
1275    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1276    /// but with i8 reads, i32 accumulation, and per-output requantize
1277    /// to i8. Bias is i32 in the accumulator scale.
1278    QConv2d {
1279        x: usize,
1280        w: usize,
1281        bias: usize,
1282        out: usize,
1283        n: u32,
1284        c_in: u32,
1285        h: u32,
1286        w_in: u32,
1287        c_out: u32,
1288        h_out: u32,
1289        w_out: u32,
1290        kh: u32,
1291        kw: u32,
1292        sh: u32,
1293        sw: u32,
1294        ph: u32,
1295        pw: u32,
1296        dh: u32,
1297        dw: u32,
1298        groups: u32,
1299        x_zp: i32,
1300        w_zp: i32,
1301        out_zp: i32,
1302        mult: f32,
1303    },
1304
1305    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1306    /// `chan = (i / inner) % chan_dim` selects the per-channel
1307    /// scale/zp; `chan_axis` is informational only (the kernel uses
1308    /// `chan_dim` and `inner` directly).
1309    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1310    /// always 0.
1311    Quantize {
1312        x: usize,
1313        q: usize,
1314        len: u32,
1315        chan_axis: u32,
1316        chan_dim: u32,
1317        inner: u32,
1318        scales: Vec<f32>,
1319        zero_points: Vec<i32>,
1320    },
1321
1322    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1323    Dequantize {
1324        q: usize,
1325        x: usize,
1326        len: u32,
1327        chan_axis: u32,
1328        chan_dim: u32,
1329        inner: u32,
1330        scales: Vec<f32>,
1331        zero_points: Vec<i32>,
1332    },
1333
1334    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1335    /// quantize-then-dequantize on the fly. Computes
1336    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1337    /// then
1338    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1339    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1340    /// channel-layout convention as `Thunk::Quantize`: every
1341    /// element's channel is `(i / inner) % chan_dim`. The kernel
1342    /// does two passes — one to scan max-abs per channel, one to
1343    /// quant-dequant per element.
1344    FakeQuantize {
1345        x: usize,
1346        out: usize,
1347        len: u32,
1348        chan_axis: u32,
1349        chan_dim: u32,
1350        inner: u32,
1351        bits: u8,
1352        /// STE variant — informational on the forward side (output is
1353        /// the same regardless), kernel-relevant in the matching
1354        /// `FakeQuantizeBackward` thunk.
1355        ste: rlx_ir::op::SteKind,
1356        /// Scale-tracking strategy. `PerBatch` recomputes
1357        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1358        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1359        /// reads `state_off` and never updates it.
1360        scale_mode: rlx_ir::op::ScaleMode,
1361        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1362        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1363        /// per channel.
1364        state_off: Option<usize>,
1365    },
1366
1367    /// Backward pass for `Op::FakeQuantize` under one of four STE
1368    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1369    /// the upstream gradient `dy`, using the same per-channel scale
1370    /// scheme as the forward.
1371    FakeQuantizeBackward {
1372        x: usize,
1373        dy: usize,
1374        dx: usize,
1375        len: u32,
1376        chan_axis: u32,
1377        chan_dim: u32,
1378        inner: u32,
1379        bits: u8,
1380        ste: rlx_ir::op::SteKind,
1381    },
1382
1383    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1384    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1385    FakeQuantizeLSQ {
1386        x: usize,
1387        scale_off: usize,
1388        out: usize,
1389        len: u32,
1390        chan_axis: u32,
1391        chan_dim: u32,
1392        inner: u32,
1393        bits: u8,
1394    },
1395
1396    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1397    /// through inside the quantization range, zeros outside.
1398    FakeQuantizeLSQBackwardX {
1399        x: usize,
1400        scale_off: usize,
1401        dy: usize,
1402        dx: usize,
1403        len: u32,
1404        chan_axis: u32,
1405        chan_dim: u32,
1406        inner: u32,
1407        bits: u8,
1408    },
1409
1410    /// LSQ backward, scale-gradient. Per-channel:
1411    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1412    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1413    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1414    FakeQuantizeLSQBackwardScale {
1415        x: usize,
1416        scale_off: usize,
1417        dy: usize,
1418        dscale: usize,
1419        len: u32,
1420        chan_axis: u32,
1421        chan_dim: u32,
1422        inner: u32,
1423        bits: u8,
1424    },
1425
1426    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1427    ReluBackward {
1428        x: usize,
1429        dy: usize,
1430        dx: usize,
1431        len: u32,
1432    },
1433    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1434    /// but reads/writes 8 bytes per element. Required because
1435    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1436    /// otherwise → backward silently produces 0 gradients on an f64
1437    /// graph. Mirrors the `ActivationBackwardF64` split.
1438    ReluBackwardF64 {
1439        x: usize,
1440        dy: usize,
1441        dx: usize,
1442        len: u32,
1443    },
1444
1445    /// Generic element-wise activation backward.
1446    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1447    /// per-element; expensive activations (Gelu) recompute internals
1448    /// inline rather than threading an extra "saved y" tensor through.
1449    ActivationBackward {
1450        x: usize,
1451        dy: usize,
1452        dx: usize,
1453        len: u32,
1454        kind: Activation,
1455    },
1456    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1457    /// elements; kernel reads/writes 8 bytes per element. Required
1458    /// because `ActivationBackward`'s `&[f32]` slot view silently
1459    /// returns garbage on an f64 graph (cb % 4 still works but every
1460    /// loaded value is half of an f64 → wrong gradient).
1461    ActivationBackwardF64 {
1462        x: usize,
1463        dy: usize,
1464        dx: usize,
1465        len: u32,
1466        kind: Activation,
1467    },
1468
1469    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1470    /// `x` and emits the closed-form `d_x` per row.
1471    LayerNormBackwardInput {
1472        x: usize,
1473        gamma: usize,
1474        dy: usize,
1475        dx: usize,
1476        rows: u32,
1477        h: u32,
1478        eps: f32,
1479    },
1480
1481    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1482    LayerNormBackwardGamma {
1483        x: usize,
1484        dy: usize,
1485        dgamma: usize,
1486        rows: u32,
1487        h: u32,
1488        eps: f32,
1489    },
1490
1491    RmsNormBackwardInput {
1492        x: usize,
1493        gamma: usize,
1494        beta: usize,
1495        dy: usize,
1496        dx: usize,
1497        rows: u32,
1498        h: u32,
1499        eps: f32,
1500    },
1501    RmsNormBackwardGamma {
1502        x: usize,
1503        gamma: usize,
1504        beta: usize,
1505        dy: usize,
1506        dgamma: usize,
1507        rows: u32,
1508        h: u32,
1509        eps: f32,
1510    },
1511    RmsNormBackwardBeta {
1512        x: usize,
1513        gamma: usize,
1514        beta: usize,
1515        dy: usize,
1516        dbeta: usize,
1517        rows: u32,
1518        h: u32,
1519        eps: f32,
1520    },
1521    RopeBackward {
1522        dy: usize,
1523        cos: usize,
1524        sin: usize,
1525        dx: usize,
1526        batch: u32,
1527        seq: u32,
1528        hidden: u32,
1529        head_dim: u32,
1530        n_rot: u32,
1531        cos_len: u32,
1532    },
1533    CumsumBackward {
1534        dy: usize,
1535        dx: usize,
1536        rows: u32,
1537        cols: u32,
1538        exclusive: bool,
1539    },
1540    GatherBackward {
1541        dy: usize,
1542        indices: usize,
1543        dst: usize,
1544        outer: u32,
1545        axis_dim: u32,
1546        num_idx: u32,
1547        trailing: u32,
1548    },
1549
1550    GroupNormBackwardInput {
1551        x: usize,
1552        gamma: usize,
1553        beta: usize,
1554        dy: usize,
1555        dx: usize,
1556        n: u32,
1557        c: u32,
1558        h: u32,
1559        w: u32,
1560        num_groups: u32,
1561        eps: f32,
1562    },
1563    GroupNormBackwardGamma {
1564        x: usize,
1565        dy: usize,
1566        dgamma: usize,
1567        n: u32,
1568        c: u32,
1569        h: u32,
1570        w: u32,
1571        num_groups: u32,
1572        eps: f32,
1573    },
1574    GroupNormBackwardBeta {
1575        dy: usize,
1576        dbeta: usize,
1577        n: u32,
1578        c: u32,
1579        h: u32,
1580        w: u32,
1581    },
1582
1583    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1584    /// inside each window and accumulates `dy` into `dx` at that
1585    /// position. Output is zeroed first; ties resolve to the first
1586    /// hit (lowest (kh,kw) index), matching what the forward kernel
1587    /// does with `acc.max(v)`.
1588    MaxPool2dBackward {
1589        x: usize,
1590        dy: usize,
1591        dx: usize,
1592        n: u32,
1593        c: u32,
1594        h: u32,
1595        w: u32,
1596        h_out: u32,
1597        w_out: u32,
1598        kh: u32,
1599        kw: u32,
1600        sh: u32,
1601        sw: u32,
1602        ph: u32,
1603        pw: u32,
1604    },
1605
1606    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1607    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1608    /// `dx [N, C_in, H, W]`.
1609    Conv2dBackwardInput {
1610        dy: usize,
1611        w: usize,
1612        dx: usize,
1613        n: u32,
1614        c_in: u32,
1615        h: u32,
1616        w_in: u32,
1617        c_out: u32,
1618        h_out: u32,
1619        w_out: u32,
1620        kh: u32,
1621        kw: u32,
1622        sh: u32,
1623        sw: u32,
1624        ph: u32,
1625        pw: u32,
1626        dh: u32,
1627        dw: u32,
1628        groups: u32,
1629    },
1630
1631    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1632    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1633    /// `dw` is zeroed before accumulation.
1634    Conv2dBackwardWeight {
1635        x: usize,
1636        dy: usize,
1637        dw: usize,
1638        n: u32,
1639        c_in: u32,
1640        h: u32,
1641        w: u32,
1642        c_out: u32,
1643        h_out: u32,
1644        w_out: u32,
1645        kh: u32,
1646        kw: u32,
1647        sh: u32,
1648        sw: u32,
1649        ph: u32,
1650        pw: u32,
1651        dh: u32,
1652        dw_dil: u32,
1653        groups: u32,
1654    },
1655
1656    /// NCHW im2col for conv backward-weight matmul. Output `[M, C·kH·kW]`
1657    /// with `M = N · H_out · W_out`. `n == 0` means infer batch from `x`.
1658    Im2Col {
1659        x: usize,
1660        col: usize,
1661        n: u32,
1662        c_in: u32,
1663        h: u32,
1664        w: u32,
1665        h_out: u32,
1666        w_out: u32,
1667        kh: u32,
1668        kw: u32,
1669        sh: u32,
1670        sw: u32,
1671        ph: u32,
1672        pw: u32,
1673        dh: u32,
1674        dw_dil: u32,
1675    },
1676
1677    /// Fused softmax + cross-entropy loss with f32-encoded integer
1678    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1679    /// Numerically stable (max-subtract before exp).
1680    SoftmaxCrossEntropy {
1681        logits: usize,
1682        labels: usize,
1683        dst: usize,
1684        n: u32,
1685        c: u32,
1686    },
1687
1688    /// Backward of the fused loss above.
1689    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1690    SoftmaxCrossEntropyBackward {
1691        logits: usize,
1692        labels: usize,
1693        d_loss: usize,
1694        dlogits: usize,
1695        n: u32,
1696        c: u32,
1697    },
1698
1699    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1700    /// `kernel` is resolved against the global CPU kernel registry at
1701    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1702    /// avoids per-call lookups. v1: f32 contiguous only — see
1703    /// `op_registry::CpuKernel::execute_f32`.
1704    CustomOp {
1705        kernel: Arc<dyn CpuKernel>,
1706        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1707        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1708        attrs: Vec<u8>,
1709    },
1710
1711    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1712    /// real-block layout (first N real, second N imag along the
1713    /// transformed axis). `outer` is the product of all leading axes;
1714    /// `n_complex` is N (the number of complex points). Both halves
1715    /// of the real-block layout are read together by the kernel.
1716    /// `dtype` selects the f32 or f64 path; the two share structure
1717    /// but not buffers, so a flag at compile time avoids per-row
1718    /// dispatch.
1719    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1720    GaussianSplatRender {
1721        positions_off: usize,
1722        positions_len: usize,
1723        scales_off: usize,
1724        scales_len: usize,
1725        rotations_off: usize,
1726        rotations_len: usize,
1727        opacities_off: usize,
1728        opacities_len: usize,
1729        colors_off: usize,
1730        colors_len: usize,
1731        sh_coeffs_off: usize,
1732        sh_coeffs_len: usize,
1733        meta_off: usize,
1734        dst_off: usize,
1735        dst_len: usize,
1736        width: u32,
1737        height: u32,
1738        tile_size: u32,
1739        radius_scale: f32,
1740        alpha_cutoff: f32,
1741        max_splat_steps: u32,
1742        transmittance_threshold: f32,
1743        max_list_entries: u32,
1744    },
1745    GaussianSplatRenderBackward {
1746        positions_off: usize,
1747        positions_len: usize,
1748        scales_off: usize,
1749        scales_len: usize,
1750        rotations_off: usize,
1751        rotations_len: usize,
1752        opacities_off: usize,
1753        opacities_len: usize,
1754        colors_off: usize,
1755        colors_len: usize,
1756        sh_coeffs_off: usize,
1757        sh_coeffs_len: usize,
1758        meta_off: usize,
1759        d_loss_off: usize,
1760        d_loss_len: usize,
1761        packed_off: usize,
1762        packed_len: usize,
1763        width: u32,
1764        height: u32,
1765        tile_size: u32,
1766        radius_scale: f32,
1767        alpha_cutoff: f32,
1768        max_splat_steps: u32,
1769        transmittance_threshold: f32,
1770        max_list_entries: u32,
1771        loss_grad_clip: f32,
1772        sh_band: u32,
1773        max_anisotropy: f32,
1774    },
1775    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1776    GaussianSplatPrepare {
1777        positions_off: usize,
1778        positions_len: usize,
1779        scales_off: usize,
1780        scales_len: usize,
1781        rotations_off: usize,
1782        rotations_len: usize,
1783        opacities_off: usize,
1784        opacities_len: usize,
1785        colors_off: usize,
1786        colors_len: usize,
1787        sh_coeffs_off: usize,
1788        sh_coeffs_len: usize,
1789        meta_off: usize,
1790        meta_len: usize,
1791        prep_off: usize,
1792        prep_len: usize,
1793        width: u32,
1794        height: u32,
1795        tile_size: u32,
1796        radius_scale: f32,
1797        alpha_cutoff: f32,
1798        max_splat_steps: u32,
1799        transmittance_threshold: f32,
1800        max_list_entries: u32,
1801    },
1802    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1803    GaussianSplatRasterize {
1804        prep_off: usize,
1805        prep_len: usize,
1806        meta_off: usize,
1807        meta_len: usize,
1808        dst_off: usize,
1809        dst_len: usize,
1810        count: usize,
1811        width: u32,
1812        height: u32,
1813        tile_size: u32,
1814        alpha_cutoff: f32,
1815        max_splat_steps: u32,
1816        transmittance_threshold: f32,
1817        max_list_entries: u32,
1818    },
1819    Fft1d {
1820        src: usize,
1821        dst: usize,
1822        outer: u32,
1823        n_complex: u32,
1824        inverse: bool,
1825        norm_tag: u32,
1826        dtype: rlx_ir::DType,
1827    },
1828    FftButterflyStage {
1829        state_src: usize,
1830        state_dst: usize,
1831        gate_src: usize,
1832        rev_src: usize,
1833        tw_re_src: usize,
1834        tw_im_src: usize,
1835        batch: u32,
1836        n_fft: u32,
1837        stage: u32,
1838    },
1839    LogMel {
1840        spec: usize,
1841        filters: usize,
1842        dst: usize,
1843        outer: u32,
1844        n_fft: u32,
1845        n_bins: u32,
1846        n_mels: u32,
1847    },
1848    LogMelBackward {
1849        spec: usize,
1850        filters: usize,
1851        dy: usize,
1852        dst: usize,
1853        outer: u32,
1854        n_fft: u32,
1855        n_bins: u32,
1856        n_mels: u32,
1857    },
1858    WelchPeaks {
1859        spec: usize,
1860        dst: usize,
1861        welch_batch: u32,
1862        n_fft: u32,
1863        n_segments: u32,
1864        k: u32,
1865    },
1866}
1867
1868/// Compiled thunk schedule — the runtime hot path.
1869/// Nop thunks are filtered out at compile time for zero iteration overhead.
1870#[derive(Clone)]
1871pub struct ThunkSchedule {
1872    pub thunks: Vec<Thunk>,
1873    /// TIDE merged placement mask (union across layers).
1874    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1875    /// Per MoE layer placement (`layer[e]`); preferred when set.
1876    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1877    /// MoE router TopK capture (per-layer refresh).
1878    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1879    /// Cached config values.
1880    pub mask_threshold: f32,
1881    pub mask_neg_inf: f32,
1882    pub score_skip: f32,
1883    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1884    /// `Box`) so the schedule can be `Clone` — multiple parallel
1885    /// executors share the same compiled closures (they're read-only
1886    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1887    /// they receive is the only mutable state and is per-executor).
1888    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1889    /// Runtime-mutable RNG policy for [`Thunk::RngNormal`] / [`Thunk::RngUniform`].
1890    pub rng: Arc<std::sync::RwLock<rlx_ir::RngOptions>>,
1891}
1892
1893impl ThunkSchedule {
1894    pub fn strip_nops(&mut self) {
1895        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1896        // compiled_fns must be rebuilt after stripping — caller should
1897        // call strip_nops() before compile_closures().
1898        self.compiled_fns.clear();
1899    }
1900}
1901
1902/// Get the arena byte offset for a node.
1903fn node_offset(arena: &Arena, id: NodeId) -> usize {
1904    if arena.has_buffer(id) {
1905        arena.byte_offset(id)
1906    } else {
1907        usize::MAX
1908    }
1909}
1910
1911/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1912/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1913/// before eliding it. Conservative: when in doubt about reads (an op
1914/// not yet listed here), the fusion will skip — correctness over
1915/// completeness.
1916fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1917    match t {
1918        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1919        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1920        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1921        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1922        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1923        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1924        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1925        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1926        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1927        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1928        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1929        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1930        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1931        Thunk::ConjugateC64 { src, .. } => vec![*src],
1932        Thunk::Scan {
1933            outer_init_off,
1934            xs_inputs,
1935            ..
1936        } => {
1937            let mut v = vec![*outer_init_off];
1938            for (_, outer_xs_off, _) in xs_inputs.iter() {
1939                v.push(*outer_xs_off);
1940            }
1941            v
1942        }
1943        Thunk::ScanBackward {
1944            outer_init_off,
1945            outer_traj_off,
1946            outer_upstream_off,
1947            outer_xs_offs,
1948            ..
1949        } => {
1950            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1951            for (off, _) in outer_xs_offs.iter() {
1952                v.push(*off);
1953            }
1954            v
1955        }
1956        Thunk::ScanBackwardXs {
1957            outer_init_off,
1958            outer_traj_off,
1959            outer_upstream_off,
1960            outer_xs_offs,
1961            ..
1962        } => {
1963            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1964            for (off, _) in outer_xs_offs.iter() {
1965                v.push(*off);
1966            }
1967            v
1968        }
1969        Thunk::CustomFn { inputs, .. } => {
1970            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1971        }
1972        Thunk::ActivationInPlace { data, .. } => vec![*data],
1973        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1974            vec![*src, *g, *b]
1975        }
1976        Thunk::BatchNormInference {
1977            src,
1978            g,
1979            b,
1980            mean,
1981            var,
1982            ..
1983        } => vec![*src, *g, *b, *mean, *var],
1984        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1985        Thunk::AxialRope2d { src, .. } => vec![*src],
1986        Thunk::FusedResidualLN {
1987            x, res, bias, g, b, ..
1988        } => vec![*x, *res, *bias, *g, *b],
1989        Thunk::FusedResidualRmsNorm {
1990            x, res, bias, g, b, ..
1991        } => vec![*x, *res, *bias, *g, *b],
1992        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1993        Thunk::Softmax { data, .. } => vec![*data],
1994        Thunk::Cumsum { src, .. } => vec![*src],
1995        Thunk::Sample { logits, .. } => vec![*logits],
1996        Thunk::RngNormal { .. } | Thunk::RngUniform { .. } => vec![],
1997        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1998        Thunk::DequantMatMul {
1999            x, w_q, scale, zp, ..
2000        } => vec![*x, *w_q, *scale, *zp],
2001        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
2002        Thunk::DequantMatMulInt4 {
2003            x, w_q, scale, zp, ..
2004        } => vec![*x, *w_q, *scale, *zp],
2005        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
2006        Thunk::DequantMatMulNvfp4 {
2007            x,
2008            w_q,
2009            scale,
2010            global_scale,
2011            ..
2012        } => vec![*x, *w_q, *scale, *global_scale],
2013        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
2014        Thunk::SelectiveScan {
2015            x, delta, a, b, c, ..
2016        } => vec![*x, *delta, *a, *b, *c],
2017        Thunk::GatedDeltaNet {
2018            q,
2019            k,
2020            v,
2021            g,
2022            beta,
2023            state,
2024            ..
2025        } => {
2026            let mut v = vec![*q, *k, *v, *g, *beta];
2027            if *state != 0 {
2028                v.push(*state);
2029            }
2030            v
2031        }
2032        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
2033        Thunk::AttentionBackward {
2034            q, k, v, dy, mask, ..
2035        } => {
2036            let mut v = vec![*q, *k, *v, *dy];
2037            if *mask != 0 {
2038                v.push(*mask);
2039            }
2040            v
2041        }
2042        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
2043        Thunk::FusedAttnBlock {
2044            hidden,
2045            qkv_w,
2046            out_w,
2047            mask,
2048            qkv_b,
2049            out_b,
2050            cos,
2051            sin,
2052            ..
2053        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
2054        Thunk::FusedSwiGLU { src, .. } => vec![*src],
2055        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _, _)| *off).collect(),
2056        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _, _)| *off).collect(),
2057        Thunk::Narrow { src, .. } => vec![*src],
2058        Thunk::Copy { src, .. } => vec![*src],
2059        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
2060        // Anything not enumerated → return the dst as a "read" too,
2061        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
2062        // this list safe to be incomplete.
2063        _ => vec![],
2064    }
2065}
2066
2067/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
2068/// `block_size` consecutive elements of a column share one f32
2069/// scale (and optionally a zero-point). The dequant happens inside
2070/// the inner accumulate so the f32 weight is never materialized.
2071///
2072/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
2073/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
2074/// empty.
2075///
2076/// Today this is the reference scalar implementation — the win is
2077/// memory bandwidth, not flops, since LLM weights dominate the
2078/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
2079/// fused-multiply-add is the natural follow-on.
2080#[allow(clippy::too_many_arguments)]
2081pub fn dequant_matmul_int8(
2082    x: &[f32],       // [m, k]
2083    w_bytes: &[i8],  // [k, n]
2084    scales: &[f32],  // [k/block, n]
2085    zps: &[f32],     // [k/block, n] or empty
2086    out: &mut [f32], // [m, n]
2087    m: usize,
2088    k: usize,
2089    n: usize,
2090    block_size: usize,
2091    asym: bool,
2092) {
2093    let blocks_per_col = k.div_ceil(block_size);
2094    for i in 0..m {
2095        for j in 0..n {
2096            let mut acc = 0f32;
2097            for p in 0..k {
2098                let block = p / block_size;
2099                let s = scales[block * n + j];
2100                let z = if asym { zps[block * n + j] } else { 0.0 };
2101                let q = w_bytes[p * n + j] as f32;
2102                let dequantized = (q - z) * s;
2103                acc += x[i * k + p] * dequantized;
2104            }
2105            out[i * n + j] = acc;
2106        }
2107    }
2108    let _ = blocks_per_col;
2109}
2110
2111#[allow(clippy::too_many_arguments)]
2112fn dequant_matmul_int4(
2113    x: &[f32],
2114    w_bytes: &[u8],
2115    scales: &[f32],
2116    zps: &[f32],
2117    out: &mut [f32],
2118    m: usize,
2119    k: usize,
2120    n: usize,
2121    block_size: usize,
2122    asym: bool,
2123) {
2124    for i in 0..m {
2125        for j in 0..n {
2126            let mut acc = 0f32;
2127            for p in 0..k {
2128                let block = p / block_size;
2129                let s = scales[block * n + j];
2130                let z = if asym { zps[block * n + j] } else { 0.0 };
2131                let byte_idx = (p * n + j) / 2;
2132                let nibble = if (p * n + j) & 1 == 0 {
2133                    w_bytes[byte_idx] & 0x0F
2134                } else {
2135                    w_bytes[byte_idx] >> 4
2136                };
2137                let dequantized = (nibble as f32 - z) * s;
2138                acc += x[i * k + p] * dequantized;
2139            }
2140            out[i * n + j] = acc;
2141        }
2142    }
2143}
2144
2145fn fp8_e4m3_to_f32(b: u8) -> f32 {
2146    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2147    let exp = (b >> 3) & 0x0F;
2148    let mant = b & 0x07;
2149    if exp == 0 {
2150        if mant == 0 {
2151            return 0.0;
2152        }
2153        return sign * (mant as f32) * 2f32.powi(-9);
2154    }
2155    if exp == 0x0F {
2156        return if mant == 0 {
2157            sign * f32::INFINITY
2158        } else {
2159            f32::NAN
2160        };
2161    }
2162    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
2163}
2164
2165fn fp8_e5m2_to_f32(b: u8) -> f32 {
2166    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
2167    let exp = (b >> 2) & 0x1F;
2168    let mant = b & 0x03;
2169    if exp == 0 {
2170        if mant == 0 {
2171            return 0.0;
2172        }
2173        return sign * (mant as f32) * 2f32.powi(-16);
2174    }
2175    if exp == 0x1F {
2176        return if mant == 0 {
2177            sign * f32::INFINITY
2178        } else {
2179            f32::NAN
2180        };
2181    }
2182    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
2183}
2184
2185#[allow(clippy::too_many_arguments)]
2186fn dequant_matmul_fp8(
2187    x: &[f32],
2188    w_bytes: &[u8],
2189    scales: &[f32],
2190    out: &mut [f32],
2191    m: usize,
2192    k: usize,
2193    n: usize,
2194    e5m2: bool,
2195) {
2196    let dequant = if e5m2 {
2197        fp8_e5m2_to_f32
2198    } else {
2199        fp8_e4m3_to_f32
2200    };
2201    for i in 0..m {
2202        for j in 0..n {
2203            let mut acc = 0f32;
2204            for p in 0..k {
2205                let w = dequant(w_bytes[p * n + j]);
2206                let s = scales.get(j).copied().unwrap_or(1.0);
2207                acc += x[i * k + p] * w * s;
2208            }
2209            out[i * n + j] = acc;
2210        }
2211    }
2212}
2213
2214#[allow(clippy::too_many_arguments)]
2215pub fn dequant_matmul_nvfp4(
2216    x: &[f32],
2217    w_bytes: &[u8],
2218    scale_bytes: &[u8],
2219    global_scale: f32,
2220    out: &mut [f32],
2221    m: usize,
2222    k: usize,
2223    n: usize,
2224) {
2225    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
2226    let gs = NVFP4_GROUP_SIZE;
2227    for i in 0..m {
2228        for j in 0..n {
2229            let mut acc = 0f32;
2230            for p in 0..k {
2231                let byte_idx = (p * n + j) / 2;
2232                let nibble = if (p * n + j) & 1 == 0 {
2233                    w_bytes[byte_idx] & 0x0F
2234                } else {
2235                    w_bytes[byte_idx] >> 4
2236                };
2237                let block = p / gs;
2238                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
2239                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
2240                acc += x[i * k + p] * w;
2241            }
2242            out[i * n + j] = acc;
2243        }
2244    }
2245}
2246
2247/// Fused sampling step: logits → top-k filter → top-p truncation
2248/// → softmax → multinomial sample. Operates on one row of length
2249/// `vocab` and returns the sampled index. Plan #42.
2250///
2251/// Internal scratch is on the stack via SmallVec-style fallback —
2252/// for `vocab > 8192` we heap-allocate a working buffer; below
2253/// that we keep things in a fixed array. (TODO: thread the
2254/// scratch through ThunkSchedule like sdpa_scores does.)
2255fn sample_row(
2256    logits: &[f32],
2257    top_k: usize,
2258    top_p: f32,
2259    temperature: f32,
2260    rng: &mut rlx_ir::Philox4x32,
2261) -> usize {
2262    let v = logits.len();
2263    if v == 0 {
2264        return 0;
2265    }
2266    let temp = temperature.max(1e-6);
2267    // Copy + temperature-scale into a working buffer.
2268    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2269
2270    // Top-k: zero out everything but the k largest by setting to -inf.
2271    if top_k > 0 && top_k < v {
2272        // Partial selection: find k-th largest then mask below.
2273        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2274        // Sort descending; partial would be O(n log k), full sort is fine
2275        // for typical vocab sizes (32k-128k) — single-row work.
2276        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2277        let cutoff = indexed[top_k - 1].1;
2278        for x in scaled.iter_mut() {
2279            if *x < cutoff {
2280                *x = f32::NEG_INFINITY;
2281            }
2282        }
2283    }
2284
2285    // Stable softmax.
2286    let mut max_l = f32::NEG_INFINITY;
2287    for &x in &scaled {
2288        if x > max_l {
2289            max_l = x;
2290        }
2291    }
2292    let mut sum = 0.0f32;
2293    for x in scaled.iter_mut() {
2294        *x = (*x - max_l).exp();
2295        sum += *x;
2296    }
2297    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2298    for x in scaled.iter_mut() {
2299        *x *= inv;
2300    }
2301
2302    // Top-p: keep the smallest set of tokens whose cumulative
2303    // probability exceeds top_p (after sorting descending).
2304    if top_p < 1.0 {
2305        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2306        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2307        let mut cum = 0.0f32;
2308        let mut keep = vec![false; v];
2309        for (idx, p) in indexed.iter() {
2310            keep[*idx] = true;
2311            cum += *p;
2312            if cum >= top_p {
2313                break;
2314            }
2315        }
2316        let mut new_sum = 0.0f32;
2317        for (i, x) in scaled.iter_mut().enumerate() {
2318            if !keep[i] {
2319                *x = 0.0;
2320            }
2321            new_sum += *x;
2322        }
2323        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2324        for x in scaled.iter_mut() {
2325            *x *= inv;
2326        }
2327    }
2328
2329    // Multinomial sample via inverse-CDF.
2330    let r = rng.next_f32();
2331    let mut acc = 0.0f32;
2332    for (i, &p) in scaled.iter().enumerate() {
2333        acc += p;
2334        if r <= acc {
2335            return i;
2336        }
2337    }
2338    v - 1 // floating-point edge case fallback
2339}
2340
2341/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2342/// scores matrix. Custom masks are read from a tensor and not handled here.
2343/// `None` is a no-op so callers don't need to special-case it.
2344#[inline]
2345fn apply_synthetic_mask(
2346    scores: &mut [f32],
2347    q_seq: usize,
2348    k_seq: usize,
2349    kind: rlx_ir::op::MaskKind,
2350) {
2351    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2352    let q_offset = k_seq.saturating_sub(q_seq);
2353    match kind {
2354        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2355        rlx_ir::op::MaskKind::Causal => {
2356            for qi in 0..q_seq {
2357                let abs_q = q_offset + qi;
2358                for ki in (abs_q + 1)..k_seq {
2359                    scores[qi * k_seq + ki] = neg;
2360                }
2361            }
2362        }
2363        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2364            for qi in 0..q_seq {
2365                let abs_q = q_offset + qi;
2366                let lo = abs_q.saturating_sub(w);
2367                for ki in 0..k_seq {
2368                    if ki < lo || ki > abs_q {
2369                        scores[qi * k_seq + ki] = neg;
2370                    }
2371                }
2372            }
2373        }
2374    }
2375}
2376
2377/// NCL `[N,C,L]` or NCHW `[N,C,H,W]` → `(n, c, h, w)` for 2D conv/norm thunks.
2378fn conv_nchw_dims(shape: &Shape) -> (u32, u32, u32, u32) {
2379    match shape.rank() {
2380        3 => (
2381            shape.dim(0).unwrap_static() as u32,
2382            shape.dim(1).unwrap_static() as u32,
2383            1,
2384            shape.dim(2).unwrap_static() as u32,
2385        ),
2386        4 => (
2387            shape.dim(0).unwrap_static() as u32,
2388            shape.dim(1).unwrap_static() as u32,
2389            shape.dim(2).unwrap_static() as u32,
2390            shape.dim(3).unwrap_static() as u32,
2391        ),
2392        r => panic!("conv_nchw_dims: expected rank 3 or 4, got {r}"),
2393    }
2394}
2395
2396/// Compile graph into thunk schedule.
2397pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2398    compile_thunks_with_rng(graph, arena, rlx_ir::RngOptions::default())
2399}
2400
2401/// Compile graph into thunk schedule with explicit RNG policy.
2402pub fn compile_thunks_with_rng(
2403    graph: &Graph,
2404    arena: &Arena,
2405    rng: rlx_ir::RngOptions,
2406) -> ThunkSchedule {
2407    let rng_shared = Arc::new(std::sync::RwLock::new(rng));
2408    let mut thunks = Vec::with_capacity(graph.len());
2409
2410    for node in graph.nodes() {
2411        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2412        // to their parent's slot by the memory planner — no copy needed.
2413        // Plan #46.
2414        if rlx_opt::is_pure_view(graph, node) {
2415            thunks.push(Thunk::Nop);
2416            continue;
2417        }
2418        let t = match &node.op {
2419            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2420
2421            Op::FusedMatMulBiasAct { activation } => {
2422                let shape = &node.shape;
2423                let n = shape.dim(shape.rank() - 1).unwrap_static();
2424                let total = shape.num_elements().unwrap();
2425                let m = total / n;
2426                let a_len = get_len(graph, node.inputs[0]);
2427                let k = a_len / m;
2428                Thunk::FusedMmBiasAct {
2429                    a: node_offset(arena, node.inputs[0]),
2430                    w: node_offset(arena, node.inputs[1]),
2431                    bias: node_offset(arena, node.inputs[2]),
2432                    c: node_offset(arena, node.id),
2433                    m: m as u32,
2434                    k: k as u32,
2435                    n: n as u32,
2436                    act: *activation,
2437                }
2438            }
2439
2440            Op::FusedResidualLN { has_bias, eps } => {
2441                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2442                let total = node.shape.num_elements().unwrap();
2443                let rows = total / h;
2444                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2445                Thunk::FusedResidualLN {
2446                    x: node_offset(arena, node.inputs[0]),
2447                    res: node_offset(arena, node.inputs[1]),
2448                    bias: if *has_bias {
2449                        node_offset(arena, node.inputs[2])
2450                    } else {
2451                        0
2452                    },
2453                    g: node_offset(arena, node.inputs[g_idx]),
2454                    b: node_offset(arena, node.inputs[b_idx]),
2455                    out: node_offset(arena, node.id),
2456                    rows: rows as u32,
2457                    h: h as u32,
2458                    eps: *eps,
2459                    has_bias: *has_bias,
2460                }
2461            }
2462
2463            Op::FusedResidualRmsNorm { has_bias, eps } => {
2464                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2465                let total = node.shape.num_elements().unwrap();
2466                let rows = total / h;
2467                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2468                Thunk::FusedResidualRmsNorm {
2469                    x: node_offset(arena, node.inputs[0]),
2470                    res: node_offset(arena, node.inputs[1]),
2471                    bias: if *has_bias {
2472                        node_offset(arena, node.inputs[2])
2473                    } else {
2474                        0
2475                    },
2476                    g: node_offset(arena, node.inputs[g_idx]),
2477                    b: node_offset(arena, node.inputs[b_idx]),
2478                    out: node_offset(arena, node.id),
2479                    rows: rows as u32,
2480                    h: h as u32,
2481                    eps: *eps,
2482                    has_bias: *has_bias,
2483                }
2484            }
2485
2486            Op::MatMul => {
2487                let shape = &node.shape;
2488                let a_shape = &graph.node(node.inputs[0]).shape;
2489                let b_shape = &graph.node(node.inputs[1]).shape;
2490                // Prefer inferred matmul shape from operands — ONNX bundle
2491                // meta often over-ranks outputs (e.g. [seq, seq, H]).
2492                let eff =
2493                    rlx_ir::shape::matmul_shape(a_shape, b_shape).unwrap_or_else(|_| shape.clone());
2494                let rank = eff.rank().max(2);
2495                let n = eff.dim(rank - 1).unwrap_static();
2496                let k_dim = a_shape.dim(a_shape.rank().max(2) - 1).unwrap_static();
2497                if shape.dtype() == rlx_ir::DType::C64 {
2498                    // Complex GEMM (interleaved re/im). Handles 2D and
2499                    // 3D×2D (flatten M); both-operand batched C64 is not
2500                    // yet wired.
2501                    let both = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2502                    assert!(!both, "batched (both-operand) C64 matmul not yet supported");
2503                    let m: usize = if a_shape.rank() >= 3 {
2504                        (0..a_shape.rank() - 1)
2505                            .map(|d| a_shape.dim(d).unwrap_static())
2506                            .product()
2507                    } else {
2508                        a_shape.dim(a_shape.rank() - 2).unwrap_static()
2509                    };
2510                    Thunk::CgemmC64 {
2511                        a: node_offset(arena, node.inputs[0]),
2512                        b: node_offset(arena, node.inputs[1]),
2513                        c: node_offset(arena, node.id),
2514                        m: m as u32,
2515                        k: k_dim as u32,
2516                        n: n as u32,
2517                    }
2518                } else {
2519                    // Batched GEMM only when both operands carry batch dimensions.
2520                    // 3D×2D (activations × shared weight) must flatten to one Sgemm.
2521                    let both_batched = a_shape.rank() >= 3 && b_shape.rank() >= 3;
2522                    let batched_3d =
2523                        rank >= 3 && both_batched && a_shape.rank() + b_shape.rank() > 4;
2524                    if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2525                        let mut batch_prod = 1usize;
2526                        for d in 0..rank - 2 {
2527                            batch_prod *= eff.dim(d).unwrap_static();
2528                        }
2529                        let m_dim = eff.dim(rank - 2).unwrap_static();
2530                        Thunk::BatchedDgemmF64 {
2531                            a: node_offset(arena, node.inputs[0]),
2532                            b: node_offset(arena, node.inputs[1]),
2533                            c: node_offset(arena, node.id),
2534                            batch: batch_prod as u32,
2535                            m: m_dim as u32,
2536                            k: k_dim as u32,
2537                            n: n as u32,
2538                        }
2539                    } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2540                        let mut batch_prod = 1usize;
2541                        for d in 0..rank - 2 {
2542                            batch_prod *= eff.dim(d).unwrap_static();
2543                        }
2544                        let m_dim = eff.dim(rank - 2).unwrap_static();
2545                        Thunk::BatchedSgemm {
2546                            a: node_offset(arena, node.inputs[0]),
2547                            b: node_offset(arena, node.inputs[1]),
2548                            c: node_offset(arena, node.id),
2549                            batch: batch_prod as u32,
2550                            m: m_dim as u32,
2551                            k: k_dim as u32,
2552                            n: n as u32,
2553                        }
2554                    } else {
2555                        let m = if a_shape.rank() >= 3 && b_shape.rank() <= 2 {
2556                            let mut m_prod = 1usize;
2557                            for d in 0..a_shape.rank() - 1 {
2558                                m_prod *= a_shape.dim(d).unwrap_static();
2559                            }
2560                            m_prod
2561                        } else if a_shape.rank() >= 2 {
2562                            a_shape.dim(a_shape.rank() - 2).unwrap_static()
2563                        } else {
2564                            eff.num_elements().unwrap_or(1) / n.max(1)
2565                        };
2566                        match shape.dtype() {
2567                            rlx_ir::DType::F64 => Thunk::Dgemm {
2568                                a: node_offset(arena, node.inputs[0]),
2569                                b: node_offset(arena, node.inputs[1]),
2570                                c: node_offset(arena, node.id),
2571                                m: m as u32,
2572                                k: k_dim as u32,
2573                                n: n as u32,
2574                            },
2575                            _ => Thunk::Sgemm {
2576                                a: node_offset(arena, node.inputs[0]),
2577                                b: node_offset(arena, node.inputs[1]),
2578                                c: node_offset(arena, node.id),
2579                                m: m as u32,
2580                                k: k_dim as u32,
2581                                n: n as u32,
2582                            },
2583                        }
2584                    }
2585                }
2586            }
2587
2588            Op::Binary(op) => {
2589                let lhs_len = get_len(graph, node.inputs[0]);
2590                let rhs_len = get_len(graph, node.inputs[1]);
2591                let out_len = node.shape.num_elements().unwrap();
2592                if node.shape.dtype() == rlx_ir::DType::C64 {
2593                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2594                    // to `BinaryFullC64`; the rest don't have a
2595                    // single natural complex definition.
2596                    match op {
2597                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2598                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2599                            "Op::Binary({op:?}) on DType::C64: complex \
2600                             max/min/pow have no single natural definition \
2601                             — caller should drop to 2N-real-block (see \
2602                             spike-ac) and pick a convention there"
2603                        ),
2604                    }
2605                }
2606                // Compute broadcast strides for the slow path. Empty
2607                // vectors when no broadcast is needed (the fast-path
2608                // kernel ignores them anyway).
2609                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2610                    if lhs_len == out_len && rhs_len == out_len {
2611                        (Vec::new(), Vec::new(), Vec::new())
2612                    } else {
2613                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2614                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2615                        let out_dims_v = get_static_dims(graph, node.id);
2616                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2617                            // Dynamic shape — fall back to the legacy
2618                            // modulo path (correct for scalar / last-
2619                            // axis broadcast, which is the only
2620                            // dynamic case in practice).
2621                            (Vec::new(), Vec::new(), Vec::new())
2622                        } else {
2623                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2624                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2625                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2626                            (od, ls, rs)
2627                        }
2628                    };
2629                if node.shape.dtype() == rlx_ir::DType::C64 {
2630                    Thunk::BinaryFullC64 {
2631                        lhs: node_offset(arena, node.inputs[0]),
2632                        rhs: node_offset(arena, node.inputs[1]),
2633                        dst: node_offset(arena, node.id),
2634                        len: out_len as u32,
2635                        lhs_len: lhs_len as u32,
2636                        rhs_len: rhs_len as u32,
2637                        op: *op,
2638                        out_dims_bcast,
2639                        bcast_lhs_strides,
2640                        bcast_rhs_strides,
2641                    }
2642                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2643                    // f64 path — no BiasAdd fast-path (yet); use the
2644                    // general binary-with-broadcast kernel.
2645                    Thunk::BinaryFullF64 {
2646                        lhs: node_offset(arena, node.inputs[0]),
2647                        rhs: node_offset(arena, node.inputs[1]),
2648                        dst: node_offset(arena, node.id),
2649                        len: out_len as u32,
2650                        lhs_len: lhs_len as u32,
2651                        rhs_len: rhs_len as u32,
2652                        op: *op,
2653                        out_dims_bcast,
2654                        bcast_lhs_strides,
2655                        bcast_rhs_strides,
2656                    }
2657                } else if matches!(op, BinaryOp::Add)
2658                    && rhs_len < out_len
2659                    && out_len % rhs_len == 0
2660                    && is_trailing_bias_broadcast(
2661                        graph.node(node.inputs[1]).shape.dims(),
2662                        graph.node(node.id).shape.dims(),
2663                    )
2664                {
2665                    // `BiasAdd` is only correct when the bias is a
2666                    // *trailing* broadcast — rhs dims match the right-
2667                    // hand side of the output dims (with size-1 only
2668                    // allowed in left-padded outer positions).
2669                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2670                    // has rhs_len divide out_len cleanly but is a
2671                    // mid-shape singleton, NOT a trailing broadcast.
2672                    // Routing it through BiasAdd silently treats it as
2673                    // last-`rhs_len`-cols repeated — wrong values.
2674                    Thunk::BiasAdd {
2675                        src: node_offset(arena, node.inputs[0]),
2676                        bias: node_offset(arena, node.inputs[1]),
2677                        dst: node_offset(arena, node.id),
2678                        m: (out_len / rhs_len) as u32,
2679                        n: rhs_len as u32,
2680                    }
2681                } else {
2682                    let lhs_len = get_len(graph, node.inputs[0]);
2683                    Thunk::BinaryFull {
2684                        lhs: node_offset(arena, node.inputs[0]),
2685                        rhs: node_offset(arena, node.inputs[1]),
2686                        dst: node_offset(arena, node.id),
2687                        len: out_len as u32,
2688                        lhs_len: lhs_len as u32,
2689                        rhs_len: rhs_len as u32,
2690                        op: *op,
2691                        out_dims_bcast,
2692                        bcast_lhs_strides,
2693                        bcast_rhs_strides,
2694                        elem_bytes: node.shape.dtype().size_bytes() as u8,
2695                    }
2696                }
2697            }
2698
2699            Op::Activation(act) => {
2700                let len = node.shape.num_elements().unwrap();
2701                let in_off = node_offset(arena, node.inputs[0]);
2702                let out_off = node_offset(arena, node.id);
2703                if node.shape.dtype() == rlx_ir::DType::C64 {
2704                    // Only Neg/Exp/Log/Sqrt have natural complex
2705                    // extensions used in signal-processing graphs.
2706                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2707                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2708                    match act {
2709                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2710                        other => panic!(
2711                            "Op::Activation({other:?}) on DType::C64: no \
2712                             natural complex extension — supported on C64: \
2713                             Neg, Exp, Log, Sqrt"
2714                        ),
2715                    }
2716                    Thunk::ActivationC64 {
2717                        src: in_off,
2718                        dst: out_off,
2719                        len: len as u32,
2720                        kind: *act,
2721                    }
2722                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2723                    Thunk::ActivationF64 {
2724                        src: in_off,
2725                        dst: out_off,
2726                        len: len as u32,
2727                        kind: *act,
2728                    }
2729                } else if in_off == out_off {
2730                    // ActivationInPlace operates on a single buffer. When the
2731                    // planner has assigned input and output the same slot
2732                    // (typical post-fusion case), we just run on that slot.
2733                    Thunk::ActivationInPlace {
2734                        data: out_off,
2735                        len: len as u32,
2736                        act: *act,
2737                    }
2738                } else {
2739                    // Two-step: copy input → output, then activate output in place.
2740                    // The schedule executes them in this order; downstream
2741                    // thunks see the activated output at out_off.
2742                    thunks.push(Thunk::Copy {
2743                        src: in_off,
2744                        dst: out_off,
2745                        len: len as u32,
2746                    });
2747                    Thunk::ActivationInPlace {
2748                        data: out_off,
2749                        len: len as u32,
2750                        act: *act,
2751                    }
2752                }
2753            }
2754
2755            Op::Gather { axis } if *axis == 0 => {
2756                let table_shape = &graph.node(node.inputs[0]).shape;
2757                let table_total = table_shape.num_elements().unwrap();
2758                let trailing: usize = (1..table_shape.rank())
2759                    .map(|i| table_shape.dim(i).unwrap_static())
2760                    .product();
2761                let idx_len = get_len(graph, node.inputs[1]);
2762                let idx_i64 =
2763                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2764                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2765                Thunk::Gather {
2766                    table: node_offset(arena, node.inputs[0]),
2767                    table_len: table_total as u32,
2768                    idx: node_offset(arena, node.inputs[1]),
2769                    dst: node_offset(arena, node.id),
2770                    num_idx: idx_len as u32,
2771                    trailing: trailing as u32,
2772                    idx_i64,
2773                    table_bytes,
2774                }
2775            }
2776
2777            Op::Gather { axis } => {
2778                // Non-zero axis: outer × num_idx × trailing layout.
2779                let table_shape = &graph.node(node.inputs[0]).shape;
2780                let rank = table_shape.rank();
2781                let outer: usize = (0..*axis)
2782                    .map(|i| table_shape.dim(i).unwrap_static())
2783                    .product::<usize>()
2784                    .max(1);
2785                let trailing: usize = (*axis + 1..rank)
2786                    .map(|i| table_shape.dim(i).unwrap_static())
2787                    .product::<usize>()
2788                    .max(1);
2789                let axis_dim = table_shape.dim(*axis).unwrap_static();
2790                let idx_len = get_len(graph, node.inputs[1]);
2791                let idx_i64 =
2792                    u8::from(graph.node(node.inputs[1]).shape.dtype() == rlx_ir::DType::I64);
2793                let table_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
2794                Thunk::GatherAxis {
2795                    table: node_offset(arena, node.inputs[0]),
2796                    idx: node_offset(arena, node.inputs[1]),
2797                    dst: node_offset(arena, node.id),
2798                    outer: outer as u32,
2799                    axis_dim: axis_dim as u32,
2800                    num_idx: idx_len as u32,
2801                    trailing: trailing as u32,
2802                    idx_i64,
2803                    table_bytes,
2804                }
2805            }
2806
2807            Op::Narrow { axis, start, len } => {
2808                let in_shape = &graph.node(node.inputs[0]).shape;
2809                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2810                let rank = in_shape.rank();
2811                let outer: usize = (0..*axis)
2812                    .map(|i| in_shape.dim(i).unwrap_static())
2813                    .product::<usize>()
2814                    .max(1);
2815                let inner: usize = (*axis + 1..rank)
2816                    .map(|i| in_shape.dim(i).unwrap_static())
2817                    .product::<usize>()
2818                    .max(1);
2819                let in_axis = in_shape.dim(*axis).unwrap_static();
2820                let src_byte_offset =
2821                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2822                Thunk::Narrow {
2823                    src: src_byte_offset,
2824                    dst: node_offset(arena, node.id),
2825                    outer: outer as u32,
2826                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2827                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2828                    inner: (*len * inner) as u32,         // elements to copy per outer step
2829                    elem_bytes,
2830                }
2831            }
2832
2833            Op::Reshape { .. } | Op::StopGradient => {
2834                // Pure layout change: same total element count, plain copy.
2835                let len = node.shape.num_elements().unwrap();
2836                let src = node_offset(arena, node.inputs[0]);
2837                let dst = node_offset(arena, node.id);
2838                match node.shape.dtype() {
2839                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2840                        src,
2841                        dst,
2842                        len: len as u32,
2843                    },
2844                    rlx_ir::DType::I64 => Thunk::CopyI64 {
2845                        src,
2846                        dst,
2847                        len: len as u32,
2848                    },
2849                    _ => Thunk::Copy {
2850                        src,
2851                        dst,
2852                        len: len as u32,
2853                    },
2854                }
2855            }
2856
2857            Op::Cast { to } => {
2858                let in_node = graph.node(node.inputs[0]);
2859                let in_dtype = in_node.shape.dtype();
2860                let out_dtype = *to;
2861                let len = node.shape.num_elements().unwrap();
2862                let src = node_offset(arena, node.inputs[0]);
2863                let dst = node_offset(arena, node.id);
2864                if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I64 {
2865                    Thunk::CastF32ToI64 {
2866                        src,
2867                        dst,
2868                        len: len as u32,
2869                    }
2870                } else if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::F64 {
2871                    Thunk::CastF32ToF64 {
2872                        src,
2873                        dst,
2874                        len: len as u32,
2875                    }
2876                } else if in_dtype == rlx_ir::DType::F32 && out_dtype == rlx_ir::DType::I32 {
2877                    Thunk::CastF32ToI32 {
2878                        src,
2879                        dst,
2880                        len: len as u32,
2881                    }
2882                } else if in_dtype == rlx_ir::DType::I64 && out_dtype == rlx_ir::DType::F32 {
2883                    Thunk::CastI64ToF32 {
2884                        src,
2885                        dst,
2886                        len: len as u32,
2887                    }
2888                } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::I32 {
2889                    Thunk::CastBoolToI32 {
2890                        src,
2891                        dst,
2892                        len: len as u32,
2893                    }
2894                } else if in_dtype == rlx_ir::DType::Bool && out_dtype == rlx_ir::DType::F32 {
2895                    // Bool is 1 byte; the generic f32 Copy below would misread it as
2896                    // 4-byte f32. VITS sequence masks are `Cast(Less(...), f32)`.
2897                    Thunk::CastBoolToF32 {
2898                        src,
2899                        dst,
2900                        len: len as u32,
2901                    }
2902                } else if in_dtype == rlx_ir::DType::I32 && out_dtype == rlx_ir::DType::F32 {
2903                    Thunk::CastI32ToF32 {
2904                        src,
2905                        dst,
2906                        len: len as u32,
2907                    }
2908                } else if in_dtype == out_dtype {
2909                    match out_dtype {
2910                        rlx_ir::DType::F64 => Thunk::CopyF64 {
2911                            src,
2912                            dst,
2913                            len: len as u32,
2914                        },
2915                        rlx_ir::DType::I64 => Thunk::CopyI64 {
2916                            src,
2917                            dst,
2918                            len: len as u32,
2919                        },
2920                        _ => Thunk::Copy {
2921                            src,
2922                            dst,
2923                            len: len as u32,
2924                        },
2925                    }
2926                } else {
2927                    Thunk::Copy {
2928                        src,
2929                        dst,
2930                        len: len as u32,
2931                    }
2932                }
2933            }
2934
2935            Op::Quantize {
2936                axis,
2937                scales,
2938                zero_points,
2939            } => {
2940                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2941                Thunk::Quantize {
2942                    x: node_offset(arena, node.inputs[0]),
2943                    q: node_offset(arena, node.id),
2944                    len: node.shape.num_elements().unwrap() as u32,
2945                    chan_axis: chan_axis as u32,
2946                    chan_dim: chan_dim as u32,
2947                    inner: inner as u32,
2948                    scales: scales.clone(),
2949                    zero_points: zero_points.clone(),
2950                }
2951            }
2952
2953            Op::FakeQuantize {
2954                bits,
2955                axis,
2956                ste,
2957                scale_mode,
2958            } => {
2959                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2960                let state_off = match scale_mode {
2961                    rlx_ir::op::ScaleMode::PerBatch => None,
2962                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2963                        // Second input carries the [chan_dim] scale state.
2964                        debug_assert_eq!(
2965                            node.inputs.len(),
2966                            2,
2967                            "EMA/Fixed FakeQuantize needs a state input"
2968                        );
2969                        Some(node_offset(arena, node.inputs[1]))
2970                    }
2971                };
2972                Thunk::FakeQuantize {
2973                    x: node_offset(arena, node.inputs[0]),
2974                    out: node_offset(arena, node.id),
2975                    len: node.shape.num_elements().unwrap() as u32,
2976                    chan_axis: chan_axis as u32,
2977                    chan_dim: chan_dim as u32,
2978                    inner: inner as u32,
2979                    bits: *bits,
2980                    ste: *ste,
2981                    scale_mode: *scale_mode,
2982                    state_off,
2983                }
2984            }
2985
2986            Op::FakeQuantizeLSQ { bits, axis } => {
2987                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2988                Thunk::FakeQuantizeLSQ {
2989                    x: node_offset(arena, node.inputs[0]),
2990                    scale_off: node_offset(arena, node.inputs[1]),
2991                    out: node_offset(arena, node.id),
2992                    len: node.shape.num_elements().unwrap() as u32,
2993                    chan_axis: chan_axis as u32,
2994                    chan_dim: chan_dim as u32,
2995                    inner: inner as u32,
2996                    bits: *bits,
2997                }
2998            }
2999
3000            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
3001                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3002                Thunk::FakeQuantizeLSQBackwardX {
3003                    x: node_offset(arena, node.inputs[0]),
3004                    scale_off: node_offset(arena, node.inputs[1]),
3005                    dy: node_offset(arena, node.inputs[2]),
3006                    dx: node_offset(arena, node.id),
3007                    len: node.shape.num_elements().unwrap() as u32,
3008                    chan_axis: chan_axis as u32,
3009                    chan_dim: chan_dim as u32,
3010                    inner: inner as u32,
3011                    bits: *bits,
3012                }
3013            }
3014
3015            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
3016                // Output shape is [chan_dim] — node.shape doesn't
3017                // describe the input data layout, but inputs[0] does.
3018                let in_shape = &graph.node(node.inputs[0]).shape;
3019                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
3020                Thunk::FakeQuantizeLSQBackwardScale {
3021                    x: node_offset(arena, node.inputs[0]),
3022                    scale_off: node_offset(arena, node.inputs[1]),
3023                    dy: node_offset(arena, node.inputs[2]),
3024                    dscale: node_offset(arena, node.id),
3025                    len: in_shape.num_elements().unwrap() as u32,
3026                    chan_axis: chan_axis as u32,
3027                    chan_dim: chan_dim as u32,
3028                    inner: inner as u32,
3029                    bits: *bits,
3030                }
3031            }
3032
3033            Op::FakeQuantizeBackward { bits, axis, ste } => {
3034                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3035                Thunk::FakeQuantizeBackward {
3036                    x: node_offset(arena, node.inputs[0]),
3037                    dy: node_offset(arena, node.inputs[1]),
3038                    dx: node_offset(arena, node.id),
3039                    len: node.shape.num_elements().unwrap() as u32,
3040                    chan_axis: chan_axis as u32,
3041                    chan_dim: chan_dim as u32,
3042                    inner: inner as u32,
3043                    bits: *bits,
3044                    ste: *ste,
3045                }
3046            }
3047
3048            Op::Dequantize {
3049                axis,
3050                scales,
3051                zero_points,
3052            } => {
3053                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
3054                Thunk::Dequantize {
3055                    q: node_offset(arena, node.inputs[0]),
3056                    x: node_offset(arena, node.id),
3057                    len: node.shape.num_elements().unwrap() as u32,
3058                    chan_axis: chan_axis as u32,
3059                    chan_dim: chan_dim as u32,
3060                    inner: inner as u32,
3061                    scales: scales.clone(),
3062                    zero_points: zero_points.clone(),
3063                }
3064            }
3065
3066            Op::Expand { .. } => {
3067                // Broadcast: build per-output-dim strides where any input dim
3068                // of size 1 has stride 0 (read the same element repeatedly).
3069                // Reuses the Thunk::Transpose runtime — N-D walk with strides
3070                // is identical; only the strides differ.
3071                let in_shape = &graph.node(node.inputs[0]).shape;
3072                let out_shape = &node.shape;
3073                let in_rank = in_shape.rank();
3074                let out_rank = out_shape.rank();
3075                // Implicit leading 1s if input has lower rank.
3076                let pad = out_rank.saturating_sub(in_rank);
3077                let in_dims: Vec<usize> = (0..out_rank)
3078                    .map(|i| {
3079                        if i < pad {
3080                            1
3081                        } else {
3082                            in_shape.dim(i - pad).unwrap_static()
3083                        }
3084                    })
3085                    .collect();
3086                // Row-major input strides (over the padded shape).
3087                let mut in_strides_full = vec![1usize; out_rank];
3088                for d in (0..out_rank.saturating_sub(1)).rev() {
3089                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3090                }
3091                let out_dims: Vec<u32> = (0..out_rank)
3092                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
3093                    .collect();
3094                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
3095                let in_strides: Vec<u32> = (0..out_rank)
3096                    .map(|i| {
3097                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
3098                            0
3099                        } else {
3100                            in_strides_full[i] as u32
3101                        }
3102                    })
3103                    .collect();
3104                let in_total = in_dims.iter().product::<usize>() as u32;
3105                let src = node_offset(arena, node.inputs[0]);
3106                let dst = node_offset(arena, node.id);
3107                let elem_bytes = node.shape.dtype().size_bytes() as u8;
3108                match node.shape.dtype() {
3109                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
3110                        src,
3111                        dst,
3112                        in_total,
3113                        out_dims,
3114                        in_strides,
3115                    },
3116                    _ => Thunk::Transpose {
3117                        src,
3118                        dst,
3119                        in_total,
3120                        out_dims,
3121                        in_strides,
3122                        elem_bytes,
3123                    },
3124                }
3125            }
3126
3127            Op::RmsNorm { eps, .. } => {
3128                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3129                let total = node.shape.num_elements().unwrap();
3130                Thunk::RmsNorm {
3131                    src: node_offset(arena, node.inputs[0]),
3132                    g: node_offset(arena, node.inputs[1]),
3133                    b: node_offset(arena, node.inputs[2]),
3134                    dst: node_offset(arena, node.id),
3135                    rows: (total / h) as u32,
3136                    h: h as u32,
3137                    eps: *eps,
3138                }
3139            }
3140
3141            Op::LayerNorm { eps, .. } => {
3142                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3143                let total = node.shape.num_elements().unwrap();
3144                Thunk::LayerNorm {
3145                    src: node_offset(arena, node.inputs[0]),
3146                    g: node_offset(arena, node.inputs[1]),
3147                    b: node_offset(arena, node.inputs[2]),
3148                    dst: node_offset(arena, node.id),
3149                    rows: (total / h) as u32,
3150                    h: h as u32,
3151                    eps: *eps,
3152                }
3153            }
3154
3155            Op::GroupNorm { num_groups, eps } => {
3156                let in_shape = &graph.node(node.inputs[0]).shape;
3157                let (n, c, h, w) = conv_nchw_dims(in_shape);
3158                Thunk::GroupNorm {
3159                    src: node_offset(arena, node.inputs[0]),
3160                    g: node_offset(arena, node.inputs[1]),
3161                    b: node_offset(arena, node.inputs[2]),
3162                    dst: node_offset(arena, node.id),
3163                    n,
3164                    c,
3165                    h,
3166                    w,
3167                    num_groups: *num_groups as u32,
3168                    eps: *eps,
3169                }
3170            }
3171
3172            Op::BatchNormInference { eps } => {
3173                let in_shape = &graph.node(node.inputs[0]).shape;
3174                let rank = in_shape.rank();
3175                let channels = in_shape.dim(rank - 1).unwrap_static();
3176                let total = in_shape.num_elements().unwrap_or(0);
3177                let count = (total / channels.max(1)) as u32;
3178                Thunk::BatchNormInference {
3179                    src: node_offset(arena, node.inputs[0]),
3180                    g: node_offset(arena, node.inputs[1]),
3181                    b: node_offset(arena, node.inputs[2]),
3182                    mean: node_offset(arena, node.inputs[3]),
3183                    var: node_offset(arena, node.inputs[4]),
3184                    dst: node_offset(arena, node.id),
3185                    count,
3186                    channels: channels as u32,
3187                    eps: *eps,
3188                }
3189            }
3190
3191            Op::BatchNormInferenceBackwardInput { eps } => {
3192                let x_shape = &graph.node(node.inputs[0]).shape;
3193                let rank = x_shape.rank();
3194                let channels = x_shape.dim(rank - 1).unwrap_static();
3195                let total = x_shape.num_elements().unwrap_or(0);
3196                Thunk::BatchNormInferenceBackwardInput {
3197                    x: node_offset(arena, node.inputs[0]),
3198                    gamma: node_offset(arena, node.inputs[1]),
3199                    mean: node_offset(arena, node.inputs[2]),
3200                    var: node_offset(arena, node.inputs[3]),
3201                    dy: node_offset(arena, node.inputs[4]),
3202                    dx: node_offset(arena, node.id),
3203                    count: (total / channels.max(1)) as u32,
3204                    channels: channels as u32,
3205                    eps: *eps,
3206                }
3207            }
3208
3209            Op::BatchNormInferenceBackwardGamma { eps } => {
3210                let x_shape = &graph.node(node.inputs[0]).shape;
3211                let rank = x_shape.rank();
3212                let channels = x_shape.dim(rank - 1).unwrap_static();
3213                let total = x_shape.num_elements().unwrap_or(0);
3214                let _gamma_shape = &graph.node(node.id).shape;
3215                Thunk::BatchNormInferenceBackwardGamma {
3216                    x: node_offset(arena, node.inputs[0]),
3217                    mean: node_offset(arena, node.inputs[1]),
3218                    var: node_offset(arena, node.inputs[2]),
3219                    dy: node_offset(arena, node.inputs[3]),
3220                    dgamma: node_offset(arena, node.id),
3221                    count: (total / channels.max(1)) as u32,
3222                    channels: channels as u32,
3223                    eps: *eps,
3224                }
3225            }
3226
3227            Op::BatchNormInferenceBackwardBeta => {
3228                let dy_shape = &graph.node(node.inputs[0]).shape;
3229                let rank = dy_shape.rank();
3230                let channels = dy_shape.dim(rank - 1).unwrap_static();
3231                let total = dy_shape.num_elements().unwrap_or(0);
3232                Thunk::BatchNormInferenceBackwardBeta {
3233                    dy: node_offset(arena, node.inputs[0]),
3234                    dbeta: node_offset(arena, node.id),
3235                    count: (total / channels.max(1)) as u32,
3236                    channels: channels as u32,
3237                }
3238            }
3239
3240            Op::LayerNorm2d { eps } => {
3241                let in_shape = &graph.node(node.inputs[0]).shape;
3242                let (n, c, h, w) = conv_nchw_dims(in_shape);
3243                Thunk::LayerNorm2d {
3244                    src: node_offset(arena, node.inputs[0]),
3245                    g: node_offset(arena, node.inputs[1]),
3246                    b: node_offset(arena, node.inputs[2]),
3247                    dst: node_offset(arena, node.id),
3248                    n,
3249                    c,
3250                    h,
3251                    w,
3252                    eps: *eps,
3253                }
3254            }
3255
3256            Op::ConvTranspose2d {
3257                kernel_size,
3258                stride,
3259                padding,
3260                dilation,
3261                output_padding: _,
3262                groups,
3263            } => {
3264                let in_shape = &graph.node(node.inputs[0]).shape;
3265                let out_shape = &node.shape;
3266                let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
3267                let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
3268                Thunk::ConvTranspose2d {
3269                    src: node_offset(arena, node.inputs[0]),
3270                    weight: node_offset(arena, node.inputs[1]),
3271                    dst: node_offset(arena, node.id),
3272                    n,
3273                    c_in,
3274                    h,
3275                    w_in,
3276                    c_out,
3277                    h_out,
3278                    w_out,
3279                    kh: kernel_size[0] as u32,
3280                    kw: kernel_size[1] as u32,
3281                    sh: stride.first().copied().unwrap_or(1) as u32,
3282                    sw: stride.get(1).copied().unwrap_or(1) as u32,
3283                    ph: padding.first().copied().unwrap_or(0) as u32,
3284                    pw: padding.get(1).copied().unwrap_or(0) as u32,
3285                    dh: dilation.first().copied().unwrap_or(1) as u32,
3286                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
3287                    groups: *groups as u32,
3288                }
3289            }
3290
3291            Op::ResizeNearest2x => {
3292                let in_shape = &graph.node(node.inputs[0]).shape;
3293                let (n, c, h, w) = conv_nchw_dims(in_shape);
3294                Thunk::ResizeNearest2x {
3295                    src: node_offset(arena, node.inputs[0]),
3296                    dst: node_offset(arena, node.id),
3297                    n,
3298                    c,
3299                    h,
3300                    w,
3301                }
3302            }
3303
3304            Op::AxialRope2d {
3305                end_x,
3306                end_y,
3307                head_dim,
3308                num_heads,
3309                theta,
3310                repeat_factor,
3311            } => {
3312                let in_shape = &graph.node(node.inputs[0]).shape;
3313                let batch = in_shape.dim(0).unwrap_static() as u32;
3314                let seq = in_shape.dim(1).unwrap_static() as u32;
3315                let hidden = in_shape.dim(2).unwrap_static() as u32;
3316                Thunk::AxialRope2d {
3317                    src: node_offset(arena, node.inputs[0]),
3318                    dst: node_offset(arena, node.id),
3319                    batch,
3320                    seq,
3321                    hidden,
3322                    end_x: *end_x as u32,
3323                    end_y: *end_y as u32,
3324                    head_dim: *head_dim as u32,
3325                    num_heads: *num_heads as u32,
3326                    theta: *theta,
3327                    repeat_factor: *repeat_factor as u32,
3328                }
3329            }
3330
3331            Op::Softmax { axis } => {
3332                let rank = node.shape.rank();
3333                let ax = if *axis < 0 {
3334                    (rank as i32 + axis) as usize
3335                } else {
3336                    *axis as usize
3337                };
3338                let cols = node.shape.dim(ax).unwrap_static();
3339                let total = node.shape.num_elements().unwrap();
3340                let in_off = node_offset(arena, node.inputs[0]);
3341                let out_off = node_offset(arena, node.id);
3342                // Softmax kernel runs in-place on its data buffer. If the
3343                // planner gave input and output separate slots (their live
3344                // ranges overlap, so no aliasing), the output starts
3345                // uninitialized — emit a Copy first so the data is there.
3346                // Same pattern as Op::Activation.
3347                if in_off != out_off {
3348                    thunks.push(Thunk::Copy {
3349                        src: in_off,
3350                        dst: out_off,
3351                        len: total as u32,
3352                    });
3353                }
3354                Thunk::Softmax {
3355                    data: out_off,
3356                    rows: (total / cols) as u32,
3357                    cols: cols as u32,
3358                }
3359            }
3360
3361            Op::SelectiveScan { state_size } => {
3362                let in_shape = &graph.node(node.inputs[0]).shape;
3363                let (batch, seq, hidden) = (
3364                    in_shape.dim(0).unwrap_static(),
3365                    in_shape.dim(1).unwrap_static(),
3366                    in_shape.dim(2).unwrap_static(),
3367                );
3368                Thunk::SelectiveScan {
3369                    x: node_offset(arena, node.inputs[0]),
3370                    delta: node_offset(arena, node.inputs[1]),
3371                    a: node_offset(arena, node.inputs[2]),
3372                    b: node_offset(arena, node.inputs[3]),
3373                    c: node_offset(arena, node.inputs[4]),
3374                    dst: node_offset(arena, node.id),
3375                    batch: batch as u32,
3376                    seq: seq as u32,
3377                    hidden: hidden as u32,
3378                    state_size: *state_size as u32,
3379                }
3380            }
3381
3382            Op::GatedDeltaNet {
3383                state_size,
3384                carry_state,
3385            } => {
3386                let q_shape = &graph.node(node.inputs[0]).shape;
3387                let (batch, seq, heads) = (
3388                    q_shape.dim(0).unwrap_static(),
3389                    q_shape.dim(1).unwrap_static(),
3390                    q_shape.dim(2).unwrap_static(),
3391                );
3392                let state_off = if *carry_state {
3393                    node_offset(arena, node.inputs[5])
3394                } else {
3395                    0
3396                };
3397                Thunk::GatedDeltaNet {
3398                    q: node_offset(arena, node.inputs[0]),
3399                    k: node_offset(arena, node.inputs[1]),
3400                    v: node_offset(arena, node.inputs[2]),
3401                    g: node_offset(arena, node.inputs[3]),
3402                    beta: node_offset(arena, node.inputs[4]),
3403                    state: state_off,
3404                    dst: node_offset(arena, node.id),
3405                    batch: batch as u32,
3406                    seq: seq as u32,
3407                    heads: heads as u32,
3408                    state_size: *state_size as u32,
3409                }
3410            }
3411
3412            Op::Lstm {
3413                hidden_size,
3414                num_layers,
3415                bidirectional,
3416                carry,
3417            } => {
3418                let x_shape = &graph.node(node.inputs[0]).shape;
3419                let (batch, seq, input_size) = (
3420                    x_shape.dim(0).unwrap_static(),
3421                    x_shape.dim(1).unwrap_static(),
3422                    x_shape.dim(2).unwrap_static(),
3423                );
3424                let (h0, c0) = if *carry {
3425                    (
3426                        node_offset(arena, node.inputs[4]),
3427                        node_offset(arena, node.inputs[5]),
3428                    )
3429                } else {
3430                    (0, 0)
3431                };
3432                Thunk::Lstm {
3433                    x: node_offset(arena, node.inputs[0]),
3434                    w_ih: node_offset(arena, node.inputs[1]),
3435                    w_hh: node_offset(arena, node.inputs[2]),
3436                    bias: node_offset(arena, node.inputs[3]),
3437                    h0,
3438                    c0,
3439                    dst: node_offset(arena, node.id),
3440                    batch: batch as u32,
3441                    seq: seq as u32,
3442                    input_size: input_size as u32,
3443                    hidden: *hidden_size as u32,
3444                    num_layers: *num_layers as u32,
3445                    bidirectional: *bidirectional,
3446                    carry: *carry,
3447                }
3448            }
3449
3450            Op::QMatMul {
3451                x_zp,
3452                w_zp,
3453                out_zp,
3454                mult,
3455            } => {
3456                let x_shape = &graph.node(node.inputs[0]).shape;
3457                let w_shape = &graph.node(node.inputs[1]).shape;
3458                let m = x_shape.dim(0).unwrap_static();
3459                let k = x_shape.dim(1).unwrap_static();
3460                let n = w_shape.dim(1).unwrap_static();
3461                Thunk::QMatMul {
3462                    x: node_offset(arena, node.inputs[0]),
3463                    w: node_offset(arena, node.inputs[1]),
3464                    bias: node_offset(arena, node.inputs[2]),
3465                    out: node_offset(arena, node.id),
3466                    m: m as u32,
3467                    k: k as u32,
3468                    n: n as u32,
3469                    x_zp: *x_zp,
3470                    w_zp: *w_zp,
3471                    out_zp: *out_zp,
3472                    mult: *mult,
3473                }
3474            }
3475
3476            Op::QConv2d {
3477                kernel_size,
3478                stride,
3479                padding,
3480                dilation,
3481                groups,
3482                x_zp,
3483                w_zp,
3484                out_zp,
3485                mult,
3486            } => {
3487                let in_shape = &graph.node(node.inputs[0]).shape;
3488                let w_shape = &graph.node(node.inputs[1]).shape;
3489                let out_shape = &node.shape;
3490                if kernel_size.len() == 2
3491                    && in_shape.rank() == 4
3492                    && w_shape.rank() == 4
3493                    && out_shape.rank() == 4
3494                {
3495                    Thunk::QConv2d {
3496                        x: node_offset(arena, node.inputs[0]),
3497                        w: node_offset(arena, node.inputs[1]),
3498                        bias: node_offset(arena, node.inputs[2]),
3499                        out: node_offset(arena, node.id),
3500                        n: in_shape.dim(0).unwrap_static() as u32,
3501                        c_in: in_shape.dim(1).unwrap_static() as u32,
3502                        h: in_shape.dim(2).unwrap_static() as u32,
3503                        w_in: in_shape.dim(3).unwrap_static() as u32,
3504                        c_out: out_shape.dim(1).unwrap_static() as u32,
3505                        h_out: out_shape.dim(2).unwrap_static() as u32,
3506                        w_out: out_shape.dim(3).unwrap_static() as u32,
3507                        kh: kernel_size[0] as u32,
3508                        kw: kernel_size[1] as u32,
3509                        sh: stride.first().copied().unwrap_or(1) as u32,
3510                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3511                        ph: padding.first().copied().unwrap_or(0) as u32,
3512                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3513                        dh: dilation.first().copied().unwrap_or(1) as u32,
3514                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3515                        groups: *groups as u32,
3516                        x_zp: *x_zp,
3517                        w_zp: *w_zp,
3518                        out_zp: *out_zp,
3519                        mult: *mult,
3520                    }
3521                } else {
3522                    Thunk::Nop
3523                }
3524            }
3525
3526            Op::DequantMatMul { scheme } => {
3527                use rlx_ir::quant::QuantScheme;
3528                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3529                let total = node.shape.num_elements().unwrap();
3530                let m = total / n.max(1);
3531                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3532                let k = x_total / m.max(1);
3533                if scheme.is_gguf() {
3534                    Thunk::DequantMatMulGguf {
3535                        x: node_offset(arena, node.inputs[0]),
3536                        w_q: node_offset(arena, node.inputs[1]),
3537                        dst: node_offset(arena, node.id),
3538                        m: m as u32,
3539                        k: k as u32,
3540                        n: n as u32,
3541                        scheme: *scheme,
3542                    }
3543                } else {
3544                    match scheme {
3545                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3546                            x: node_offset(arena, node.inputs[0]),
3547                            w_q: node_offset(arena, node.inputs[1]),
3548                            scale: node_offset(arena, node.inputs[2]),
3549                            global_scale: node_offset(arena, node.inputs[3]),
3550                            dst: node_offset(arena, node.id),
3551                            m: m as u32,
3552                            k: k as u32,
3553                            n: n as u32,
3554                        },
3555                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3556                            x: node_offset(arena, node.inputs[0]),
3557                            w_q: node_offset(arena, node.inputs[1]),
3558                            scale: node_offset(arena, node.inputs[2]),
3559                            zp: node_offset(arena, node.inputs[3]),
3560                            dst: node_offset(arena, node.id),
3561                            m: m as u32,
3562                            k: k as u32,
3563                            n: n as u32,
3564                            block_size: *block_size,
3565                            is_asymmetric: false,
3566                        },
3567                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3568                            x: node_offset(arena, node.inputs[0]),
3569                            w_q: node_offset(arena, node.inputs[1]),
3570                            scale: node_offset(arena, node.inputs[2]),
3571                            dst: node_offset(arena, node.id),
3572                            m: m as u32,
3573                            k: k as u32,
3574                            n: n as u32,
3575                            e5m2: false,
3576                        },
3577                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3578                            x: node_offset(arena, node.inputs[0]),
3579                            w_q: node_offset(arena, node.inputs[1]),
3580                            scale: node_offset(arena, node.inputs[2]),
3581                            dst: node_offset(arena, node.id),
3582                            m: m as u32,
3583                            k: k as u32,
3584                            n: n as u32,
3585                            e5m2: true,
3586                        },
3587                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3588                            x: node_offset(arena, node.inputs[0]),
3589                            w_q: node_offset(arena, node.inputs[1]),
3590                            scale: node_offset(arena, node.inputs[2]),
3591                            zp: node_offset(arena, node.inputs[3]),
3592                            dst: node_offset(arena, node.id),
3593                            m: m as u32,
3594                            k: k as u32,
3595                            n: n as u32,
3596                            block_size: *block_size,
3597                            is_asymmetric: false,
3598                        },
3599                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3600                            x: node_offset(arena, node.inputs[0]),
3601                            w_q: node_offset(arena, node.inputs[1]),
3602                            scale: node_offset(arena, node.inputs[2]),
3603                            zp: node_offset(arena, node.inputs[3]),
3604                            dst: node_offset(arena, node.id),
3605                            m: m as u32,
3606                            k: k as u32,
3607                            n: n as u32,
3608                            block_size: *block_size,
3609                            is_asymmetric: true,
3610                        },
3611                        other => panic!(
3612                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3613                        ),
3614                    }
3615                }
3616            }
3617
3618            Op::LoraMatMul { scale } => {
3619                // x [m, k], w [k, n], a [k, r], b [r, n].
3620                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3621                let total = node.shape.num_elements().unwrap();
3622                let m = total / n.max(1);
3623                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3624                let k = x_total / m.max(1);
3625                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3626                let r = a_total / k.max(1);
3627                Thunk::LoraMatMul {
3628                    x: node_offset(arena, node.inputs[0]),
3629                    w: node_offset(arena, node.inputs[1]),
3630                    a: node_offset(arena, node.inputs[2]),
3631                    b: node_offset(arena, node.inputs[3]),
3632                    dst: node_offset(arena, node.id),
3633                    m: m as u32,
3634                    k: k as u32,
3635                    n: n as u32,
3636                    r: r as u32,
3637                    scale: *scale,
3638                }
3639            }
3640
3641            Op::Sample {
3642                top_k,
3643                top_p,
3644                temperature,
3645                seed,
3646            } => {
3647                let in_shape = &graph.node(node.inputs[0]).shape;
3648                // Logits are [batch, vocab] (or [vocab] → batch=1).
3649                let (batch, vocab) = if in_shape.rank() >= 2 {
3650                    (
3651                        in_shape.dim(0).unwrap_static(),
3652                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3653                    )
3654                } else {
3655                    (1, in_shape.num_elements().unwrap_or(0))
3656                };
3657                Thunk::Sample {
3658                    logits: node_offset(arena, node.inputs[0]),
3659                    dst: node_offset(arena, node.id),
3660                    batch: batch as u32,
3661                    vocab: vocab as u32,
3662                    top_k: *top_k as u32,
3663                    top_p: *top_p,
3664                    temperature: *temperature,
3665                    seed: *seed,
3666                }
3667            }
3668
3669            Op::RngNormal {
3670                mean,
3671                scale,
3672                key,
3673                op_seed,
3674            } => Thunk::RngNormal {
3675                dst: node_offset(arena, node.id),
3676                len: node.shape.num_elements().unwrap_or(0) as u32,
3677                mean: *mean,
3678                scale: *scale,
3679                key: *key,
3680                op_seed: *op_seed,
3681            },
3682
3683            Op::RngUniform {
3684                low,
3685                high,
3686                key,
3687                op_seed,
3688            } => Thunk::RngUniform {
3689                dst: node_offset(arena, node.id),
3690                len: node.shape.num_elements().unwrap_or(0) as u32,
3691                low: *low,
3692                high: *high,
3693                key: *key,
3694                op_seed: *op_seed,
3695            },
3696
3697            Op::Cumsum { axis, exclusive } => {
3698                // For now CPU only supports last-axis cumsum (the
3699                // common case for sampling / ragged offsets).
3700                // Other axes can lower via Transpose → Cumsum →
3701                // Transpose; not on the hot path today.
3702                let rank = node.shape.rank();
3703                let ax = if *axis < 0 {
3704                    (rank as i32 + axis) as usize
3705                } else {
3706                    *axis as usize
3707                };
3708                assert_eq!(
3709                    ax,
3710                    rank - 1,
3711                    "Cumsum only supports the last axis on CPU today"
3712                );
3713                let cols = node.shape.dim(ax).unwrap_static();
3714                let total = node.shape.num_elements().unwrap();
3715                Thunk::Cumsum {
3716                    src: node_offset(arena, node.inputs[0]),
3717                    dst: node_offset(arena, node.id),
3718                    rows: (total / cols) as u32,
3719                    cols: cols as u32,
3720                    exclusive: *exclusive,
3721                }
3722            }
3723
3724            Op::Attention {
3725                num_heads,
3726                head_dim,
3727                mask_kind,
3728                score_scale,
3729                attn_logit_softcap: _,
3730            } => {
3731                // Layout dispatch: rank-4 input could be either
3732                // `[B, S, H, D]` (CPU's historical convention) or
3733                // `[B, H, S, D]` (the convention the GPU/TPU backends
3734                // share). Disambiguate by which axis matches
3735                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3736                let q_shape = &graph.node(node.inputs[0]).shape;
3737                let k_shape = &graph.node(node.inputs[1]).shape;
3738                let rank = q_shape.rank();
3739                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3740                    let d1 = q_shape.dim(1).unwrap_static();
3741                    let d2 = q_shape.dim(2).unwrap_static();
3742                    if d1 == *num_heads {
3743                        // [B, H, S, D]
3744                        (
3745                            q_shape.dim(0).unwrap_static(),
3746                            d2,
3747                            k_shape.dim(2).unwrap_static(),
3748                            true,
3749                        )
3750                    } else {
3751                        // [B, S, H, D]
3752                        (
3753                            q_shape.dim(0).unwrap_static(),
3754                            d1,
3755                            k_shape.dim(1).unwrap_static(),
3756                            false,
3757                        )
3758                    }
3759                } else if rank >= 3 {
3760                    (
3761                        q_shape.dim(0).unwrap_static(),
3762                        q_shape.dim(1).unwrap_static(),
3763                        k_shape.dim(1).unwrap_static(),
3764                        false,
3765                    )
3766                } else {
3767                    (
3768                        1,
3769                        q_shape.dim(0).unwrap_static(),
3770                        k_shape.dim(0).unwrap_static(),
3771                        false,
3772                    )
3773                };
3774                let mask_off = if matches!(
3775                    mask_kind,
3776                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3777                ) {
3778                    node_offset(arena, node.inputs[3])
3779                } else {
3780                    0
3781                };
3782                let hs = (*num_heads * *head_dim) as u32;
3783                Thunk::Attention {
3784                    q: node_offset(arena, node.inputs[0]),
3785                    k: node_offset(arena, node.inputs[1]),
3786                    v: node_offset(arena, node.inputs[2]),
3787                    mask: mask_off,
3788                    out: node_offset(arena, node.id),
3789                    batch: batch as u32,
3790                    seq: seq as u32,
3791                    kv_seq: kv_seq as u32,
3792                    heads: *num_heads as u32,
3793                    head_dim: *head_dim as u32,
3794                    mask_kind: *mask_kind,
3795                    scale: score_scale.unwrap_or((*head_dim as f32).powf(-0.5)),
3796                    // Defaults: each input is its own contiguous buffer
3797                    // with row stride = hidden. Rewritten by the
3798                    // Narrow→Attention fusion when applicable.
3799                    q_row_stride: hs,
3800                    k_row_stride: hs,
3801                    v_row_stride: hs,
3802                    bhsd,
3803                }
3804            }
3805
3806            Op::AttentionBackward {
3807                num_heads,
3808                head_dim,
3809                mask_kind,
3810                wrt,
3811            } => {
3812                let q_shape = &graph.node(node.inputs[0]).shape;
3813                let k_shape = &graph.node(node.inputs[1]).shape;
3814                let rank = q_shape.rank();
3815                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3816                    let d1 = q_shape.dim(1).unwrap_static();
3817                    let d2 = q_shape.dim(2).unwrap_static();
3818                    if d1 == *num_heads {
3819                        (
3820                            q_shape.dim(0).unwrap_static(),
3821                            d2,
3822                            k_shape.dim(2).unwrap_static(),
3823                            true,
3824                        )
3825                    } else {
3826                        (
3827                            q_shape.dim(0).unwrap_static(),
3828                            d1,
3829                            k_shape.dim(1).unwrap_static(),
3830                            false,
3831                        )
3832                    }
3833                } else if rank >= 3 {
3834                    (
3835                        q_shape.dim(0).unwrap_static(),
3836                        q_shape.dim(1).unwrap_static(),
3837                        k_shape.dim(1).unwrap_static(),
3838                        false,
3839                    )
3840                } else {
3841                    (
3842                        1,
3843                        q_shape.dim(0).unwrap_static(),
3844                        k_shape.dim(0).unwrap_static(),
3845                        false,
3846                    )
3847                };
3848                let mask_off = if matches!(
3849                    mask_kind,
3850                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3851                ) {
3852                    node_offset(arena, node.inputs[4])
3853                } else {
3854                    0
3855                };
3856                Thunk::AttentionBackward {
3857                    q: node_offset(arena, node.inputs[0]),
3858                    k: node_offset(arena, node.inputs[1]),
3859                    v: node_offset(arena, node.inputs[2]),
3860                    dy: node_offset(arena, node.inputs[3]),
3861                    mask: mask_off,
3862                    out: node_offset(arena, node.id),
3863                    batch: batch as u32,
3864                    seq: seq as u32,
3865                    kv_seq: kv_seq as u32,
3866                    heads: *num_heads as u32,
3867                    head_dim: *head_dim as u32,
3868                    mask_kind: *mask_kind,
3869                    wrt: *wrt,
3870                    bhsd,
3871                }
3872            }
3873
3874            Op::FusedAttentionBlock {
3875                num_heads,
3876                head_dim,
3877                has_bias,
3878                has_rope,
3879            } => {
3880                let x_shape = &graph.node(node.inputs[0]).shape;
3881                let (batch, seq) = if x_shape.rank() >= 3 {
3882                    (
3883                        x_shape.dim(0).unwrap_static(),
3884                        x_shape.dim(1).unwrap_static(),
3885                    )
3886                } else {
3887                    let total = x_shape.num_elements().unwrap();
3888                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3889                    (total / (s * num_heads * head_dim), s)
3890                };
3891                let hs = (*num_heads * *head_dim) as u32;
3892                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3893                let mut idx = 4;
3894                let (qkv_b_off, out_b_off) = if *has_bias {
3895                    let qb = node_offset(arena, node.inputs[idx]);
3896                    let ob = node_offset(arena, node.inputs[idx + 1]);
3897                    idx += 2;
3898                    (qb, ob)
3899                } else {
3900                    (0, 0)
3901                };
3902                let (cos_off, sin_off, cl) = if *has_rope {
3903                    let c = node_offset(arena, node.inputs[idx]);
3904                    let s = node_offset(arena, node.inputs[idx + 1]);
3905                    let clen = get_len(graph, node.inputs[idx]);
3906                    (c, s, clen as u32)
3907                } else {
3908                    (0, 0, 0)
3909                };
3910
3911                Thunk::FusedAttnBlock {
3912                    hidden: node_offset(arena, node.inputs[0]),
3913                    qkv_w: node_offset(arena, node.inputs[1]),
3914                    out_w: node_offset(arena, node.inputs[2]),
3915                    mask: node_offset(arena, node.inputs[3]),
3916                    out: node_offset(arena, node.id),
3917                    qkv_b: qkv_b_off,
3918                    out_b: out_b_off,
3919                    cos: cos_off,
3920                    sin: sin_off,
3921                    cos_len: cl,
3922                    batch: batch as u32,
3923                    seq: seq as u32,
3924                    hs,
3925                    nh: *num_heads as u32,
3926                    dh: *head_dim as u32,
3927                    has_bias: *has_bias,
3928                    has_rope: *has_rope,
3929                }
3930            }
3931
3932            Op::Rope { head_dim, n_rot } => {
3933                let x_shape = &graph.node(node.inputs[0]).shape;
3934                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3935                    (
3936                        x_shape.dim(0).unwrap_static(),
3937                        x_shape.dim(1).unwrap_static(),
3938                        x_shape.dim(2).unwrap_static(),
3939                    )
3940                } else {
3941                    let total = x_shape.num_elements().unwrap();
3942                    (
3943                        1,
3944                        x_shape.dim(0).unwrap_static(),
3945                        total / x_shape.dim(0).unwrap_static(),
3946                    )
3947                };
3948                let cos_len = get_len(graph, node.inputs[1]);
3949                Thunk::Rope {
3950                    src: node_offset(arena, node.inputs[0]),
3951                    cos: node_offset(arena, node.inputs[1]),
3952                    sin: node_offset(arena, node.inputs[2]),
3953                    dst: node_offset(arena, node.id),
3954                    batch: batch as u32,
3955                    seq: seq as u32,
3956                    hidden: hidden as u32,
3957                    head_dim: *head_dim as u32,
3958                    n_rot: *n_rot as u32,
3959                    cos_len: cos_len as u32,
3960                    // Default: source rows are tightly packed (rewritten
3961                    // by the Narrow→Rope fusion pass below if Rope ends
3962                    // up reading from a wider parent like QKV).
3963                    src_row_stride: hidden as u32,
3964                }
3965            }
3966
3967            Op::FusedSwiGLU {
3968                cast_to: _,
3969                gate_first,
3970            } => {
3971                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3972                let total = node.shape.num_elements().unwrap();
3973                Thunk::FusedSwiGLU {
3974                    src: node_offset(arena, node.inputs[0]),
3975                    dst: node_offset(arena, node.id),
3976                    n_half: n_half as u32,
3977                    total: total as u32,
3978                    gate_first: *gate_first,
3979                }
3980            }
3981
3982            Op::Conv {
3983                kernel_size,
3984                stride,
3985                padding,
3986                dilation,
3987                groups,
3988            } => {
3989                let in_shape = &graph.node(node.inputs[0]).shape;
3990                let w_shape = &graph.node(node.inputs[1]).shape;
3991                let out_shape = &node.shape;
3992                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3993                // padding=0, dilation=1, groups=1. Emits a single
3994                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3995                let is_1x1_simple = kernel_size.len() == 2
3996                    && kernel_size[0] == 1
3997                    && kernel_size[1] == 1
3998                    && stride.iter().all(|&s| s == 1)
3999                    && padding.iter().all(|&p| p == 0)
4000                    && dilation.iter().all(|&d| d == 1)
4001                    && *groups == 1;
4002                if is_1x1_simple
4003                    && in_shape.rank() >= 3
4004                    && out_shape.rank() >= 3
4005                    && w_shape.rank() >= 2
4006                {
4007                    let (n, c_in, h, w) = conv_nchw_dims(in_shape);
4008                    let (_, c_out, _, _) = conv_nchw_dims(out_shape);
4009                    Thunk::Conv2D1x1 {
4010                        src: node_offset(arena, node.inputs[0]),
4011                        weight: node_offset(arena, node.inputs[1]),
4012                        dst: node_offset(arena, node.id),
4013                        n,
4014                        c_in,
4015                        c_out,
4016                        hw: h.saturating_mul(w),
4017                    }
4018                } else if kernel_size.len() == 2
4019                    && in_shape.rank() >= 3
4020                    && w_shape.rank() >= 2
4021                    && out_shape.rank() >= 3
4022                {
4023                    let (n, c_in, h, w_in) = conv_nchw_dims(in_shape);
4024                    let (_, c_out, h_out, w_out) = conv_nchw_dims(out_shape);
4025                    // rlx lowers ONNX 1D convs as 2D NCHW with a unit H axis and the
4026                    // length in W (`[N,C,1,L]`), but keeps the length kernel/stride/pad/
4027                    // dilation at index 0 (`kernel=[k,1]`). A literal 2D conv would run
4028                    // the k-tap kernel over the singleton H axis and ignore the length.
4029                    // Since `[N,C,1,L]` and `[N,C,L,1]` share the same row-major layout,
4030                    // relabel the length onto the H axis (no data copy) so the kernel
4031                    // convolves it — matching the MLX 1D path and onnxruntime.
4032                    let one_d_w = h == 1
4033                        && w_in > 1
4034                        && kernel_size[0] > 1
4035                        && kernel_size.get(1).copied().unwrap_or(1) == 1;
4036                    let (h, w_in, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw) = if one_d_w {
4037                        (
4038                            w_in,
4039                            1,
4040                            w_out,
4041                            1,
4042                            kernel_size[0] as u32,
4043                            1,
4044                            stride.first().copied().unwrap_or(1) as u32,
4045                            1,
4046                            padding.first().copied().unwrap_or(0) as u32,
4047                            0,
4048                            dilation.first().copied().unwrap_or(1) as u32,
4049                            1,
4050                        )
4051                    } else {
4052                        (
4053                            h,
4054                            w_in,
4055                            h_out,
4056                            w_out,
4057                            kernel_size[0] as u32,
4058                            kernel_size[1] as u32,
4059                            stride.first().copied().unwrap_or(1) as u32,
4060                            stride.get(1).copied().unwrap_or(1) as u32,
4061                            padding.first().copied().unwrap_or(0) as u32,
4062                            padding.get(1).copied().unwrap_or(0) as u32,
4063                            dilation.first().copied().unwrap_or(1) as u32,
4064                            dilation.get(1).copied().unwrap_or(1) as u32,
4065                        )
4066                    };
4067                    Thunk::Conv2D {
4068                        src: node_offset(arena, node.inputs[0]),
4069                        weight: node_offset(arena, node.inputs[1]),
4070                        dst: node_offset(arena, node.id),
4071                        n,
4072                        c_in,
4073                        h,
4074                        w: w_in,
4075                        c_out,
4076                        h_out,
4077                        w_out,
4078                        kh,
4079                        kw,
4080                        sh,
4081                        sw,
4082                        ph,
4083                        pw,
4084                        dh,
4085                        dw,
4086                        groups: *groups as u32,
4087                    }
4088                } else {
4089                    Thunk::Nop
4090                }
4091            }
4092
4093            Op::Pool {
4094                kind,
4095                kernel_size,
4096                stride,
4097                padding,
4098            } => {
4099                // Currently support 2D pooling on rank-4 NCHW tensors.
4100                let in_shape = &graph.node(node.inputs[0]).shape;
4101                let out_shape = &node.shape;
4102                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
4103                    Thunk::Pool2D {
4104                        src: node_offset(arena, node.inputs[0]),
4105                        dst: node_offset(arena, node.id),
4106                        n: in_shape.dim(0).unwrap_static() as u32,
4107                        c: in_shape.dim(1).unwrap_static() as u32,
4108                        h: in_shape.dim(2).unwrap_static() as u32,
4109                        w: in_shape.dim(3).unwrap_static() as u32,
4110                        h_out: out_shape.dim(2).unwrap_static() as u32,
4111                        w_out: out_shape.dim(3).unwrap_static() as u32,
4112                        kh: kernel_size[0] as u32,
4113                        kw: kernel_size[1] as u32,
4114                        sh: stride.first().copied().unwrap_or(1) as u32,
4115                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4116                        ph: padding.first().copied().unwrap_or(0) as u32,
4117                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4118                        kind: *kind,
4119                    }
4120                } else {
4121                    Thunk::Nop
4122                }
4123            }
4124
4125            Op::Transpose { perm } => {
4126                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
4127                // runtime loop is just an N-D index walk + scatter.
4128                let in_shape = &graph.node(node.inputs[0]).shape;
4129                let in_rank = in_shape.rank();
4130                if perm.iter().any(|&p| p >= in_rank) {
4131                    Thunk::Nop
4132                } else {
4133                    let in_dims: Vec<usize> = (0..in_rank)
4134                        .map(|i| in_shape.dim(i).unwrap_static())
4135                        .collect();
4136                    // Row-major input strides: stride[d] = product of dims[d+1..].
4137                    let mut in_strides_full = vec![1usize; in_rank];
4138                    for d in (0..in_rank.saturating_sub(1)).rev() {
4139                        in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
4140                    }
4141                    let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
4142                    let in_strides: Vec<u32> =
4143                        perm.iter().map(|&p| in_strides_full[p] as u32).collect();
4144                    let in_total = in_dims.iter().product::<usize>() as u32;
4145                    let src = node_offset(arena, node.inputs[0]);
4146                    let dst = node_offset(arena, node.id);
4147                    let elem_bytes = node.shape.dtype().size_bytes() as u8;
4148                    match node.shape.dtype() {
4149                        rlx_ir::DType::F64 => Thunk::TransposeF64 {
4150                            src,
4151                            dst,
4152                            in_total,
4153                            out_dims,
4154                            in_strides,
4155                        },
4156                        _ => Thunk::Transpose {
4157                            src,
4158                            dst,
4159                            in_total,
4160                            out_dims,
4161                            in_strides,
4162                            elem_bytes,
4163                        },
4164                    }
4165                }
4166            }
4167
4168            Op::ScatterAdd => {
4169                // updates: [num_updates, ...trailing], indices: [num_updates],
4170                // output: [out_dim, ...trailing]
4171                let upd_shape = &graph.node(node.inputs[0]).shape;
4172                let out_shape = &node.shape;
4173                let num_updates = upd_shape.dim(0).unwrap_static();
4174                let out_dim = out_shape.dim(0).unwrap_static();
4175                let trailing: usize = (1..out_shape.rank())
4176                    .map(|i| out_shape.dim(i).unwrap_static())
4177                    .product::<usize>()
4178                    .max(1);
4179                Thunk::ScatterAdd {
4180                    updates: node_offset(arena, node.inputs[0]),
4181                    indices: node_offset(arena, node.inputs[1]),
4182                    dst: node_offset(arena, node.id),
4183                    num_updates: num_updates as u32,
4184                    out_dim: out_dim as u32,
4185                    trailing: trailing as u32,
4186                }
4187            }
4188
4189            Op::GroupedMatMul => {
4190                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
4191                let in_shape = &graph.node(node.inputs[0]).shape;
4192                let w_shape = &graph.node(node.inputs[1]).shape;
4193                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
4194                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
4195                let num_experts = w_shape.dim(0).unwrap_static();
4196                let n = w_shape.dim(2).unwrap_static();
4197                Thunk::GroupedMatMul {
4198                    input: node_offset(arena, node.inputs[0]),
4199                    weight: node_offset(arena, node.inputs[1]),
4200                    expert_idx: node_offset(arena, node.inputs[2]),
4201                    dst: node_offset(arena, node.id),
4202                    m: m as u32,
4203                    k_dim: k_dim as u32,
4204                    n: n as u32,
4205                    num_experts: num_experts as u32,
4206                }
4207            }
4208
4209            Op::DequantGroupedMatMul { scheme } => {
4210                let in_shape = &graph.node(node.inputs[0]).shape;
4211                let w_shape = &graph.node(node.inputs[1]).shape;
4212                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
4213                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
4214                let out_shape = &node.shape;
4215                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
4216                let block_elems = scheme.gguf_block_size() as usize;
4217                let block_bytes = scheme.gguf_block_bytes() as usize;
4218                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
4219                let total_bytes = w_shape.num_elements().unwrap();
4220                let num_experts = total_bytes / slab_bytes.max(1);
4221                Thunk::DequantGroupedMatMulGguf {
4222                    input: node_offset(arena, node.inputs[0]),
4223                    w_q: node_offset(arena, node.inputs[1]),
4224                    expert_idx: node_offset(arena, node.inputs[2]),
4225                    dst: node_offset(arena, node.id),
4226                    m: m as u32,
4227                    k_dim: k_dim as u32,
4228                    n: n as u32,
4229                    num_experts: num_experts as u32,
4230                    scheme: *scheme,
4231                }
4232            }
4233
4234            Op::DequantMoEWeights { scheme } => {
4235                let w_shape = &graph.node(node.inputs[0]).shape;
4236                let out_shape = &node.shape;
4237                let num_experts = out_shape.dim(0).unwrap_static();
4238                let k_dim = out_shape.dim(1).unwrap_static();
4239                let n = out_shape.dim(2).unwrap_static();
4240                let block_elems = scheme.gguf_block_size() as usize;
4241                let block_bytes = scheme.gguf_block_bytes() as usize;
4242                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
4243                let total_bytes = w_shape.num_elements().unwrap();
4244                assert_eq!(
4245                    total_bytes,
4246                    num_experts * slab_bytes,
4247                    "DequantMoEWeights packed bytes mismatch"
4248                );
4249                Thunk::DequantMoEWeightsGguf {
4250                    w_q: node_offset(arena, node.inputs[0]),
4251                    dst: node_offset(arena, node.id),
4252                    k_dim: k_dim as u32,
4253                    n: n as u32,
4254                    num_experts: num_experts as u32,
4255                    scheme: *scheme,
4256                }
4257            }
4258
4259            Op::TopK { k } => {
4260                let in_shape = &graph.node(node.inputs[0]).shape;
4261                let rank = in_shape.rank();
4262                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
4263                let outer = in_shape.num_elements().unwrap() / axis_dim;
4264                let indices_i64 = u8::from(graph.node(node.id).shape.dtype() == rlx_ir::DType::I64);
4265                Thunk::TopK {
4266                    src: node_offset(arena, node.inputs[0]),
4267                    dst: node_offset(arena, node.id),
4268                    outer: outer as u32,
4269                    axis_dim: axis_dim as u32,
4270                    k: *k as u32,
4271                    indices_i64,
4272                }
4273            }
4274
4275            Op::Reduce {
4276                op,
4277                axes,
4278                keep_dim: _,
4279            } => {
4280                // Decompose the input shape into [outer, reduced, inner]
4281                // around the reduced axis range. Non-contiguous reduced
4282                // axes aren't supported here — caller must transpose them
4283                // contiguous first (the coverage tool would surface the
4284                // gap if a model needs it).
4285                let in_shape = &graph.node(node.inputs[0]).shape;
4286                let rank = in_shape.rank();
4287                let mut sorted = axes.clone();
4288                sorted.sort();
4289                sorted.dedup();
4290                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
4291                    && !sorted.is_empty()
4292                    && *sorted.last().unwrap() < rank;
4293                if !contiguous {
4294                    Thunk::Nop
4295                } else {
4296                    let first = sorted[0];
4297                    let last = *sorted.last().unwrap();
4298                    let outer: usize = (0..first)
4299                        .map(|i| in_shape.dim(i).unwrap_static())
4300                        .product::<usize>()
4301                        .max(1);
4302                    let reduced: usize = (first..=last)
4303                        .map(|i| in_shape.dim(i).unwrap_static())
4304                        .product();
4305                    let inner: usize = (last + 1..rank)
4306                        .map(|i| in_shape.dim(i).unwrap_static())
4307                        .product::<usize>()
4308                        .max(1);
4309                    let src = node_offset(arena, node.inputs[0]);
4310                    let dst = node_offset(arena, node.id);
4311                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
4312                        Thunk::ReduceSumF64 {
4313                            src,
4314                            dst,
4315                            outer: outer as u32,
4316                            reduced: reduced as u32,
4317                            inner: inner as u32,
4318                        }
4319                    } else {
4320                        Thunk::Reduce {
4321                            src,
4322                            dst,
4323                            outer: outer as u32,
4324                            reduced: reduced as u32,
4325                            inner: inner as u32,
4326                            op: *op,
4327                        }
4328                    }
4329                }
4330            }
4331
4332            Op::ArgMax { axis, keep_dim: _ } | Op::ArgMin { axis, keep_dim: _ } => {
4333                let in_shape = &graph.node(node.inputs[0]).shape;
4334                let rank = in_shape.rank();
4335                let outer: usize = (0..*axis)
4336                    .map(|i| in_shape.dim(i).unwrap_static())
4337                    .product::<usize>()
4338                    .max(1);
4339                let reduced = in_shape.dim(*axis).unwrap_static();
4340                let inner: usize = (*axis + 1..rank)
4341                    .map(|i| in_shape.dim(i).unwrap_static())
4342                    .product::<usize>()
4343                    .max(1);
4344                Thunk::ArgReduce {
4345                    src: node_offset(arena, node.inputs[0]),
4346                    dst: node_offset(arena, node.id),
4347                    outer: outer as u32,
4348                    reduced: reduced as u32,
4349                    inner: inner as u32,
4350                    is_max: matches!(node.op, Op::ArgMax { .. }),
4351                }
4352            }
4353
4354            Op::Compare(cmp) => {
4355                let len = node.shape.num_elements().unwrap();
4356                let in_dtype = graph.node(node.inputs[0]).shape.dtype();
4357                let inputs_i64 = u8::from(in_dtype == rlx_ir::DType::I64);
4358                Thunk::Compare {
4359                    lhs: node_offset(arena, node.inputs[0]),
4360                    rhs: node_offset(arena, node.inputs[1]),
4361                    dst: node_offset(arena, node.id),
4362                    len: len as u32,
4363                    op: *cmp,
4364                    inputs_i64,
4365                    inputs_elem_bytes: in_dtype.size_bytes() as u8,
4366                    dst_elem_bytes: node.shape.dtype().size_bytes() as u8,
4367                }
4368            }
4369
4370            Op::Where => {
4371                let len = node.shape.num_elements().unwrap();
4372                let elem_bytes = node.shape.dtype().size_bytes() as u8;
4373                let cond_elem_bytes = graph.node(node.inputs[0]).shape.dtype().size_bytes() as u8;
4374                Thunk::Where {
4375                    cond: node_offset(arena, node.inputs[0]),
4376                    on_true: node_offset(arena, node.inputs[1]),
4377                    on_false: node_offset(arena, node.inputs[2]),
4378                    dst: node_offset(arena, node.id),
4379                    len: len as u32,
4380                    elem_bytes,
4381                    cond_elem_bytes,
4382                }
4383            }
4384
4385            Op::ReluBackward => {
4386                let len: usize = (0..node.shape.rank())
4387                    .map(|i| node.shape.dim(i).unwrap_static())
4388                    .product();
4389                let x = node_offset(arena, node.inputs[0]);
4390                let dy = node_offset(arena, node.inputs[1]);
4391                let dx = node_offset(arena, node.id);
4392                match node.shape.dtype() {
4393                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
4394                        x,
4395                        dy,
4396                        dx,
4397                        len: len as u32,
4398                    },
4399                    _ => Thunk::ReluBackward {
4400                        x,
4401                        dy,
4402                        dx,
4403                        len: len as u32,
4404                    },
4405                }
4406            }
4407
4408            Op::ComplexNormSq => {
4409                let len: usize = (0..node.shape.rank())
4410                    .map(|i| node.shape.dim(i).unwrap_static())
4411                    .product();
4412                let src = node_offset(arena, node.inputs[0]);
4413                let dst = node_offset(arena, node.id);
4414                Thunk::ComplexNormSqF32 {
4415                    src,
4416                    dst,
4417                    len: len as u32,
4418                }
4419            }
4420
4421            Op::ComplexNormSqBackward => {
4422                let len: usize = (0..node.shape.rank())
4423                    .map(|i| node.shape.dim(i).unwrap_static())
4424                    .product();
4425                let z = node_offset(arena, node.inputs[0]);
4426                let g = node_offset(arena, node.inputs[1]);
4427                let dz = node_offset(arena, node.id);
4428                Thunk::ComplexNormSqBackwardF32 {
4429                    z,
4430                    g,
4431                    dz,
4432                    len: len as u32,
4433                }
4434            }
4435
4436            Op::Conjugate => {
4437                let len: usize = (0..node.shape.rank())
4438                    .map(|i| node.shape.dim(i).unwrap_static())
4439                    .product();
4440                Thunk::ConjugateC64 {
4441                    src: node_offset(arena, node.inputs[0]),
4442                    dst: node_offset(arena, node.id),
4443                    len: len as u32,
4444                }
4445            }
4446
4447            Op::ActivationBackward { kind } => {
4448                let len: usize = (0..node.shape.rank())
4449                    .map(|i| node.shape.dim(i).unwrap_static())
4450                    .product();
4451                let x = node_offset(arena, node.inputs[0]);
4452                let dy = node_offset(arena, node.inputs[1]);
4453                let dx = node_offset(arena, node.id);
4454                match node.shape.dtype() {
4455                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
4456                        x,
4457                        dy,
4458                        dx,
4459                        len: len as u32,
4460                        kind: *kind,
4461                    },
4462                    _ => Thunk::ActivationBackward {
4463                        x,
4464                        dy,
4465                        dx,
4466                        len: len as u32,
4467                        kind: *kind,
4468                    },
4469                }
4470            }
4471
4472            Op::LayerNormBackwardInput { eps, .. } => {
4473                // axis = -1 only (matches forward LayerNorm thunk).
4474                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
4475                let total = node.shape.num_elements().unwrap();
4476                Thunk::LayerNormBackwardInput {
4477                    x: node_offset(arena, node.inputs[0]),
4478                    gamma: node_offset(arena, node.inputs[1]),
4479                    dy: node_offset(arena, node.inputs[2]),
4480                    dx: node_offset(arena, node.id),
4481                    rows: (total / h) as u32,
4482                    h: h as u32,
4483                    eps: *eps,
4484                }
4485            }
4486
4487            Op::LayerNormBackwardGamma { eps, .. } => {
4488                let x_shape = &graph.node(node.inputs[0]).shape;
4489                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4490                let x_total = x_shape.num_elements().unwrap();
4491                Thunk::LayerNormBackwardGamma {
4492                    x: node_offset(arena, node.inputs[0]),
4493                    dy: node_offset(arena, node.inputs[1]),
4494                    dgamma: node_offset(arena, node.id),
4495                    rows: (x_total / h) as u32,
4496                    h: h as u32,
4497                    eps: *eps,
4498                }
4499            }
4500
4501            Op::RmsNormBackwardInput { eps, .. }
4502            | Op::RmsNormBackwardGamma { eps, .. }
4503            | Op::RmsNormBackwardBeta { eps, .. } => {
4504                let x_shape = &graph.node(node.inputs[0]).shape;
4505                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
4506                let rows = (x_shape.num_elements().unwrap() / h) as u32;
4507                let off = |i: usize| node_offset(arena, node.inputs[i]);
4508                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
4509                match &node.op {
4510                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
4511                        x: common.0,
4512                        gamma: common.1,
4513                        beta: common.2,
4514                        dy: common.3,
4515                        dx: node_offset(arena, node.id),
4516                        rows: common.4,
4517                        h: common.5,
4518                        eps: common.6,
4519                    },
4520                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
4521                        x: common.0,
4522                        gamma: common.1,
4523                        beta: common.2,
4524                        dy: common.3,
4525                        dgamma: node_offset(arena, node.id),
4526                        rows: common.4,
4527                        h: common.5,
4528                        eps: common.6,
4529                    },
4530                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
4531                        x: common.0,
4532                        gamma: common.1,
4533                        beta: common.2,
4534                        dy: common.3,
4535                        dbeta: node_offset(arena, node.id),
4536                        rows: common.4,
4537                        h: common.5,
4538                        eps: common.6,
4539                    },
4540                    _ => unreachable!(),
4541                }
4542            }
4543
4544            Op::RopeBackward { head_dim, n_rot } => {
4545                let dy_shape = &graph.node(node.inputs[0]).shape;
4546                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4547                    (
4548                        dy_shape.dim(0).unwrap_static(),
4549                        dy_shape.dim(1).unwrap_static(),
4550                        dy_shape.dim(2).unwrap_static(),
4551                    )
4552                } else {
4553                    (
4554                        1,
4555                        dy_shape.dim(0).unwrap_static(),
4556                        dy_shape.dim(1).unwrap_static(),
4557                    )
4558                };
4559                let cos_shape = &graph.node(node.inputs[1]).shape;
4560                let cos_len = cos_shape.num_elements().unwrap();
4561                Thunk::RopeBackward {
4562                    dy: node_offset(arena, node.inputs[0]),
4563                    cos: node_offset(arena, node.inputs[1]),
4564                    sin: node_offset(arena, node.inputs[2]),
4565                    dx: node_offset(arena, node.id),
4566                    batch: batch as u32,
4567                    seq: seq as u32,
4568                    hidden: hidden as u32,
4569                    head_dim: *head_dim as u32,
4570                    n_rot: *n_rot as u32,
4571                    cos_len: cos_len as u32,
4572                }
4573            }
4574
4575            Op::CumsumBackward { exclusive, .. } => {
4576                let dy_shape = &graph.node(node.inputs[0]).shape;
4577                let rank = dy_shape.rank();
4578                let cols = dy_shape.dim(rank - 1).unwrap_static();
4579                let rows = dy_shape.num_elements().unwrap() / cols;
4580                Thunk::CumsumBackward {
4581                    dy: node_offset(arena, node.inputs[0]),
4582                    dx: node_offset(arena, node.id),
4583                    rows: rows as u32,
4584                    cols: cols as u32,
4585                    exclusive: *exclusive,
4586                }
4587            }
4588
4589            Op::GatherBackward { .. } => {
4590                let dy_shape = &graph.node(node.inputs[0]).shape;
4591                let idx_shape = &graph.node(node.inputs[1]).shape;
4592                let out_shape = &node.shape;
4593                let rank = out_shape.rank();
4594                let axis = match &node.op {
4595                    Op::GatherBackward { axis } => *axis,
4596                    _ => 0,
4597                };
4598                let axis_u = if axis < 0 {
4599                    (rank as i32 + axis) as usize
4600                } else {
4601                    axis as usize
4602                };
4603                let outer: usize = (0..axis_u)
4604                    .map(|i| dy_shape.dim(i).unwrap_static())
4605                    .product::<usize>()
4606                    .max(1);
4607                let num_idx = idx_shape.dim(axis_u).unwrap_static();
4608                let trailing: usize = (axis_u + 1..dy_shape.rank())
4609                    .map(|i| dy_shape.dim(i).unwrap_static())
4610                    .product::<usize>()
4611                    .max(1);
4612                let axis_dim = out_shape.dim(axis_u).unwrap_static();
4613                Thunk::GatherBackward {
4614                    dy: node_offset(arena, node.inputs[0]),
4615                    indices: node_offset(arena, node.inputs[1]),
4616                    dst: node_offset(arena, node.id),
4617                    outer: outer as u32,
4618                    axis_dim: axis_dim as u32,
4619                    num_idx: num_idx as u32,
4620                    trailing: trailing as u32,
4621                }
4622            }
4623
4624            Op::GroupNormBackwardInput { num_groups, eps }
4625            | Op::GroupNormBackwardGamma { num_groups, eps }
4626            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4627                let x_shape = &graph.node(node.inputs[0]).shape;
4628                let n = x_shape.dim(0).unwrap_static() as u32;
4629                let c = x_shape.dim(1).unwrap_static() as u32;
4630                let h = x_shape.dim(2).unwrap_static() as u32;
4631                let w = x_shape.dim(3).unwrap_static() as u32;
4632                match &node.op {
4633                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4634                        x: node_offset(arena, node.inputs[0]),
4635                        gamma: node_offset(arena, node.inputs[1]),
4636                        beta: node_offset(arena, node.inputs[2]),
4637                        dy: node_offset(arena, node.inputs[3]),
4638                        dx: node_offset(arena, node.id),
4639                        n,
4640                        c,
4641                        h,
4642                        w,
4643                        num_groups: *num_groups as u32,
4644                        eps: *eps,
4645                    },
4646                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4647                        x: node_offset(arena, node.inputs[0]),
4648                        dy: node_offset(arena, node.inputs[1]),
4649                        dgamma: node_offset(arena, node.id),
4650                        n,
4651                        c,
4652                        h,
4653                        w,
4654                        num_groups: *num_groups as u32,
4655                        eps: *eps,
4656                    },
4657                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4658                        dy: node_offset(arena, node.inputs[1]),
4659                        dbeta: node_offset(arena, node.id),
4660                        n,
4661                        c,
4662                        h,
4663                        w,
4664                    },
4665                    _ => unreachable!(),
4666                }
4667            }
4668
4669            Op::MaxPool2dBackward {
4670                kernel_size,
4671                stride,
4672                padding,
4673            } => {
4674                let x_shape = &graph.node(node.inputs[0]).shape;
4675                let dy_shape = &graph.node(node.inputs[1]).shape;
4676                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4677                    Thunk::MaxPool2dBackward {
4678                        x: node_offset(arena, node.inputs[0]),
4679                        dy: node_offset(arena, node.inputs[1]),
4680                        dx: node_offset(arena, node.id),
4681                        n: x_shape.dim(0).unwrap_static() as u32,
4682                        c: x_shape.dim(1).unwrap_static() as u32,
4683                        h: x_shape.dim(2).unwrap_static() as u32,
4684                        w: x_shape.dim(3).unwrap_static() as u32,
4685                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4686                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4687                        kh: kernel_size[0] as u32,
4688                        kw: kernel_size[1] as u32,
4689                        sh: stride.first().copied().unwrap_or(1) as u32,
4690                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4691                        ph: padding.first().copied().unwrap_or(0) as u32,
4692                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4693                    }
4694                } else {
4695                    Thunk::Nop
4696                }
4697            }
4698
4699            Op::Conv2dBackwardInput {
4700                kernel_size,
4701                stride,
4702                padding,
4703                dilation,
4704                groups,
4705            } => {
4706                let dy_shape = &graph.node(node.inputs[0]).shape;
4707                let w_shape = &graph.node(node.inputs[1]).shape;
4708                let out_shape = &node.shape;
4709                if kernel_size.len() == 2
4710                    && dy_shape.rank() == 4
4711                    && w_shape.rank() == 4
4712                    && out_shape.rank() == 4
4713                {
4714                    Thunk::Conv2dBackwardInput {
4715                        dy: node_offset(arena, node.inputs[0]),
4716                        w: node_offset(arena, node.inputs[1]),
4717                        dx: node_offset(arena, node.id),
4718                        n: out_shape.dim(0).unwrap_static() as u32,
4719                        c_in: out_shape.dim(1).unwrap_static() as u32,
4720                        h: out_shape.dim(2).unwrap_static() as u32,
4721                        w_in: out_shape.dim(3).unwrap_static() as u32,
4722                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4723                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4724                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4725                        kh: kernel_size[0] as u32,
4726                        kw: kernel_size[1] as u32,
4727                        sh: stride.first().copied().unwrap_or(1) as u32,
4728                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4729                        ph: padding.first().copied().unwrap_or(0) as u32,
4730                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4731                        dh: dilation.first().copied().unwrap_or(1) as u32,
4732                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4733                        groups: *groups as u32,
4734                    }
4735                } else {
4736                    Thunk::Nop
4737                }
4738            }
4739
4740            Op::Conv2dBackwardWeight {
4741                kernel_size,
4742                stride,
4743                padding,
4744                dilation,
4745                groups,
4746            } => {
4747                let x_shape = &graph.node(node.inputs[0]).shape;
4748                let dy_shape = &graph.node(node.inputs[1]).shape;
4749                let dw_shape = &node.shape;
4750                if kernel_size.len() == 2
4751                    && x_shape.rank() == 4
4752                    && dy_shape.rank() == 4
4753                    && dw_shape.rank() == 4
4754                {
4755                    Thunk::Conv2dBackwardWeight {
4756                        x: node_offset(arena, node.inputs[0]),
4757                        dy: node_offset(arena, node.inputs[1]),
4758                        dw: node_offset(arena, node.id),
4759                        n: x_shape.dim(0).unwrap_static() as u32,
4760                        c_in: x_shape.dim(1).unwrap_static() as u32,
4761                        h: x_shape.dim(2).unwrap_static() as u32,
4762                        w: x_shape.dim(3).unwrap_static() as u32,
4763                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4764                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4765                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4766                        kh: kernel_size[0] as u32,
4767                        kw: kernel_size[1] as u32,
4768                        sh: stride.first().copied().unwrap_or(1) as u32,
4769                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4770                        ph: padding.first().copied().unwrap_or(0) as u32,
4771                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4772                        dh: dilation.first().copied().unwrap_or(1) as u32,
4773                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4774                        groups: *groups as u32,
4775                    }
4776                } else {
4777                    Thunk::Nop
4778                }
4779            }
4780
4781            Op::Im2Col {
4782                kernel_size,
4783                stride,
4784                padding,
4785                dilation,
4786            } => {
4787                let x_shape = &graph.node(node.inputs[0]).shape;
4788                let out_shape = &node.shape;
4789                if kernel_size.len() == 2 && x_shape.rank() == 4 && out_shape.rank() == 2 {
4790                    let n = match x_shape.dim(0) {
4791                        rlx_ir::shape::Dim::Static(v) => v as u32,
4792                        _ => 0,
4793                    };
4794                    let c_in = x_shape.dim(1).unwrap_static() as u32;
4795                    let h = x_shape.dim(2).unwrap_static() as u32;
4796                    let w = x_shape.dim(3).unwrap_static() as u32;
4797                    let kh = kernel_size[0] as u32;
4798                    let kw = kernel_size[1] as u32;
4799                    let sh = stride.first().copied().unwrap_or(1) as u32;
4800                    let sw = stride.get(1).copied().unwrap_or(1) as u32;
4801                    let ph = padding.first().copied().unwrap_or(0) as u32;
4802                    let pw = padding.get(1).copied().unwrap_or(0) as u32;
4803                    let dh = dilation.first().copied().unwrap_or(1) as u32;
4804                    let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4805                    let h_out = rlx_ir::shape::conv2d_spatial_output(
4806                        h as usize,
4807                        kh as usize,
4808                        sh as usize,
4809                        ph as usize,
4810                        dh as usize,
4811                    ) as u32;
4812                    let w_out = rlx_ir::shape::conv2d_spatial_output(
4813                        w as usize,
4814                        kw as usize,
4815                        sw as usize,
4816                        pw as usize,
4817                        dw_dil as usize,
4818                    ) as u32;
4819                    Thunk::Im2Col {
4820                        x: node_offset(arena, node.inputs[0]),
4821                        col: node_offset(arena, node.id),
4822                        n,
4823                        c_in,
4824                        h,
4825                        w,
4826                        h_out,
4827                        w_out,
4828                        kh,
4829                        kw,
4830                        sh,
4831                        sw,
4832                        ph,
4833                        pw,
4834                        dh,
4835                        dw_dil,
4836                    }
4837                } else {
4838                    Thunk::Nop
4839                }
4840            }
4841
4842            Op::SoftmaxCrossEntropyWithLogits => {
4843                let logits_shape = &graph.node(node.inputs[0]).shape;
4844                if logits_shape.rank() == 2 {
4845                    Thunk::SoftmaxCrossEntropy {
4846                        logits: node_offset(arena, node.inputs[0]),
4847                        labels: node_offset(arena, node.inputs[1]),
4848                        dst: node_offset(arena, node.id),
4849                        n: logits_shape.dim(0).unwrap_static() as u32,
4850                        c: logits_shape.dim(1).unwrap_static() as u32,
4851                    }
4852                } else {
4853                    Thunk::Nop
4854                }
4855            }
4856
4857            Op::SoftmaxCrossEntropyBackward => {
4858                let logits_shape = &graph.node(node.inputs[0]).shape;
4859                if logits_shape.rank() == 2 {
4860                    Thunk::SoftmaxCrossEntropyBackward {
4861                        logits: node_offset(arena, node.inputs[0]),
4862                        labels: node_offset(arena, node.inputs[1]),
4863                        d_loss: node_offset(arena, node.inputs[2]),
4864                        dlogits: node_offset(arena, node.id),
4865                        n: logits_shape.dim(0).unwrap_static() as u32,
4866                        c: logits_shape.dim(1).unwrap_static() as u32,
4867                    }
4868                } else {
4869                    Thunk::Nop
4870                }
4871            }
4872
4873            Op::DenseSolve => {
4874                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4875                let a_shape = &graph.node(node.inputs[0]).shape;
4876                let n = a_shape.dim(0).unwrap_static();
4877                debug_assert_eq!(
4878                    n,
4879                    a_shape.dim(1).unwrap_static(),
4880                    "DenseSolve: A must be square"
4881                );
4882                let b_elems = node.shape.num_elements().unwrap();
4883                let nrhs = b_elems / n;
4884                match node.shape.dtype() {
4885                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4886                        a: node_offset(arena, node.inputs[0]),
4887                        b: node_offset(arena, node.inputs[1]),
4888                        x: node_offset(arena, node.id),
4889                        n: n as u32,
4890                        nrhs: nrhs as u32,
4891                    },
4892                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4893                        a: node_offset(arena, node.inputs[0]),
4894                        b: node_offset(arena, node.inputs[1]),
4895                        x: node_offset(arena, node.id),
4896                        n: n as u32,
4897                        nrhs: nrhs as u32,
4898                    },
4899                    other => panic!(
4900                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4901                         Add another variant when needed."
4902                    ),
4903                }
4904            }
4905
4906            Op::BatchedDenseSolve => {
4907                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4908                let a_shape = &graph.node(node.inputs[0]).shape;
4909                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4910                let batch = a_shape.dim(0).unwrap_static();
4911                let n = a_shape.dim(1).unwrap_static();
4912                debug_assert_eq!(
4913                    n,
4914                    a_shape.dim(2).unwrap_static(),
4915                    "BatchedDenseSolve: A's last two dims must match"
4916                );
4917                let total = node.shape.num_elements().unwrap();
4918                let nrhs = total / (batch * n);
4919                match node.shape.dtype() {
4920                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4921                        a: node_offset(arena, node.inputs[0]),
4922                        b: node_offset(arena, node.inputs[1]),
4923                        x: node_offset(arena, node.id),
4924                        batch: batch as u32,
4925                        n: n as u32,
4926                        nrhs: nrhs as u32,
4927                    },
4928                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4929                        a: node_offset(arena, node.inputs[0]),
4930                        b: node_offset(arena, node.inputs[1]),
4931                        x: node_offset(arena, node.id),
4932                        batch: batch as u32,
4933                        n: n as u32,
4934                        nrhs: nrhs as u32,
4935                    },
4936                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4937                }
4938            }
4939
4940            Op::Scan {
4941                body,
4942                length,
4943                save_trajectory,
4944                num_bcast,
4945                num_xs,
4946                num_checkpoints,
4947            } => {
4948                assert!(
4949                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4950                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4951                    *num_checkpoints,
4952                    *length
4953                );
4954                if *num_checkpoints != 0 && *num_checkpoints != *length {
4955                    assert!(
4956                        *save_trajectory,
4957                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4958                    );
4959                }
4960                // Plan + compile the body sub-graph standalone. The body
4961                // gets its own Arena; per execution we clone its
4962                // pristine bytes, copy the outer carry (and per-step xs
4963                // slices, if any) into the body's Input slots, run the
4964                // body schedule N times, then copy the body's output
4965                // back to the outer arena.
4966                //
4967                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4968                // — first declared is the carry, rest are x_t_i. Single
4969                // graph output (the next carry), same shape as carry.
4970                let body_plan = rlx_opt::memory::plan_memory(body);
4971                let _body_arena_size = body_plan.arena_size;
4972                // Snapshot per-input byte offsets before plan_memory
4973                // moves into the Arena below.
4974                let body_offsets: HashMap<NodeId, usize> = body_plan
4975                    .assignments
4976                    .iter()
4977                    .map(|(id, slot)| (*id, slot.offset))
4978                    .collect();
4979
4980                // Collect body Input nodes in NodeId order; first is
4981                // carry, rest are per-step xs in matching order.
4982                let mut body_inputs: Vec<NodeId> = body
4983                    .nodes()
4984                    .iter()
4985                    .filter(|n| matches!(n.op, Op::Input { .. }))
4986                    .map(|n| n.id)
4987                    .collect();
4988                body_inputs.sort();
4989                let n_body_inputs = body_inputs.len();
4990                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4991                if n_body_inputs != expected {
4992                    let names: Vec<String> = body
4993                        .nodes()
4994                        .iter()
4995                        .filter_map(|n| match &n.op {
4996                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4997                            _ => None,
4998                        })
4999                        .collect();
5000                    panic!(
5001                        "Op::Scan body has {} Op::Input nodes; expected {} \
5002                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
5003                        n_body_inputs,
5004                        expected,
5005                        *num_bcast,
5006                        *num_xs,
5007                        names.join(", ")
5008                    );
5009                }
5010
5011                let body_input_id = body_inputs[0];
5012                let body_input_off = body_offsets[&body_input_id];
5013                let body_output_id = body
5014                    .outputs
5015                    .first()
5016                    .copied()
5017                    .expect("Op::Scan body must declare one output");
5018                let body_output_off = body_offsets[&body_output_id];
5019
5020                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5021                // Fill body Constant nodes — mirror the outer-graph logic
5022                // in rlx-runtime/src/backend.rs (dtype-aware).
5023                for n in body.nodes() {
5024                    if let Op::Constant { data } = &n.op
5025                        && body_arena.has_buffer(n.id)
5026                        && !data.is_empty()
5027                    {
5028                        match n.shape.dtype() {
5029                            rlx_ir::DType::F64 => {
5030                                let off = body_arena.byte_offset(n.id);
5031                                let buf = body_arena.raw_buf_mut();
5032                                let nbytes = (buf.len() - off).min(data.len());
5033                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
5034                            }
5035                            _ => {
5036                                let buf = body_arena.slice_mut(n.id);
5037                                let n_floats = data.len() / 4;
5038                                let n_lim = buf.len().min(n_floats);
5039                                for i in 0..n_lim {
5040                                    let bytes = [
5041                                        data[i * 4],
5042                                        data[i * 4 + 1],
5043                                        data[i * 4 + 2],
5044                                        data[i * 4 + 3],
5045                                    ];
5046                                    buf[i] = f32::from_le_bytes(bytes);
5047                                }
5048                            }
5049                        }
5050                    }
5051                }
5052                let body_init = body_arena.raw_buf().to_vec();
5053                let body_schedule = compile_thunks_with_rng(body, &body_arena, rng);
5054
5055                // Carry bytes — for trajectory mode, the outer node's
5056                // shape is [length, *carry_shape], so dividing by length
5057                // gives one row's bytes; the body's input slot still
5058                // holds carry_shape bytes.
5059                let carry_bytes = if *save_trajectory {
5060                    let total = node
5061                        .shape
5062                        .size_bytes()
5063                        .expect("Op::Scan trajectory output must have static shape");
5064                    total / *length as usize
5065                } else {
5066                    node.shape
5067                        .size_bytes()
5068                        .expect("Op::Scan carry must have static shape")
5069                };
5070
5071                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
5072                // outer node.inputs[1..1+num_bcast]. They keep their
5073                // natural shape (no [length, ...] prefix) and are
5074                // copied into body_buf ONCE before the scan loop.
5075                let mut bcast_inputs: Vec<(usize, usize, u32)> =
5076                    Vec::with_capacity(*num_bcast as usize);
5077                for i in 0..*num_bcast as usize {
5078                    let body_b_id = body_inputs[1 + i];
5079                    let body_b_off = body_offsets[&body_b_id];
5080                    let outer_b_id = node.inputs[1 + i];
5081                    let outer_b_off = node_offset(arena, outer_b_id);
5082                    let outer_b_shape = &graph.node(outer_b_id).shape;
5083                    let total = outer_b_shape
5084                        .size_bytes()
5085                        .expect("Op::Scan bcast must have static shape");
5086                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
5087                }
5088
5089                // xs occupy body_inputs[1+num_bcast..] and node.inputs
5090                // [1+num_bcast..]. Each has shape [length, *per_step];
5091                // per-step bytes = total / length.
5092                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
5093                let xs_base = 1 + *num_bcast as usize;
5094                for i in 0..*num_xs as usize {
5095                    let body_x_id = body_inputs[xs_base + i];
5096                    let body_x_off = body_offsets[&body_x_id];
5097                    let outer_xs_id = node.inputs[xs_base + i];
5098                    let outer_xs_off = node_offset(arena, outer_xs_id);
5099                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
5100                    let total = outer_xs_shape
5101                        .size_bytes()
5102                        .expect("Op::Scan xs must have static shape");
5103                    let per_step = total / *length as usize;
5104                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
5105                }
5106
5107                Thunk::Scan {
5108                    body: Arc::new(body_schedule),
5109                    body_init: Arc::new(body_init),
5110                    body_input_off,
5111                    body_output_off,
5112                    outer_init_off: node_offset(arena, node.inputs[0]),
5113                    outer_final_off: node_offset(arena, node.id),
5114                    length: *length,
5115                    carry_bytes: carry_bytes as u32,
5116                    save_trajectory: *save_trajectory,
5117                    xs_inputs: Arc::new(xs_inputs),
5118                    bcast_inputs: Arc::new(bcast_inputs),
5119                    num_checkpoints: *num_checkpoints,
5120                }
5121            }
5122
5123            Op::ScanBackward {
5124                body_vjp,
5125                length,
5126                save_trajectory,
5127                num_xs,
5128                num_checkpoints,
5129                forward_body,
5130            } => {
5131                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5132                if is_recursive {
5133                    assert!(
5134                        forward_body.is_some(),
5135                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
5136                    );
5137                }
5138                // body_vjp has signature
5139                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
5140                // Identify slots:
5141                //   * "d_output" by exact name (AD-introduced seed Input).
5142                //   * Remaining Inputs sorted by NodeId — first is the
5143                //     carry mirror, rest are x_t_i mirrors in body's
5144                //     original Op::Input declaration order.
5145                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5146                let body_offsets: HashMap<NodeId, usize> = body_plan
5147                    .assignments
5148                    .iter()
5149                    .map(|(id, slot)| (*id, slot.offset))
5150                    .collect();
5151                let mut body_d_output_off: Option<usize> = None;
5152                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5153                for n in body_vjp.nodes() {
5154                    if let Op::Input { name } = &n.op {
5155                        let off = body_offsets[&n.id];
5156                        if name == "d_output" {
5157                            body_d_output_off = Some(off);
5158                        } else {
5159                            body_other_inputs.push((n.id, off));
5160                        }
5161                    }
5162                }
5163                body_other_inputs.sort_by_key(|(id, _)| *id);
5164                let body_d_output_off =
5165                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
5166                let expected_others = 1 + *num_xs as usize;
5167                assert_eq!(
5168                    body_other_inputs.len(),
5169                    expected_others,
5170                    "ScanBackward body_vjp has {} non-d_output Inputs; \
5171                     expected {} (1 carry + {} xs)",
5172                    body_other_inputs.len(),
5173                    expected_others,
5174                    num_xs
5175                );
5176                let body_carry_in_off = body_other_inputs[0].1;
5177                let body_x_offs: Vec<usize> = body_other_inputs
5178                    .iter()
5179                    .skip(1)
5180                    .map(|(_, off)| *off)
5181                    .collect();
5182                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5183
5184                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5185                // Fill body_vjp's Constants (mirrors the Scan lowering).
5186                for n in body_vjp.nodes() {
5187                    if let Op::Constant { data } = &n.op
5188                        && body_arena.has_buffer(n.id)
5189                        && !data.is_empty()
5190                    {
5191                        match n.shape.dtype() {
5192                            rlx_ir::DType::F64 => {
5193                                let off = body_arena.byte_offset(n.id);
5194                                let buf = body_arena.raw_buf_mut();
5195                                let nb = (buf.len() - off).min(data.len());
5196                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5197                            }
5198                            _ => {
5199                                let buf = body_arena.slice_mut(n.id);
5200                                let nf = data.len() / 4;
5201                                let nl = buf.len().min(nf);
5202                                for i in 0..nl {
5203                                    let bytes = [
5204                                        data[i * 4],
5205                                        data[i * 4 + 1],
5206                                        data[i * 4 + 2],
5207                                        data[i * 4 + 3],
5208                                    ];
5209                                    buf[i] = f32::from_le_bytes(bytes);
5210                                }
5211                            }
5212                        }
5213                    }
5214                }
5215                let body_init = body_arena.raw_buf().to_vec();
5216                let body_schedule = compile_thunks_with_rng(body_vjp, &body_arena, rng);
5217
5218                // Carry bytes from the dcarry output node (== carry shape).
5219                let carry_bytes = body_vjp
5220                    .node(body_vjp.outputs[0])
5221                    .shape
5222                    .size_bytes()
5223                    .expect("ScanBackward dcarry must be statically shaped");
5224                let carry_elem_size = body_vjp
5225                    .node(body_vjp.outputs[0])
5226                    .shape
5227                    .dtype()
5228                    .size_bytes() as u32;
5229
5230                // For each xs input on the outer node:
5231                // (outer_xs_base, per_step_bytes).
5232                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5233                for i in 0..*num_xs as usize {
5234                    let outer_xs_id = node.inputs[3 + i];
5235                    let outer_xs_off = node_offset(arena, outer_xs_id);
5236                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
5237                    let total = outer_xs_shape
5238                        .size_bytes()
5239                        .expect("ScanBackward xs must have static shape");
5240                    let per_step = total / *length as usize;
5241                    outer_xs_offs.push((outer_xs_off, per_step as u32));
5242                }
5243
5244                // If recursive checkpointing is active, we also compile
5245                // the forward body so the executor can recompute
5246                // intermediate carries. The forward body is supplied
5247                // by the AD pass via `forward_body: Some(_)`.
5248                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5249                    if is_recursive {
5250                        let fb = forward_body.as_ref().unwrap();
5251                        let fb_plan = rlx_opt::memory::plan_memory(fb);
5252                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
5253                            .assignments
5254                            .iter()
5255                            .map(|(id, slot)| (*id, slot.offset))
5256                            .collect();
5257                        let mut fb_inputs: Vec<NodeId> = fb
5258                            .nodes()
5259                            .iter()
5260                            .filter(|n| matches!(n.op, Op::Input { .. }))
5261                            .map(|n| n.id)
5262                            .collect();
5263                        fb_inputs.sort();
5264                        let fb_carry = fb_offsets[&fb_inputs[0]];
5265                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
5266                            .map(|i| fb_offsets[&fb_inputs[i]])
5267                            .collect();
5268                        let fb_out = fb_offsets[&fb.outputs[0]];
5269                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5270                        for n in fb.nodes() {
5271                            if let Op::Constant { data } = &n.op
5272                                && fb_arena.has_buffer(n.id)
5273                                && !data.is_empty()
5274                            {
5275                                // Byte-copy works for any
5276                                // numeric dtype as long as the
5277                                // arena slot is sized to hold
5278                                // it — the Constant's `data`
5279                                // already encodes the right
5280                                // bytes per element.
5281                                let off = fb_arena.byte_offset(n.id);
5282                                let buf = fb_arena.raw_buf_mut();
5283                                let nb = (buf.len() - off).min(data.len());
5284                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5285                            }
5286                        }
5287                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
5288                        let fb_sched = compile_thunks_with_rng(fb, &fb_arena, rng);
5289                        (
5290                            Some(Arc::new(fb_sched)),
5291                            Some(Arc::new(fb_init_bytes)),
5292                            fb_carry,
5293                            fb_out,
5294                            fb_xs,
5295                        )
5296                    } else {
5297                        (None, None, 0, 0, Vec::new())
5298                    };
5299
5300                Thunk::ScanBackward {
5301                    body_vjp: Arc::new(body_schedule),
5302                    body_init: Arc::new(body_init),
5303                    body_carry_in_off,
5304                    body_x_offs: Arc::new(body_x_offs),
5305                    body_d_output_off,
5306                    body_dcarry_out_off,
5307                    outer_init_off: node_offset(arena, node.inputs[0]),
5308                    outer_traj_off: node_offset(arena, node.inputs[1]),
5309                    outer_upstream_off: node_offset(arena, node.inputs[2]),
5310                    outer_xs_offs: Arc::new(outer_xs_offs),
5311                    outer_dinit_off: node_offset(arena, node.id),
5312                    length: *length,
5313                    carry_bytes: carry_bytes as u32,
5314                    carry_elem_size,
5315                    save_trajectory: *save_trajectory,
5316                    num_checkpoints: *num_checkpoints,
5317                    forward_body: fb_schedule,
5318                    forward_body_init: fb_init,
5319                    forward_body_carry_in_off: fb_carry_in_off,
5320                    forward_body_output_off: fb_output_off,
5321                    forward_body_x_offs: Arc::new(fb_x_offs),
5322                }
5323            }
5324
5325            Op::ScanBackwardXs {
5326                body_vjp,
5327                length,
5328                save_trajectory,
5329                num_xs,
5330                xs_idx,
5331                num_checkpoints,
5332                forward_body,
5333            } => {
5334                assert!(
5335                    *num_checkpoints == 0 || *num_checkpoints <= *length,
5336                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
5337                    *num_checkpoints,
5338                    *length
5339                );
5340                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
5341                if is_recursive {
5342                    assert!(
5343                        forward_body.is_some(),
5344                        "Op::ScanBackwardXs with num_checkpoints<length \
5345                         requires forward_body"
5346                    );
5347                }
5348                // Mirror ScanBackward's body_vjp slot identification +
5349                // arena prep, then add: per-iteration extraction of the
5350                // body_vjp output that corresponds to the chosen xs.
5351                //
5352                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
5353                //   outputs[0]      = dcarry
5354                //   outputs[1 + i]  = dx_t_i
5355                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
5356                let body_offsets: HashMap<NodeId, usize> = body_plan
5357                    .assignments
5358                    .iter()
5359                    .map(|(id, slot)| (*id, slot.offset))
5360                    .collect();
5361                let mut body_d_output_off: Option<usize> = None;
5362                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
5363                for n in body_vjp.nodes() {
5364                    if let Op::Input { name } = &n.op {
5365                        let off = body_offsets[&n.id];
5366                        if name == "d_output" {
5367                            body_d_output_off = Some(off);
5368                        } else {
5369                            body_other_inputs.push((n.id, off));
5370                        }
5371                    }
5372                }
5373                body_other_inputs.sort_by_key(|(id, _)| *id);
5374                let body_d_output_off =
5375                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
5376                let expected_others = 1 + *num_xs as usize;
5377                assert_eq!(
5378                    body_other_inputs.len(),
5379                    expected_others,
5380                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
5381                    body_other_inputs.len(),
5382                    expected_others
5383                );
5384                let body_carry_in_off = body_other_inputs[0].1;
5385                let body_x_offs: Vec<usize> = body_other_inputs
5386                    .iter()
5387                    .skip(1)
5388                    .map(|(_, off)| *off)
5389                    .collect();
5390                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
5391                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
5392                let body_dxs_out_off = body_offsets[&dxs_out_node];
5393
5394                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5395                for n in body_vjp.nodes() {
5396                    if let Op::Constant { data } = &n.op
5397                        && body_arena.has_buffer(n.id)
5398                        && !data.is_empty()
5399                    {
5400                        match n.shape.dtype() {
5401                            rlx_ir::DType::F64 => {
5402                                let off = body_arena.byte_offset(n.id);
5403                                let buf = body_arena.raw_buf_mut();
5404                                let nb = (buf.len() - off).min(data.len());
5405                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5406                            }
5407                            _ => {
5408                                let buf = body_arena.slice_mut(n.id);
5409                                let nf = data.len() / 4;
5410                                let nl = buf.len().min(nf);
5411                                for i in 0..nl {
5412                                    let bytes = [
5413                                        data[i * 4],
5414                                        data[i * 4 + 1],
5415                                        data[i * 4 + 2],
5416                                        data[i * 4 + 3],
5417                                    ];
5418                                    buf[i] = f32::from_le_bytes(bytes);
5419                                }
5420                            }
5421                        }
5422                    }
5423                }
5424                let body_init = body_arena.raw_buf().to_vec();
5425                let body_schedule = compile_thunks_with_rng(body_vjp, &body_arena, rng);
5426
5427                let carry_bytes = body_vjp
5428                    .node(body_vjp.outputs[0])
5429                    .shape
5430                    .size_bytes()
5431                    .expect("ScanBackwardXs dcarry must be statically shaped");
5432                let carry_elem_size = body_vjp
5433                    .node(body_vjp.outputs[0])
5434                    .shape
5435                    .dtype()
5436                    .size_bytes() as u32;
5437                let per_step_bytes = body_vjp
5438                    .node(dxs_out_node)
5439                    .shape
5440                    .size_bytes()
5441                    .expect("ScanBackwardXs dxs body output must be statically shaped");
5442
5443                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
5444                for i in 0..*num_xs as usize {
5445                    let outer_xs_id = node.inputs[3 + i];
5446                    let outer_xs_off = node_offset(arena, outer_xs_id);
5447                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
5448                    let total = outer_xs_shape
5449                        .size_bytes()
5450                        .expect("ScanBackwardXs xs must have static shape");
5451                    let per_step = total / *length as usize;
5452                    outer_xs_offs.push((outer_xs_off, per_step as u32));
5453                }
5454
5455                // Compile forward_body for recompute when checkpointed.
5456                // Mirrors the same code path in the ScanBackward arm.
5457                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
5458                    if is_recursive {
5459                        let fb = forward_body.as_ref().unwrap();
5460                        let fb_plan = rlx_opt::memory::plan_memory(fb);
5461                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
5462                            .assignments
5463                            .iter()
5464                            .map(|(id, slot)| (*id, slot.offset))
5465                            .collect();
5466                        let mut fb_inputs: Vec<NodeId> = fb
5467                            .nodes()
5468                            .iter()
5469                            .filter(|n| matches!(n.op, Op::Input { .. }))
5470                            .map(|n| n.id)
5471                            .collect();
5472                        fb_inputs.sort();
5473                        let fb_carry = fb_offsets[&fb_inputs[0]];
5474                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
5475                            .map(|i| fb_offsets[&fb_inputs[i]])
5476                            .collect();
5477                        let fb_out = fb_offsets[&fb.outputs[0]];
5478                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
5479                        for n in fb.nodes() {
5480                            if let Op::Constant { data } = &n.op
5481                                && fb_arena.has_buffer(n.id)
5482                                && !data.is_empty()
5483                            {
5484                                // Byte-copy works for any
5485                                // numeric dtype as long as the
5486                                // arena slot is sized to hold
5487                                // it — the Constant's `data`
5488                                // already encodes the right
5489                                // bytes per element.
5490                                let off = fb_arena.byte_offset(n.id);
5491                                let buf = fb_arena.raw_buf_mut();
5492                                let nb = (buf.len() - off).min(data.len());
5493                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5494                            }
5495                        }
5496                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
5497                        let fb_sched = compile_thunks_with_rng(fb, &fb_arena, rng);
5498                        (
5499                            Some(Arc::new(fb_sched)),
5500                            Some(Arc::new(fb_init_bytes)),
5501                            fb_carry,
5502                            fb_out,
5503                            fb_xs,
5504                        )
5505                    } else {
5506                        (None, None, 0, 0, Vec::new())
5507                    };
5508
5509                Thunk::ScanBackwardXs {
5510                    body_vjp: Arc::new(body_schedule),
5511                    body_init: Arc::new(body_init),
5512                    body_carry_in_off,
5513                    body_x_offs: Arc::new(body_x_offs),
5514                    body_d_output_off,
5515                    body_dcarry_out_off,
5516                    body_dxs_out_off,
5517                    outer_init_off: node_offset(arena, node.inputs[0]),
5518                    outer_traj_off: node_offset(arena, node.inputs[1]),
5519                    outer_upstream_off: node_offset(arena, node.inputs[2]),
5520                    outer_xs_offs: Arc::new(outer_xs_offs),
5521                    outer_dxs_off: node_offset(arena, node.id),
5522                    length: *length,
5523                    carry_bytes: carry_bytes as u32,
5524                    carry_elem_size,
5525                    per_step_bytes: per_step_bytes as u32,
5526                    save_trajectory: *save_trajectory,
5527                    num_checkpoints: *num_checkpoints,
5528                    forward_body: fb_schedule,
5529                    forward_body_init: fb_init,
5530                    forward_body_carry_in_off: fb_carry_in_off,
5531                    forward_body_output_off: fb_output_off,
5532                    forward_body_x_offs: Arc::new(fb_x_offs),
5533                }
5534            }
5535
5536            Op::Concat { axis } => {
5537                // Compute outer/inner from the OUTPUT shape: all inputs share
5538                // the same shape except along `axis`. The output's leading
5539                // and trailing dims match.
5540                let out_shape = &node.shape;
5541                let rank = out_shape.rank();
5542                let outer: usize = (0..*axis)
5543                    .map(|i| out_shape.dim(i).unwrap_static())
5544                    .product::<usize>()
5545                    .max(1);
5546                let inner: usize = (*axis + 1..rank)
5547                    .map(|i| out_shape.dim(i).unwrap_static())
5548                    .product::<usize>()
5549                    .max(1);
5550                let total_axis = out_shape.dim(*axis).unwrap_static();
5551                let inputs: Vec<(usize, u32, u32)> = node
5552                    .inputs
5553                    .iter()
5554                    .map(|&in_id| {
5555                        let in_shape = &graph.node(in_id).shape;
5556                        let in_axis = concat_axis_extent(in_shape, *axis, rank);
5557                        let in_numel = in_shape.num_elements().unwrap_or(0) as u32;
5558                        (node_offset(arena, in_id), in_axis as u32, in_numel)
5559                    })
5560                    .collect();
5561                let dst = node_offset(arena, node.id);
5562                match out_shape.dtype() {
5563                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
5564                        dst,
5565                        outer: outer as u32,
5566                        inner: inner as u32,
5567                        total_axis: total_axis as u32,
5568                        inputs,
5569                    },
5570                    _ => Thunk::Concat {
5571                        dst,
5572                        outer: outer as u32,
5573                        inner: inner as u32,
5574                        total_axis: total_axis as u32,
5575                        inputs,
5576                    },
5577                }
5578            }
5579
5580            Op::GaussianSplatRender {
5581                width,
5582                height,
5583                tile_size,
5584                radius_scale,
5585                alpha_cutoff,
5586                max_splat_steps,
5587                transmittance_threshold,
5588                max_list_entries,
5589            } => {
5590                let elem_len =
5591                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5592                Thunk::GaussianSplatRender {
5593                    positions_off: node_offset(arena, node.inputs[0]),
5594                    positions_len: elem_len(node.inputs[0]),
5595                    scales_off: node_offset(arena, node.inputs[1]),
5596                    scales_len: elem_len(node.inputs[1]),
5597                    rotations_off: node_offset(arena, node.inputs[2]),
5598                    rotations_len: elem_len(node.inputs[2]),
5599                    opacities_off: node_offset(arena, node.inputs[3]),
5600                    opacities_len: elem_len(node.inputs[3]),
5601                    colors_off: node_offset(arena, node.inputs[4]),
5602                    colors_len: elem_len(node.inputs[4]),
5603                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5604                    sh_coeffs_len: elem_len(node.inputs[5]),
5605                    meta_off: node_offset(arena, node.inputs[6]),
5606                    dst_off: node_offset(arena, node.id),
5607                    dst_len: node.shape.num_elements().unwrap_or(0),
5608                    width: *width,
5609                    height: *height,
5610                    tile_size: *tile_size,
5611                    radius_scale: *radius_scale,
5612                    alpha_cutoff: *alpha_cutoff,
5613                    max_splat_steps: *max_splat_steps,
5614                    transmittance_threshold: *transmittance_threshold,
5615                    max_list_entries: *max_list_entries,
5616                }
5617            }
5618
5619            Op::GaussianSplatRenderBackward {
5620                width,
5621                height,
5622                tile_size,
5623                radius_scale,
5624                alpha_cutoff,
5625                max_splat_steps,
5626                transmittance_threshold,
5627                max_list_entries,
5628                loss_grad_clip,
5629                sh_band,
5630                max_anisotropy,
5631            } => {
5632                let elem_len =
5633                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5634                Thunk::GaussianSplatRenderBackward {
5635                    positions_off: node_offset(arena, node.inputs[0]),
5636                    positions_len: elem_len(node.inputs[0]),
5637                    scales_off: node_offset(arena, node.inputs[1]),
5638                    scales_len: elem_len(node.inputs[1]),
5639                    rotations_off: node_offset(arena, node.inputs[2]),
5640                    rotations_len: elem_len(node.inputs[2]),
5641                    opacities_off: node_offset(arena, node.inputs[3]),
5642                    opacities_len: elem_len(node.inputs[3]),
5643                    colors_off: node_offset(arena, node.inputs[4]),
5644                    colors_len: elem_len(node.inputs[4]),
5645                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5646                    sh_coeffs_len: elem_len(node.inputs[5]),
5647                    meta_off: node_offset(arena, node.inputs[6]),
5648                    d_loss_off: node_offset(arena, node.inputs[7]),
5649                    d_loss_len: elem_len(node.inputs[7]),
5650                    packed_off: node_offset(arena, node.id),
5651                    packed_len: node.shape.num_elements().unwrap_or(0),
5652                    width: *width,
5653                    height: *height,
5654                    tile_size: *tile_size,
5655                    radius_scale: *radius_scale,
5656                    alpha_cutoff: *alpha_cutoff,
5657                    max_splat_steps: *max_splat_steps,
5658                    transmittance_threshold: *transmittance_threshold,
5659                    max_list_entries: *max_list_entries,
5660                    loss_grad_clip: *loss_grad_clip,
5661                    sh_band: *sh_band,
5662                    max_anisotropy: *max_anisotropy,
5663                }
5664            }
5665
5666            Op::GaussianSplatPrepare {
5667                width,
5668                height,
5669                tile_size,
5670                radius_scale,
5671                alpha_cutoff,
5672                max_splat_steps,
5673                transmittance_threshold,
5674                max_list_entries,
5675            } => {
5676                let elem_len =
5677                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5678                Thunk::GaussianSplatPrepare {
5679                    positions_off: node_offset(arena, node.inputs[0]),
5680                    positions_len: elem_len(node.inputs[0]),
5681                    scales_off: node_offset(arena, node.inputs[1]),
5682                    scales_len: elem_len(node.inputs[1]),
5683                    rotations_off: node_offset(arena, node.inputs[2]),
5684                    rotations_len: elem_len(node.inputs[2]),
5685                    opacities_off: node_offset(arena, node.inputs[3]),
5686                    opacities_len: elem_len(node.inputs[3]),
5687                    colors_off: node_offset(arena, node.inputs[4]),
5688                    colors_len: elem_len(node.inputs[4]),
5689                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5690                    sh_coeffs_len: elem_len(node.inputs[5]),
5691                    meta_off: node_offset(arena, node.inputs[6]),
5692                    meta_len: elem_len(node.inputs[6]),
5693                    prep_off: node_offset(arena, node.id),
5694                    prep_len: node.shape.num_elements().unwrap_or(0),
5695                    width: *width,
5696                    height: *height,
5697                    tile_size: *tile_size,
5698                    radius_scale: *radius_scale,
5699                    alpha_cutoff: *alpha_cutoff,
5700                    max_splat_steps: *max_splat_steps,
5701                    transmittance_threshold: *transmittance_threshold,
5702                    max_list_entries: *max_list_entries,
5703                }
5704            }
5705
5706            Op::GaussianSplatRasterize {
5707                width,
5708                height,
5709                tile_size,
5710                alpha_cutoff,
5711                max_splat_steps,
5712                transmittance_threshold,
5713                max_list_entries,
5714            } => {
5715                let elem_len =
5716                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5717                let prep_id = node.inputs[0];
5718                let count = match &graph.node(prep_id).op {
5719                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5720                        elem_len(graph.node(prep_id).inputs[0]) / 3
5721                    }
5722                    _ => 1,
5723                };
5724                Thunk::GaussianSplatRasterize {
5725                    prep_off: node_offset(arena, prep_id),
5726                    prep_len: elem_len(prep_id),
5727                    meta_off: node_offset(arena, node.inputs[1]),
5728                    meta_len: elem_len(node.inputs[1]),
5729                    dst_off: node_offset(arena, node.id),
5730                    dst_len: node.shape.num_elements().unwrap_or(0),
5731                    count,
5732                    width: *width,
5733                    height: *height,
5734                    tile_size: *tile_size,
5735                    alpha_cutoff: *alpha_cutoff,
5736                    max_splat_steps: *max_splat_steps,
5737                    transmittance_threshold: *transmittance_threshold,
5738                    max_list_entries: *max_list_entries,
5739                }
5740            }
5741
5742            Op::Custom { name, attrs, .. } => {
5743                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5744                    panic!(
5745                        "compile_thunks: no CPU kernel registered for \
5746                         Op::Custom('{name}'). Register one via \
5747                         rlx_cpu::op_registry::register_cpu_kernel \
5748                         before compiling on the CPU backend."
5749                    )
5750                });
5751                let inputs_v: Vec<(usize, u32, Shape)> = node
5752                    .inputs
5753                    .iter()
5754                    .map(|&in_id| {
5755                        let s = graph.node(in_id).shape.clone();
5756                        let len = s.num_elements().unwrap_or(0) as u32;
5757                        (node_offset(arena, in_id), len, s)
5758                    })
5759                    .collect();
5760                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5761                Thunk::CustomOp {
5762                    kernel,
5763                    inputs: inputs_v,
5764                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5765                    attrs: attrs.clone(),
5766                }
5767            }
5768
5769            Op::Fft { inverse, norm } => {
5770                let shape = &node.shape;
5771                let meta = rlx_ir::fft::fft_meta(shape);
5772                let dtype = shape.dtype();
5773                assert!(
5774                    matches!(
5775                        dtype,
5776                        rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5777                    ),
5778                    "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5779                );
5780                Thunk::Fft1d {
5781                    src: node_offset(arena, node.inputs[0]),
5782                    dst: node_offset(arena, node.id),
5783                    outer: meta.outer as u32,
5784                    n_complex: meta.n_complex as u32,
5785                    inverse: *inverse,
5786                    norm_tag: norm.tag(),
5787                    dtype,
5788                }
5789            }
5790
5791            Op::FftButterflyStage { stage, n_fft } => {
5792                let state_shape = graph.node(node.inputs[0]).shape.clone();
5793                assert_eq!(
5794                    state_shape.dtype(),
5795                    rlx_ir::DType::F32,
5796                    "Op::FftButterflyStage requires F32 state"
5797                );
5798                let batch = state_shape.dim(0).unwrap_static() as u32;
5799                Thunk::FftButterflyStage {
5800                    state_src: node_offset(arena, node.inputs[0]),
5801                    state_dst: node_offset(arena, node.id),
5802                    gate_src: node_offset(arena, node.inputs[1]),
5803                    rev_src: node_offset(arena, node.inputs[2]),
5804                    tw_re_src: node_offset(arena, node.inputs[3]),
5805                    tw_im_src: node_offset(arena, node.inputs[4]),
5806                    batch,
5807                    n_fft: *n_fft,
5808                    stage: *stage,
5809                }
5810            }
5811
5812            Op::LogMel => {
5813                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5814                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5815                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5816                    .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
5817                Thunk::LogMel {
5818                    spec: node_offset(arena, node.inputs[0]),
5819                    filters: node_offset(arena, node.inputs[1]),
5820                    dst: node_offset(arena, node.id),
5821                    outer: meta.outer as u32,
5822                    n_fft: meta.n_fft as u32,
5823                    n_bins: meta.n_bins as u32,
5824                    n_mels: meta.n_mels as u32,
5825                }
5826            }
5827
5828            Op::LogMelBackward => {
5829                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5830                let filt_shape = graph.node(node.inputs[1]).shape.clone();
5831                let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
5832                    .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
5833                Thunk::LogMelBackward {
5834                    spec: node_offset(arena, node.inputs[0]),
5835                    filters: node_offset(arena, node.inputs[1]),
5836                    dy: node_offset(arena, node.inputs[2]),
5837                    dst: node_offset(arena, node.id),
5838                    outer: meta.outer as u32,
5839                    n_fft: meta.n_fft as u32,
5840                    n_bins: meta.n_bins as u32,
5841                    n_mels: meta.n_mels as u32,
5842                }
5843            }
5844
5845            Op::WelchPeaks { k, n_segments } => {
5846                let spec_shape = graph.node(node.inputs[0]).shape.clone();
5847                let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
5848                    .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
5849                Thunk::WelchPeaks {
5850                    spec: node_offset(arena, node.inputs[0]),
5851                    dst: node_offset(arena, node.id),
5852                    welch_batch: meta.welch_batch as u32,
5853                    n_fft: meta.n_fft as u32,
5854                    n_segments: meta.n_segments as u32,
5855                    k: meta.k as u32,
5856                }
5857            }
5858
5859            Op::CustomFn {
5860                fwd_body,
5861                num_inputs,
5862                ..
5863            } => {
5864                // Plan + compile the body sub-graph standalone, fill its
5865                // Constants (mirrors the Op::Scan body lowering), then
5866                // capture per-input copy specs and the output spec.
5867                // Body Inputs in NodeId order match the outer node's
5868                // operand vector by position.
5869                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5870                let body_offsets: HashMap<NodeId, usize> = body_plan
5871                    .assignments
5872                    .iter()
5873                    .map(|(id, slot)| (*id, slot.offset))
5874                    .collect();
5875
5876                let mut body_input_ids: Vec<NodeId> = fwd_body
5877                    .nodes()
5878                    .iter()
5879                    .filter(|n| matches!(n.op, Op::Input { .. }))
5880                    .map(|n| n.id)
5881                    .collect();
5882                body_input_ids.sort();
5883                assert_eq!(
5884                    body_input_ids.len(),
5885                    *num_inputs as usize,
5886                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5887                    body_input_ids.len(),
5888                    *num_inputs,
5889                );
5890
5891                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5892                for n in fwd_body.nodes() {
5893                    if let Op::Constant { data } = &n.op
5894                        && body_arena.has_buffer(n.id)
5895                        && !data.is_empty()
5896                    {
5897                        match n.shape.dtype() {
5898                            rlx_ir::DType::F64 => {
5899                                let off = body_arena.byte_offset(n.id);
5900                                let buf = body_arena.raw_buf_mut();
5901                                let nb = (buf.len() - off).min(data.len());
5902                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5903                            }
5904                            _ => {
5905                                let buf = body_arena.slice_mut(n.id);
5906                                let nf = data.len() / 4;
5907                                let nl = buf.len().min(nf);
5908                                for i in 0..nl {
5909                                    let bytes = [
5910                                        data[i * 4],
5911                                        data[i * 4 + 1],
5912                                        data[i * 4 + 2],
5913                                        data[i * 4 + 3],
5914                                    ];
5915                                    buf[i] = f32::from_le_bytes(bytes);
5916                                }
5917                            }
5918                        }
5919                    }
5920                }
5921                let body_init = body_arena.raw_buf().to_vec();
5922                let body_schedule = compile_thunks_with_rng(fwd_body, &body_arena, rng);
5923
5924                // Per primal input: (body_input_off, outer_input_off, bytes).
5925                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5926                    .map(|i| {
5927                        let body_in = body_input_ids[i];
5928                        let body_off = body_offsets[&body_in];
5929                        let outer_in = node.inputs[i];
5930                        let outer_off = node_offset(arena, outer_in);
5931                        let bytes = graph
5932                            .node(outer_in)
5933                            .shape
5934                            .size_bytes()
5935                            .expect("Op::CustomFn primal input must have static shape");
5936                        (body_off, outer_off, bytes as u32)
5937                    })
5938                    .collect();
5939
5940                let body_output_id = fwd_body
5941                    .outputs
5942                    .first()
5943                    .copied()
5944                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5945                let body_output_off = body_offsets[&body_output_id];
5946                let out_bytes = node
5947                    .shape
5948                    .size_bytes()
5949                    .expect("Op::CustomFn output must have static shape");
5950
5951                Thunk::CustomFn {
5952                    body: Arc::new(body_schedule),
5953                    body_init: Arc::new(body_init),
5954                    inputs: Arc::new(inputs_v),
5955                    body_output_off,
5956                    outer_output_off: node_offset(arena, node.id),
5957                    out_bytes: out_bytes as u32,
5958                }
5959            }
5960
5961            _ => Thunk::Nop,
5962        };
5963        thunks.push(t);
5964    }
5965
5966    let cfg = crate::config::RuntimeConfig::global();
5967    let mask_thr = cfg.mask_binary_threshold;
5968    let mask_neg = cfg.attn_mask_neg_inf;
5969    let score_skip = cfg.score_skip_threshold;
5970
5971    // Pre-compile closures (skip Nops — they're filtered out)
5972    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5973        .iter()
5974        .filter(|t| !matches!(t, Thunk::Nop))
5975        .map(|thunk| {
5976            match thunk.clone() {
5977                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5978
5979                Thunk::Sgemm { a, b, c, m, k, n } => {
5980                    let (m, k, n) = (m as usize, k as usize, n as usize);
5981                    Arc::new(move |base: *mut u8| unsafe {
5982                        crate::blas::sgemm(
5983                            sl(a, base, m * k),
5984                            sl(b, base, k * n),
5985                            sl_mut(c, base, m * n),
5986                            m,
5987                            k,
5988                            n,
5989                        );
5990                    })
5991                }
5992
5993                Thunk::CgemmC64 { a, b, c, m, k, n } => {
5994                    let (m, k, n) = (m as usize, k as usize, n as usize);
5995                    Arc::new(move |base: *mut u8| unsafe {
5996                        cgemm_c64(a, b, c, m, k, n, base);
5997                    })
5998                }
5999
6000                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
6001                    let (n_, nrhs_) = (n as usize, nrhs as usize);
6002                    Arc::new(move |base: *mut u8| unsafe {
6003                        let a_src = sl_f64(a, base, n_ * n_);
6004                        let b_src = sl_f64(b, base, n_ * nrhs_);
6005                        let mut a_scratch: Vec<f64> = a_src.to_vec();
6006                        let mut x_buf: Vec<f64> = b_src.to_vec();
6007                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
6008                        if info != 0 {
6009                            panic!("DenseSolveF64: singular (info={info})");
6010                        }
6011                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
6012                    })
6013                }
6014
6015                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
6016                    let (n_, nrhs_) = (n as usize, nrhs as usize);
6017                    Arc::new(move |base: *mut u8| unsafe {
6018                        let a_src = sl(a, base, n_ * n_);
6019                        let b_src = sl(b, base, n_ * nrhs_);
6020                        let mut a_scratch: Vec<f32> = a_src.to_vec();
6021                        let mut x_buf: Vec<f32> = b_src.to_vec();
6022                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
6023                        if info != 0 {
6024                            panic!("DenseSolveF32: singular (info={info})");
6025                        }
6026                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
6027                    })
6028                }
6029
6030                Thunk::FusedMmBiasAct {
6031                    a,
6032                    w,
6033                    bias,
6034                    c,
6035                    m,
6036                    k,
6037                    n,
6038                    act,
6039                } => {
6040                    let (m, k, n) = (m as usize, k as usize, n as usize);
6041                    Arc::new(move |base: *mut u8| unsafe {
6042                        let out = sl_mut(c, base, m * n);
6043                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
6044                        // Bias + activation epilogue. Gelu uses the fused
6045                        // `par_bias_gelu` kernel (bias add + Gelu in one
6046                        // pass). For everything else, do the bias add first
6047                        // and then apply the activation per-element. The
6048                        // pre-fix code dispatched `_ => bias_add` and dropped
6049                        // the activation entirely — silent correctness bug
6050                        // for Silu/Relu/Sigmoid/etc.
6051                        match act {
6052                            Some(Activation::Gelu) => {
6053                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
6054                            }
6055                            Some(other) => {
6056                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
6057                                apply_activation_inplace(out, other);
6058                            }
6059                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
6060                        }
6061                    })
6062                }
6063
6064                Thunk::FusedResidualLN {
6065                    x,
6066                    res,
6067                    bias,
6068                    g,
6069                    b,
6070                    out,
6071                    rows,
6072                    h,
6073                    eps,
6074                    has_bias,
6075                } => {
6076                    let (rows, h) = (rows as usize, h as usize);
6077                    Arc::new(move |base: *mut u8| unsafe {
6078                        let zero = vec![0f32; h]; // closure only — not hot path
6079                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
6080                        let xp = sl(x, base, rows * h).as_ptr() as usize;
6081                        let rp = sl(res, base, rows * h).as_ptr() as usize;
6082                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
6083                        let bp = bi.as_ptr() as usize;
6084                        let gp = sl(g, base, h).as_ptr() as usize;
6085                        let bbp = sl(b, base, h).as_ptr() as usize;
6086                        crate::pool::par_for(rows, 4, &|off, cnt| {
6087                            let xs = std::slice::from_raw_parts(
6088                                (xp as *const f32).add(off * h),
6089                                cnt * h,
6090                            );
6091                            let rs = std::slice::from_raw_parts(
6092                                (rp as *const f32).add(off * h),
6093                                cnt * h,
6094                            );
6095                            let os = std::slice::from_raw_parts_mut(
6096                                (op as *mut f32).add(off * h),
6097                                cnt * h,
6098                            );
6099                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
6100                            let g = std::slice::from_raw_parts(gp as *const f32, h);
6101                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
6102                            crate::kernels::residual_bias_layer_norm(
6103                                xs, rs, bi, g, b, os, cnt, h, eps,
6104                            );
6105                        });
6106                    })
6107                }
6108
6109                Thunk::BiasAdd {
6110                    src,
6111                    bias,
6112                    dst,
6113                    m,
6114                    n,
6115                } => {
6116                    let (m, n) = (m as usize, n as usize);
6117                    let len = m * n;
6118                    Arc::new(move |base: *mut u8| unsafe {
6119                        let out = sl_mut(dst, base, len);
6120                        if src != dst {
6121                            let src_ptr = base.add(src) as *const f32;
6122                            let dst_ptr = base.add(dst) as *mut f32;
6123                            if src_ptr != dst_ptr {
6124                                std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
6125                            }
6126                        }
6127                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
6128                    })
6129                }
6130
6131                Thunk::Gather {
6132                    table,
6133                    table_len,
6134                    idx,
6135                    dst,
6136                    num_idx,
6137                    trailing,
6138                    idx_i64,
6139                    table_bytes,
6140                } => {
6141                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
6142                    let rows = tl / tr.max(1);
6143                    let (idx_i64, table_bytes) = (idx_i64, table_bytes);
6144                    Arc::new(move |base: *mut u8| unsafe {
6145                        if table_bytes == 8 {
6146                            let tab = sl_i64(table, base, tl);
6147                            let out = sl_mut_i64(dst, base, ni * tr);
6148                            if idx_i64 != 0 {
6149                                let ids = sl_i64(idx, base, ni);
6150                                for i in 0..ni {
6151                                    let row = ids[i].max(0) as usize;
6152                                    if row < rows {
6153                                        out[i * tr..(i + 1) * tr]
6154                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6155                                    }
6156                                }
6157                            } else {
6158                                let ids = sl(idx, base, ni);
6159                                for i in 0..ni {
6160                                    let row = ids[i] as usize;
6161                                    if row < rows {
6162                                        out[i * tr..(i + 1) * tr]
6163                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6164                                    }
6165                                }
6166                            }
6167                        } else {
6168                            let tab = sl(table, base, tl);
6169                            let out = sl_mut(dst, base, ni * tr);
6170                            if idx_i64 != 0 {
6171                                let ids = sl_i64(idx, base, ni);
6172                                for i in 0..ni {
6173                                    let row = ids[i].max(0) as usize;
6174                                    if row < rows {
6175                                        out[i * tr..(i + 1) * tr]
6176                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6177                                    }
6178                                }
6179                            } else {
6180                                let ids = sl(idx, base, ni);
6181                                for i in 0..ni {
6182                                    let row = ids[i] as usize;
6183                                    if row < rows {
6184                                        out[i * tr..(i + 1) * tr]
6185                                            .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
6186                                    }
6187                                }
6188                            }
6189                        }
6190                    })
6191                }
6192
6193                Thunk::Narrow {
6194                    src,
6195                    dst,
6196                    outer,
6197                    src_stride,
6198                    dst_stride,
6199                    inner,
6200                    elem_bytes,
6201                } => {
6202                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
6203                }
6204
6205                Thunk::Copy { src, dst, len } => {
6206                    let len = len as usize;
6207                    Arc::new(move |base: *mut u8| unsafe {
6208                        if src == dst || len == 0 {
6209                            return;
6210                        }
6211                        let src_ptr = base.add(src) as *const f32;
6212                        let dst_ptr = base.add(dst) as *mut f32;
6213                        if src_ptr == dst_ptr {
6214                            return;
6215                        }
6216                        std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
6217                    })
6218                }
6219
6220                Thunk::Softmax { data, rows, cols } => {
6221                    let (rows, cols) = (rows as usize, cols as usize);
6222                    Arc::new(move |base: *mut u8| unsafe {
6223                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
6224                    })
6225                }
6226
6227                Thunk::Cumsum {
6228                    src,
6229                    dst,
6230                    rows,
6231                    cols,
6232                    exclusive,
6233                } => {
6234                    let (rows, cols) = (rows as usize, cols as usize);
6235                    Arc::new(move |base: *mut u8| unsafe {
6236                        let s = sl(src, base, rows * cols);
6237                        let d = sl_mut(dst, base, rows * cols);
6238                        if exclusive {
6239                            for r in 0..rows {
6240                                let mut acc = 0.0f32;
6241                                for c in 0..cols {
6242                                    d[r * cols + c] = acc;
6243                                    acc += s[r * cols + c];
6244                                }
6245                            }
6246                        } else {
6247                            for r in 0..rows {
6248                                let mut acc = 0.0f32;
6249                                for c in 0..cols {
6250                                    acc += s[r * cols + c];
6251                                    d[r * cols + c] = acc;
6252                                }
6253                            }
6254                        }
6255                    })
6256                }
6257
6258                Thunk::Sample {
6259                    logits,
6260                    dst,
6261                    batch,
6262                    vocab,
6263                    top_k,
6264                    top_p,
6265                    temperature,
6266                    seed,
6267                } => {
6268                    let (b, v) = (batch as usize, vocab as usize);
6269                    let k = (top_k as usize).min(v);
6270                    Arc::new(move |base: *mut u8| unsafe {
6271                        let lg = sl(logits, base, b * v);
6272                        let out = sl_mut(dst, base, b);
6273                        let mut rng =
6274                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
6275                        for bi in 0..b {
6276                            let row = &lg[bi * v..(bi + 1) * v];
6277                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
6278                        }
6279                    })
6280                }
6281
6282                Thunk::RngNormal {
6283                    dst,
6284                    len,
6285                    mean,
6286                    scale,
6287                    key,
6288                    op_seed,
6289                } => {
6290                    let n = len as usize;
6291                    let rng = rng_shared.clone();
6292                    Arc::new(move |base: *mut u8| unsafe {
6293                        let out = sl_mut(dst, base, n);
6294                        let opts = *rng.read().unwrap();
6295                        rlx_ir::fill_normal_like(out, mean, scale, opts, key, op_seed);
6296                    })
6297                }
6298
6299                Thunk::RngUniform {
6300                    dst,
6301                    len,
6302                    low,
6303                    high,
6304                    key,
6305                    op_seed,
6306                } => {
6307                    let n = len as usize;
6308                    let rng = rng_shared.clone();
6309                    Arc::new(move |base: *mut u8| unsafe {
6310                        let out = sl_mut(dst, base, n);
6311                        let opts = *rng.read().unwrap();
6312                        rlx_ir::fill_uniform_like(out, low, high, opts, key, op_seed);
6313                    })
6314                }
6315
6316                Thunk::DequantMatMul {
6317                    x,
6318                    w_q,
6319                    scale,
6320                    zp,
6321                    dst,
6322                    m,
6323                    k,
6324                    n,
6325                    block_size,
6326                    is_asymmetric,
6327                } => {
6328                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6329                    let n_blocks_per_col = k.div_ceil(bs);
6330                    Arc::new(move |base: *mut u8| unsafe {
6331                        let xs = sl(x, base, m * k);
6332                        // w_q is packed i8 — use raw byte slice + reinterpret.
6333                        let raw = base.add(w_q);
6334                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
6335                        let scales = sl(scale, base, n_blocks_per_col * n);
6336                        let zps = if is_asymmetric {
6337                            sl(zp, base, n_blocks_per_col * n)
6338                        } else {
6339                            &[][..]
6340                        };
6341                        let out = sl_mut(dst, base, m * n);
6342                        dequant_matmul_int8(
6343                            xs,
6344                            w_bytes,
6345                            scales,
6346                            zps,
6347                            out,
6348                            m,
6349                            k,
6350                            n,
6351                            bs,
6352                            is_asymmetric,
6353                        );
6354                    })
6355                }
6356
6357                Thunk::DequantMatMulGguf {
6358                    x,
6359                    w_q,
6360                    dst,
6361                    m,
6362                    k,
6363                    n,
6364                    scheme,
6365                } => {
6366                    let (m, k, n) = (m as usize, k as usize, n as usize);
6367                    let block_bytes = scheme.gguf_block_bytes() as usize;
6368                    let block_elems = scheme.gguf_block_size() as usize;
6369                    let total_bytes = (k * n) / block_elems * block_bytes;
6370                    Arc::new(move |base: *mut u8| unsafe {
6371                        let xs = sl(x, base, m * k);
6372                        let w_bytes =
6373                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
6374                        let out = sl_mut(dst, base, m * n);
6375                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
6376                    })
6377                }
6378
6379                Thunk::DequantMatMulInt4 {
6380                    x,
6381                    w_q,
6382                    scale,
6383                    zp,
6384                    dst,
6385                    m,
6386                    k,
6387                    n,
6388                    block_size,
6389                    is_asymmetric,
6390                } => {
6391                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
6392                    let n_blocks = k.div_ceil(bs);
6393                    Arc::new(move |base: *mut u8| unsafe {
6394                        let xs = sl(x, base, m * k);
6395                        let w_bytes = std::slice::from_raw_parts(
6396                            base.add(w_q) as *const u8,
6397                            (k * n).div_ceil(2),
6398                        );
6399                        let scales = sl(scale, base, n_blocks * n);
6400                        let zps = if is_asymmetric {
6401                            sl(zp, base, n_blocks * n)
6402                        } else {
6403                            &[][..]
6404                        };
6405                        let out = sl_mut(dst, base, m * n);
6406                        dequant_matmul_int4(
6407                            xs,
6408                            w_bytes,
6409                            scales,
6410                            zps,
6411                            out,
6412                            m,
6413                            k,
6414                            n,
6415                            bs,
6416                            is_asymmetric,
6417                        );
6418                    })
6419                }
6420
6421                Thunk::DequantMatMulFp8 {
6422                    x,
6423                    w_q,
6424                    scale,
6425                    dst,
6426                    m,
6427                    k,
6428                    n,
6429                    e5m2,
6430                } => {
6431                    let (m, k, n) = (m as usize, k as usize, n as usize);
6432                    Arc::new(move |base: *mut u8| unsafe {
6433                        let xs = sl(x, base, m * k);
6434                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
6435                        let scales = sl(scale, base, n);
6436                        let out = sl_mut(dst, base, m * n);
6437                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
6438                    })
6439                }
6440
6441                Thunk::DequantMatMulNvfp4 {
6442                    x,
6443                    w_q,
6444                    scale,
6445                    global_scale,
6446                    dst,
6447                    m,
6448                    k,
6449                    n,
6450                } => {
6451                    let (m, k, n) = (m as usize, k as usize, n as usize);
6452                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
6453                    Arc::new(move |base: *mut u8| unsafe {
6454                        let xs = sl(x, base, m * k);
6455                        let w_bytes = std::slice::from_raw_parts(
6456                            base.add(w_q) as *const u8,
6457                            (k * n).div_ceil(2),
6458                        );
6459                        let scale_bytes =
6460                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
6461                        let gs = sl(global_scale, base, 1)[0];
6462                        let out = sl_mut(dst, base, m * n);
6463                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
6464                    })
6465                }
6466
6467                Thunk::LoraMatMul {
6468                    x,
6469                    w,
6470                    a,
6471                    b,
6472                    dst,
6473                    m,
6474                    k,
6475                    n,
6476                    r,
6477                    scale,
6478                } => {
6479                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
6480                    Arc::new(move |base: *mut u8| unsafe {
6481                        let xs = sl(x, base, m * k);
6482                        let ws = sl(w, base, k * n);
6483                        let a_s = sl(a, base, k * r);
6484                        let bs = sl(b, base, r * n);
6485                        let out = sl_mut(dst, base, m * n);
6486                        // Step 1: out = x · W.
6487                        crate::blas::sgemm(xs, ws, out, m, k, n);
6488                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
6489                        let mut tmp = vec![0f32; m * r];
6490                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
6491                        // Step 3: out += scale * (tmp · B).
6492                        // sgemm_accumulate uses alpha=1.0 internally, so
6493                        // scale tmp first.
6494                        if scale != 1.0 {
6495                            for v in tmp.iter_mut() {
6496                                *v *= scale;
6497                            }
6498                        }
6499                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
6500                    })
6501                }
6502
6503                Thunk::LayerNorm {
6504                    src,
6505                    g,
6506                    b,
6507                    dst,
6508                    rows,
6509                    h,
6510                    eps,
6511                } => {
6512                    let (rows, h) = (rows as usize, h as usize);
6513                    Arc::new(move |base: *mut u8| unsafe {
6514                        let inp = sl(src, base, rows * h);
6515                        let gamma = sl(g, base, h);
6516                        let beta = sl(b, base, h);
6517                        let out = sl_mut(dst, base, rows * h);
6518                        for row in 0..rows {
6519                            crate::kernels::layer_norm_row(
6520                                &inp[row * h..(row + 1) * h],
6521                                gamma,
6522                                beta,
6523                                &mut out[row * h..(row + 1) * h],
6524                                h,
6525                                eps,
6526                            );
6527                        }
6528                    })
6529                }
6530
6531                Thunk::BatchNormInference {
6532                    src,
6533                    g,
6534                    b,
6535                    mean,
6536                    var,
6537                    dst,
6538                    count,
6539                    channels,
6540                    eps,
6541                } => {
6542                    let count = count as usize;
6543                    let c = channels as usize;
6544                    let n = count * c;
6545                    let (src, g, b, mean, var, dst) = (src, g, b, mean, var, dst);
6546                    Arc::new(move |base: *mut u8| unsafe {
6547                        crate::kernels::batch_norm_inference(
6548                            sl(src, base, n),
6549                            sl(g, base, c),
6550                            sl(b, base, c),
6551                            sl(mean, base, c),
6552                            sl(var, base, c),
6553                            sl_mut(dst, base, n),
6554                            c,
6555                            eps,
6556                        );
6557                    })
6558                }
6559
6560                Thunk::Attention {
6561                    q,
6562                    k,
6563                    v,
6564                    mask,
6565                    out,
6566                    batch,
6567                    seq,
6568                    kv_seq,
6569                    heads,
6570                    head_dim,
6571                    mask_kind,
6572                    scale,
6573                    q_row_stride,
6574                    k_row_stride,
6575                    v_row_stride,
6576                    bhsd,
6577                } => {
6578                    if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6579                        eprintln!("[attn-compile] batch={batch} seq={seq} kv_seq={kv_seq} heads={heads} bhsd={bhsd}");
6580                    }
6581                    // Q seq length (`q_s`) and K/V seq length (`k_s`) differ
6582                    // during cached decode (`q_s=1`, `k_s=past_seq+1`). The
6583                    // earlier version of this kernel destructured
6584                    // `kv_seq: _` and used a single `s = seq` for both axes,
6585                    // so cached decode only scored 1×1 instead of 1×k_s —
6586                    // attention couldn't see the past K cache and decode
6587                    // collapsed into repetitive fragments
6588                    // (`Self-based on [1\nAnswer: Self-based on [1…`).
6589                    let (b, q_s, k_s, nh, dh) = (
6590                        batch as usize,
6591                        seq as usize,
6592                        kv_seq as usize,
6593                        heads as usize,
6594                        head_dim as usize,
6595                    );
6596                    let hs = nh * dh;
6597                    let qrs = q_row_stride as usize;
6598                    let krs = k_row_stride as usize;
6599                    let vrs = v_row_stride as usize;
6600                    // honor Op::Attention::score_scale (e.g. Gemma 4 = 1.0)
6601                    Arc::new(move |base: *mut u8| unsafe {
6602                        if std::env::var("RLX_ATTN_DEBUG").is_ok() {
6603                            eprintln!("[attn] b={b} q_s={q_s} k_s={k_s} nh={nh} dh={dh} bhsd={bhsd} mask_kind={:?}", mask_kind);
6604                        }
6605                        // Slice lengths use the source's row stride so the
6606                        // compiler-emitted bounds checks cover the whole
6607                        // strided span (the kernel walks with q/k/v_rs).
6608                        // For [B, H, S, D] the buffer is dense B*H*S*D.
6609                        let (q_len, k_len, v_len, o_len) = if bhsd {
6610                            let qn = b * nh * q_s * dh;
6611                            let kn = b * nh * k_s * dh;
6612                            (qn, kn, kn, qn)
6613                        } else {
6614                            (b * q_s * qrs, b * k_s * krs, b * k_s * vrs, b * q_s * hs)
6615                        };
6616                        let q_d = sl(q, base, q_len);
6617                        let k_d = sl(k, base, k_len);
6618                        let v_d = sl(v, base, v_len);
6619                        let m_d: &[f32] = match mask_kind {
6620                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * k_s),
6621                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * q_s * k_s),
6622                            _ => &[],
6623                        };
6624                        let o_d = sl_mut(out, base, o_len);
6625                        let mut qh = vec![0f32; q_s * dh];
6626                        let mut kh = vec![0f32; k_s * dh];
6627                        let mut vh = vec![0f32; k_s * dh];
6628                        let mut sc = vec![0f32; q_s * k_s];
6629                        let mut oh = vec![0f32; q_s * dh];
6630                        for bi in 0..b {
6631                            for hi in 0..nh {
6632                                // Gather per-head Q.
6633                                for si in 0..q_s {
6634                                    let q_off = if bhsd {
6635                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6636                                    } else {
6637                                        bi * q_s * qrs + si * qrs + hi * dh
6638                                    };
6639                                    qh[si * dh..(si + 1) * dh]
6640                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
6641                                }
6642                                // Gather per-head K, V.
6643                                for si in 0..k_s {
6644                                    let (k_off, v_off) = if bhsd {
6645                                        (
6646                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6647                                            bi * nh * k_s * dh + hi * k_s * dh + si * dh,
6648                                        )
6649                                    } else {
6650                                        (
6651                                            bi * k_s * krs + si * krs + hi * dh,
6652                                            bi * k_s * vrs + si * vrs + hi * dh,
6653                                        )
6654                                    };
6655                                    kh[si * dh..(si + 1) * dh]
6656                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
6657                                    vh[si * dh..(si + 1) * dh]
6658                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
6659                                }
6660                                for qi in 0..q_s {
6661                                    for ki in 0..k_s {
6662                                        let mut dot = 0f32;
6663                                        for d in 0..dh {
6664                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
6665                                        }
6666                                        sc[qi * k_s + ki] = dot * scale;
6667                                    }
6668                                }
6669                                // Apply mask. Causal/SlidingWindow use absolute
6670                                // positions so they handle Lq != Lk (decode mode
6671                                // with cached K/V): q_offset = k_s - q_s.
6672                                let q_offset = k_s.saturating_sub(q_s);
6673                                match mask_kind {
6674                                    rlx_ir::op::MaskKind::None => {}
6675                                    rlx_ir::op::MaskKind::Causal => {
6676                                        for qi in 0..q_s {
6677                                            let abs_q = q_offset + qi;
6678                                            for ki in (abs_q + 1)..k_s {
6679                                                sc[qi * k_s + ki] = mask_neg;
6680                                            }
6681                                        }
6682                                    }
6683                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
6684                                        for qi in 0..q_s {
6685                                            let abs_q = q_offset + qi;
6686                                            let lo = abs_q.saturating_sub(w);
6687                                            for ki in 0..k_s {
6688                                                if ki < lo || ki > abs_q {
6689                                                    sc[qi * k_s + ki] = mask_neg;
6690                                                }
6691                                            }
6692                                        }
6693                                    }
6694                                    rlx_ir::op::MaskKind::Custom => {
6695                                        for qi in 0..q_s {
6696                                            for ki in 0..k_s {
6697                                                if m_d[bi * k_s + ki] < mask_thr {
6698                                                    sc[qi * k_s + ki] = mask_neg;
6699                                                }
6700                                            }
6701                                        }
6702                                    }
6703                                    rlx_ir::op::MaskKind::Bias => {
6704                                        let per_bh = q_s * k_s;
6705                                        let off = (bi * nh + hi) * per_bh;
6706                                        for i in 0..per_bh {
6707                                            sc[i] += m_d[off + i];
6708                                        }
6709                                    }
6710                                }
6711                                crate::naive::softmax(&mut sc, q_s, k_s);
6712                                oh.fill(0.0);
6713                                for qi in 0..q_s {
6714                                    for ki in 0..k_s {
6715                                        let w = sc[qi * k_s + ki];
6716                                        if w > score_skip {
6717                                            for d in 0..dh {
6718                                                oh[qi * dh + d] += w * vh[ki * dh + d];
6719                                            }
6720                                        }
6721                                    }
6722                                }
6723                                for si in 0..q_s {
6724                                    let off = if bhsd {
6725                                        bi * nh * q_s * dh + hi * q_s * dh + si * dh
6726                                    } else {
6727                                        bi * q_s * hs + si * hs + hi * dh
6728                                    };
6729                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
6730                                }
6731                            }
6732                        }
6733                    })
6734                }
6735
6736                Thunk::FusedSwiGLU {
6737                    src,
6738                    dst,
6739                    n_half,
6740                    total,
6741                    gate_first,
6742                } => {
6743                    let n = n_half as usize;
6744                    let t = total as usize;
6745                    let outer = t / n;
6746                    let in_total = outer * 2 * n;
6747                    Arc::new(move |base: *mut u8| unsafe {
6748                        let inp = sl(src, base, in_total);
6749                        let out = sl_mut(dst, base, t);
6750                        for o in 0..outer {
6751                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
6752                            let out_row = &mut out[o * n..(o + 1) * n];
6753                            for i in 0..n {
6754                                let (up, gate) = if gate_first {
6755                                    (in_row[n + i], in_row[i])
6756                                } else {
6757                                    (in_row[i], in_row[n + i])
6758                                };
6759                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
6760                            }
6761                        }
6762                    })
6763                }
6764
6765                Thunk::Concat {
6766                    dst,
6767                    outer,
6768                    inner,
6769                    total_axis,
6770                    inputs,
6771                } => {
6772                    let outer = outer as usize;
6773                    let inner = inner as usize;
6774                    let total_axis = total_axis as usize;
6775                    let out_total = outer * total_axis * inner;
6776                    let mut layout: Vec<(usize, usize, usize, usize)> =
6777                        Vec::with_capacity(inputs.len());
6778                    let mut cum: usize = 0;
6779                    for (src_off, in_axis, in_numel) in &inputs {
6780                        let in_axis = *in_axis as usize;
6781                        layout.push((*src_off, cum * inner, in_axis * inner, *in_numel as usize));
6782                        cum += in_axis;
6783                    }
6784                    Arc::new(move |base: *mut u8| unsafe {
6785                        let out = sl_mut(dst, base, out_total);
6786                        let row_stride = total_axis * inner;
6787                        for (src_off, dst_col_off, copy_per_row, in_numel) in &layout {
6788                            let inp = sl(*src_off, base, (*in_numel).max(1));
6789                            concat_copy_rows_f32(
6790                                out,
6791                                inp,
6792                                outer,
6793                                *copy_per_row,
6794                                row_stride,
6795                                *dst_col_off,
6796                                *in_numel,
6797                            );
6798                        }
6799                    })
6800                }
6801
6802                Thunk::CustomOp {
6803                    kernel,
6804                    inputs,
6805                    output,
6806                    attrs,
6807                } => {
6808                    // Capture-by-move: clone the Arc and Vecs once into the
6809                    // closure. Dispatch by output dtype each call (the
6810                    // dtype is fixed at compile time but it's cheaper to
6811                    // branch once per execution than to monomorphize a
6812                    // dozen closure variants).
6813                    let kernel = kernel.clone();
6814                    let attrs = attrs.clone();
6815                    let inputs = inputs.clone();
6816                    let (out_off, out_len, out_shape) = output.clone();
6817                    Arc::new(move |base: *mut u8| unsafe {
6818                        dispatch_custom_op(
6819                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
6820                        );
6821                    })
6822                }
6823
6824                Thunk::GaussianSplatRender {
6825                    positions_off,
6826                    positions_len,
6827                    scales_off,
6828                    scales_len,
6829                    rotations_off,
6830                    rotations_len,
6831                    opacities_off,
6832                    opacities_len,
6833                    colors_off,
6834                    colors_len,
6835                    sh_coeffs_off,
6836                    sh_coeffs_len,
6837                    meta_off,
6838                    dst_off,
6839                    dst_len,
6840                    width,
6841                    height,
6842                    tile_size,
6843                    radius_scale,
6844                    alpha_cutoff,
6845                    max_splat_steps,
6846                    transmittance_threshold,
6847                    max_list_entries,
6848                } => Arc::new(move |base: *mut u8| unsafe {
6849                    crate::splat::execute_gaussian_splat_render(
6850                        positions_off,
6851                        positions_len,
6852                        scales_off,
6853                        scales_len,
6854                        rotations_off,
6855                        rotations_len,
6856                        opacities_off,
6857                        opacities_len,
6858                        colors_off,
6859                        colors_len,
6860                        sh_coeffs_off,
6861                        sh_coeffs_len,
6862                        meta_off,
6863                        dst_off,
6864                        dst_len,
6865                        width,
6866                        height,
6867                        tile_size,
6868                        radius_scale,
6869                        alpha_cutoff,
6870                        max_splat_steps,
6871                        transmittance_threshold,
6872                        max_list_entries,
6873                        base,
6874                    );
6875                }),
6876
6877                Thunk::GaussianSplatRenderBackward {
6878                    positions_off,
6879                    positions_len,
6880                    scales_off,
6881                    scales_len,
6882                    rotations_off,
6883                    rotations_len,
6884                    opacities_off,
6885                    opacities_len,
6886                    colors_off,
6887                    colors_len,
6888                    sh_coeffs_off,
6889                    sh_coeffs_len,
6890                    meta_off,
6891                    d_loss_off,
6892                    d_loss_len,
6893                    packed_off,
6894                    packed_len,
6895                    width,
6896                    height,
6897                    tile_size,
6898                    radius_scale,
6899                    alpha_cutoff,
6900                    max_splat_steps,
6901                    transmittance_threshold,
6902                    max_list_entries,
6903                    loss_grad_clip,
6904                    sh_band,
6905                    max_anisotropy,
6906                } => Arc::new(move |base: *mut u8| unsafe {
6907                    crate::splat::execute_gaussian_splat_render_backward(
6908                        positions_off,
6909                        positions_len,
6910                        scales_off,
6911                        scales_len,
6912                        rotations_off,
6913                        rotations_len,
6914                        opacities_off,
6915                        opacities_len,
6916                        colors_off,
6917                        colors_len,
6918                        sh_coeffs_off,
6919                        sh_coeffs_len,
6920                        meta_off,
6921                        d_loss_off,
6922                        d_loss_len,
6923                        packed_off,
6924                        packed_len,
6925                        width,
6926                        height,
6927                        tile_size,
6928                        radius_scale,
6929                        alpha_cutoff,
6930                        max_splat_steps,
6931                        transmittance_threshold,
6932                        max_list_entries,
6933                        loss_grad_clip,
6934                        sh_band,
6935                        max_anisotropy,
6936                        base,
6937                    );
6938                }),
6939
6940                Thunk::GaussianSplatPrepare {
6941                    positions_off,
6942                    positions_len,
6943                    scales_off,
6944                    scales_len,
6945                    rotations_off,
6946                    rotations_len,
6947                    opacities_off,
6948                    opacities_len,
6949                    colors_off,
6950                    colors_len,
6951                    sh_coeffs_off,
6952                    sh_coeffs_len,
6953                    meta_off,
6954                    meta_len,
6955                    prep_off,
6956                    prep_len,
6957                    width,
6958                    height,
6959                    tile_size,
6960                    radius_scale,
6961                    alpha_cutoff,
6962                    max_splat_steps,
6963                    transmittance_threshold,
6964                    max_list_entries,
6965                } => Arc::new(move |base: *mut u8| unsafe {
6966                    crate::splat::execute_gaussian_splat_prepare(
6967                        positions_off,
6968                        positions_len,
6969                        scales_off,
6970                        scales_len,
6971                        rotations_off,
6972                        rotations_len,
6973                        opacities_off,
6974                        opacities_len,
6975                        colors_off,
6976                        colors_len,
6977                        sh_coeffs_off,
6978                        sh_coeffs_len,
6979                        meta_off,
6980                        meta_len,
6981                        prep_off,
6982                        prep_len,
6983                        width,
6984                        height,
6985                        tile_size,
6986                        radius_scale,
6987                        alpha_cutoff,
6988                        max_splat_steps,
6989                        transmittance_threshold,
6990                        max_list_entries,
6991                        base,
6992                    );
6993                }),
6994
6995                Thunk::GaussianSplatRasterize {
6996                    prep_off,
6997                    prep_len,
6998                    meta_off,
6999                    meta_len,
7000                    dst_off,
7001                    dst_len,
7002                    count,
7003                    width,
7004                    height,
7005                    tile_size,
7006                    alpha_cutoff,
7007                    max_splat_steps,
7008                    transmittance_threshold,
7009                    max_list_entries,
7010                } => Arc::new(move |base: *mut u8| unsafe {
7011                    crate::splat::execute_gaussian_splat_rasterize(
7012                        prep_off,
7013                        prep_len,
7014                        meta_off,
7015                        meta_len,
7016                        dst_off,
7017                        dst_len,
7018                        count,
7019                        width,
7020                        height,
7021                        tile_size,
7022                        alpha_cutoff,
7023                        max_splat_steps,
7024                        transmittance_threshold,
7025                        max_list_entries,
7026                        base,
7027                    );
7028                }),
7029
7030                Thunk::Fft1d {
7031                    src,
7032                    dst,
7033                    outer,
7034                    n_complex,
7035                    inverse,
7036                    norm_tag,
7037                    dtype,
7038                } => {
7039                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
7040                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
7041                            execute_fft1d_f64(
7042                                src,
7043                                dst,
7044                                outer as usize,
7045                                n_complex as usize,
7046                                inverse,
7047                                norm_tag,
7048                                base,
7049                            );
7050                        }),
7051                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
7052                            execute_fft1d_f32(
7053                                src,
7054                                dst,
7055                                outer as usize,
7056                                n_complex as usize,
7057                                inverse,
7058                                norm_tag,
7059                                base,
7060                            );
7061                        }),
7062                        rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
7063                            execute_fft1d_c64(
7064                                src,
7065                                dst,
7066                                outer as usize,
7067                                n_complex as usize,
7068                                inverse,
7069                                norm_tag,
7070                                base,
7071                            );
7072                        }),
7073                        other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7074                    };
7075                    f
7076                }
7077
7078                Thunk::FftButterflyStage {
7079                    state_src,
7080                    state_dst,
7081                    gate_src,
7082                    rev_src,
7083                    tw_re_src,
7084                    tw_im_src,
7085                    batch,
7086                    n_fft,
7087                    stage,
7088                } => Arc::new(move |base: *mut u8| unsafe {
7089                    execute_fft_butterfly_stage_f32(
7090                        state_src,
7091                        state_dst,
7092                        gate_src,
7093                        rev_src,
7094                        tw_re_src,
7095                        tw_im_src,
7096                        batch as usize,
7097                        n_fft as usize,
7098                        stage as usize,
7099                        base,
7100                    );
7101                }),
7102
7103                Thunk::LogMel {
7104                    spec,
7105                    filters,
7106                    dst,
7107                    outer,
7108                    n_fft,
7109                    n_bins,
7110                    n_mels,
7111                } => Arc::new(move |base: *mut u8| unsafe {
7112                    execute_log_mel_f32(
7113                        spec,
7114                        filters,
7115                        dst,
7116                        outer as usize,
7117                        n_fft as usize,
7118                        n_bins as usize,
7119                        n_mels as usize,
7120                        base,
7121                    );
7122                }),
7123
7124                Thunk::LogMelBackward {
7125                    spec,
7126                    filters,
7127                    dy,
7128                    dst,
7129                    outer,
7130                    n_fft,
7131                    n_bins,
7132                    n_mels,
7133                } => Arc::new(move |base: *mut u8| unsafe {
7134                    execute_log_mel_backward_f32(
7135                        spec,
7136                        filters,
7137                        dy,
7138                        dst,
7139                        outer as usize,
7140                        n_fft as usize,
7141                        n_bins as usize,
7142                        n_mels as usize,
7143                        base,
7144                    );
7145                }),
7146
7147                Thunk::WelchPeaks {
7148                    spec,
7149                    dst,
7150                    welch_batch,
7151                    n_fft,
7152                    n_segments,
7153                    k,
7154                } => Arc::new(move |base: *mut u8| unsafe {
7155                    execute_welch_peaks_f32(
7156                        spec,
7157                        dst,
7158                        welch_batch as usize,
7159                        n_fft as usize,
7160                        n_segments as usize,
7161                        k as usize,
7162                        base,
7163                    );
7164                }),
7165
7166                _ => Arc::new(|_: *mut u8| {}),
7167            }
7168        })
7169        .collect();
7170
7171    // ── Thunk-level attention fusion ──────────────────────
7172    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
7173    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
7174    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
7175        .and_then(|v| v.parse().ok())
7176        .unwrap_or(64);
7177    let should_fuse = thunks.iter().any(|t| match t {
7178        Thunk::Attention { batch, seq, .. } => {
7179            (*batch as usize) * (*seq as usize) <= fuse_threshold
7180        }
7181        _ => false,
7182    });
7183
7184    if should_fuse {
7185        // Build non-Nop index for pattern matching across Nop gaps
7186        let active: Vec<usize> = thunks
7187            .iter()
7188            .enumerate()
7189            .filter(|(_, t)| !matches!(t, Thunk::Nop))
7190            .map(|(i, _)| i)
7191            .collect();
7192
7193        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
7194        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
7195
7196        let mut ai = 0;
7197        while ai < active.len() {
7198            // Helper: get active thunk at offset from current
7199            let a = |off: usize| -> Option<(usize, &Thunk)> {
7200                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
7201            };
7202
7203            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
7204            let matched = (|| {
7205                let (_i0, t0) = a(0)?;
7206                let (_, t1) = a(1)?;
7207                let (_, t2) = a(2)?;
7208                let (_, t3) = a(3)?;
7209
7210                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
7211                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
7212                    Thunk::FusedMmBiasAct {
7213                        a,
7214                        w,
7215                        bias,
7216                        n: _,
7217                        act: None,
7218                        ..
7219                    } => (*a, *w, *bias, true),
7220                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
7221                    _ => return None,
7222                };
7223
7224                // a[1..3] must be Narrows
7225                if !matches!(t1, Thunk::Narrow { .. }) {
7226                    return None;
7227                }
7228                if !matches!(t2, Thunk::Narrow { .. }) {
7229                    return None;
7230                }
7231                if !matches!(t3, Thunk::Narrow { .. }) {
7232                    return None;
7233                }
7234
7235                // Look for optional Rope×2 then Attention
7236                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
7237                    _,
7238                    Thunk::Rope {
7239                        cos, sin, cos_len, ..
7240                    },
7241                )) = a(4)
7242                {
7243                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
7244                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
7245                            (true, 6, *cos, *sin, *cos_len)
7246                        } else {
7247                            return None;
7248                        }
7249                    } else {
7250                        return None;
7251                    }
7252                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
7253                    (false, 4, 0, 0, 0)
7254                } else {
7255                    return None;
7256                };
7257
7258                let (_attn_real_idx, attn_t) = a(attn_ai)?;
7259                let (batch, seq, heads, head_dim, mask) = match attn_t {
7260                    Thunk::Attention {
7261                        batch,
7262                        seq,
7263                        heads,
7264                        head_dim,
7265                        mask,
7266                        ..
7267                    } => (*batch, *seq, *heads, *head_dim, *mask),
7268                    _ => return None,
7269                };
7270
7271                // Next active must be out projection (FusedMmBiasAct or Sgemm)
7272                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
7273                let (out_w, out_b, out_dst) = match out_t {
7274                    Thunk::FusedMmBiasAct {
7275                        w,
7276                        bias,
7277                        c,
7278                        act: None,
7279                        ..
7280                    } => (*w, *bias, *c),
7281                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
7282                    _ => return None,
7283                };
7284
7285                let hs = heads * head_dim;
7286                let total_active = attn_ai + 2; // number of active thunks consumed
7287
7288                Some((
7289                    total_active,
7290                    Thunk::FusedAttnBlock {
7291                        hidden,
7292                        qkv_w,
7293                        out_w,
7294                        mask,
7295                        out: out_dst,
7296                        qkv_b: if has_b { qkv_b } else { 0 },
7297                        out_b: if has_b { out_b } else { 0 },
7298                        cos: cos_off,
7299                        sin: sin_off,
7300                        cos_len: cl,
7301                        batch,
7302                        seq,
7303                        hs,
7304                        nh: heads,
7305                        dh: head_dim,
7306                        has_bias: has_b,
7307                        has_rope,
7308                    },
7309                ))
7310            })();
7311
7312            if let Some((count, fused_thunk)) = matched {
7313                // Mark consumed thunks for removal
7314                for off in 0..count {
7315                    if let Some(&idx) = active.get(ai + off) {
7316                        kill[idx] = true;
7317                    }
7318                }
7319                // Insert replacement at position of the QKV thunk
7320                insertions.push((active[ai], fused_thunk));
7321                ai += count;
7322            } else {
7323                ai += 1;
7324            }
7325        }
7326
7327        // Rebuild thunk list: keep non-killed, insert fused at right positions
7328        if !insertions.is_empty() {
7329            let mut new_thunks = Vec::with_capacity(thunks.len());
7330            let mut insert_idx = 0;
7331            for (i, t) in thunks.into_iter().enumerate() {
7332                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
7333                    new_thunks.push(insertions[insert_idx].1.clone());
7334                    insert_idx += 1;
7335                }
7336                if !kill[i] {
7337                    new_thunks.push(t);
7338                }
7339            }
7340            if cfg.verbose >= 1 {
7341                eprintln!(
7342                    "[rlx] fused_attention: {} attention blocks fused",
7343                    insertions.len()
7344                );
7345            }
7346            thunks = new_thunks;
7347        }
7348    }
7349
7350    // ── Full layer fusion ──────────────────────────────────
7351    // After attention blocks are fused, scan for full layer patterns:
7352    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
7353    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
7354    if should_fuse {
7355        let active: Vec<usize> = thunks
7356            .iter()
7357            .enumerate()
7358            .filter(|(_, t)| !matches!(t, Thunk::Nop))
7359            .map(|(i, _)| i)
7360            .collect();
7361
7362        let mut kill = vec![false; thunks.len()];
7363        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
7364
7365        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
7366
7367        let mut ai = 0;
7368        while ai < active.len() {
7369            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
7370            let bert_match = (|| -> Option<usize> {
7371                let fab = a(ai)?;
7372                let rln1 = a(ai + 1)?;
7373                let ffn1 = a(ai + 2)?;
7374                let ffn2 = a(ai + 3)?;
7375                let rln2 = a(ai + 4)?;
7376
7377                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
7378                    Thunk::FusedAttnBlock {
7379                        hidden,
7380                        qkv_w,
7381                        qkv_b,
7382                        out_w,
7383                        out_b,
7384                        mask,
7385                        batch,
7386                        seq,
7387                        hs,
7388                        nh,
7389                        dh,
7390                        has_bias: true,
7391                        has_rope: false,
7392                        ..
7393                    } => (
7394                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
7395                    ),
7396                    _ => return None,
7397                };
7398                let (ln1_g, ln1_b, eps1) = match rln1 {
7399                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7400                    _ => return None,
7401                };
7402                let (fc1_w, fc1_b, int_dim) = match ffn1 {
7403                    Thunk::FusedMmBiasAct {
7404                        w,
7405                        bias,
7406                        n,
7407                        act: Some(Activation::Gelu),
7408                        ..
7409                    } => (*w, *bias, *n),
7410                    _ => return None,
7411                };
7412                let (fc2_w, fc2_b) = match ffn2 {
7413                    Thunk::FusedMmBiasAct {
7414                        w, bias, act: None, ..
7415                    } => (*w, *bias),
7416                    _ => return None,
7417                };
7418                let (ln2_g, ln2_b, eps2, out) = match rln2 {
7419                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7420                    _ => return None,
7421                };
7422
7423                for off in 0..5 {
7424                    kill[active[ai + off]] = true;
7425                }
7426                insertions.push((
7427                    active[ai],
7428                    Thunk::FusedBertLayer {
7429                        hidden,
7430                        qkv_w,
7431                        qkv_b,
7432                        out_w,
7433                        out_b,
7434                        mask,
7435                        ln1_g,
7436                        ln1_b,
7437                        eps1,
7438                        fc1_w,
7439                        fc1_b,
7440                        fc2_w,
7441                        fc2_b,
7442                        ln2_g,
7443                        ln2_b,
7444                        eps2,
7445                        out,
7446                        batch,
7447                        seq,
7448                        hs,
7449                        nh,
7450                        dh,
7451                        int_dim,
7452                    },
7453                ));
7454                Some(5)
7455            })();
7456            if let Some(n) = bert_match {
7457                ai += n;
7458                continue;
7459            }
7460
7461            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
7462            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
7463            // The body below is kept as reference for when the stride bug is fixed.
7464            #[allow(unreachable_code)]
7465            let nomic_match = (|| -> Option<usize> {
7466                return None; // TODO: fix SwiGLU strided fc2 output mismatch
7467                let fab = a(ai)?;
7468                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
7469                    match fab {
7470                        Thunk::FusedAttnBlock {
7471                            hidden,
7472                            qkv_w,
7473                            out_w,
7474                            mask,
7475                            cos,
7476                            sin,
7477                            cos_len,
7478                            batch,
7479                            seq,
7480                            hs,
7481                            nh,
7482                            dh,
7483                            has_bias: false,
7484                            has_rope: true,
7485                            ..
7486                        } => (
7487                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
7488                            *hs, *nh, *dh,
7489                        ),
7490                        _ => return None,
7491                    };
7492                // FusedResidualLN for LN1
7493                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
7494                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
7495                    _ => return None,
7496                };
7497                // Sgemm (fused fc11+fc12)
7498                let fused_fc_w = match a(ai + 2)? {
7499                    Thunk::Sgemm { b: w, .. } => *w,
7500                    _ => return None,
7501                };
7502                // Narrow×2 for split
7503                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
7504                    return None;
7505                }
7506                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
7507                    return None;
7508                }
7509                // SiLU
7510                if !matches!(
7511                    a(ai + 5)?,
7512                    Thunk::ActivationInPlace {
7513                        act: Activation::Silu,
7514                        ..
7515                    }
7516                ) {
7517                    return None;
7518                }
7519                // BinaryFull(Mul) for gate
7520                if !matches!(
7521                    a(ai + 6)?,
7522                    Thunk::BinaryFull {
7523                        op: BinaryOp::Mul,
7524                        ..
7525                    }
7526                ) {
7527                    return None;
7528                }
7529                // Sgemm (fc2)
7530                let fc2_w = match a(ai + 7)? {
7531                    Thunk::Sgemm { b: w, .. } => *w,
7532                    _ => return None,
7533                };
7534                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
7535                let int_dim = match a(ai + 3)? {
7536                    Thunk::Narrow { inner, .. } => *inner,
7537                    _ => return None,
7538                };
7539                // FusedResidualLN for LN2
7540                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
7541                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
7542                    _ => return None,
7543                };
7544
7545                for off in 0..9 {
7546                    kill[active[ai + off]] = true;
7547                }
7548                insertions.push((
7549                    active[ai],
7550                    Thunk::FusedNomicLayer {
7551                        hidden,
7552                        qkv_w,
7553                        out_w,
7554                        mask,
7555                        cos,
7556                        sin,
7557                        cos_len,
7558                        ln1_g,
7559                        ln1_b,
7560                        eps1,
7561                        fc11_w: fused_fc_w,
7562                        fc12_w: 0,
7563                        fc2_w,
7564                        ln2_g,
7565                        ln2_b,
7566                        eps2,
7567                        out,
7568                        batch,
7569                        seq,
7570                        hs,
7571                        nh,
7572                        dh,
7573                        int_dim,
7574                    },
7575                ));
7576                Some(9)
7577            })();
7578            if let Some(n) = nomic_match {
7579                ai += n;
7580                continue;
7581            }
7582
7583            ai += 1;
7584        }
7585
7586        if !insertions.is_empty() {
7587            let mut new_thunks = Vec::with_capacity(thunks.len());
7588            let mut ins_idx = 0;
7589            for (i, t) in thunks.into_iter().enumerate() {
7590                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
7591                    new_thunks.push(insertions[ins_idx].1.clone());
7592                    ins_idx += 1;
7593                }
7594                if !kill[i] {
7595                    new_thunks.push(t);
7596                }
7597            }
7598            if cfg.verbose >= 1 {
7599                eprintln!(
7600                    "[rlx] fused_layer: {} full transformer layers fused",
7601                    insertions.len()
7602                );
7603            }
7604            thunks = new_thunks;
7605        }
7606    }
7607
7608    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
7609    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
7610    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
7611    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
7612    // preceding Narrow whose dst has no other consumer in this schedule.
7613    // Rewrite Rope to read directly from the parent buffer with the
7614    // parent's row stride; the Narrow becomes a Nop.
7615    //
7616    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
7617    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
7618    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
7619    {
7620        // Collect every byte-offset that's read as a thunk's `src` so
7621        // we know whether a Narrow's dst has consumers other than Rope.
7622        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
7623        for t in &thunks {
7624            for off in thunk_read_offsets(t) {
7625                *read_offsets.entry(off).or_insert(0) += 1;
7626            }
7627        }
7628
7629        let mut fused_count = 0usize;
7630        for i in 0..thunks.len().saturating_sub(1) {
7631            // Look for Rope at i+1 reading from Narrow at i (skip Nops
7632            // between them since the planner left them in place).
7633            let narrow = match &thunks[i] {
7634                Thunk::Narrow { .. } => i,
7635                _ => continue,
7636            };
7637            // Find the next non-Nop thunk
7638            let mut j = narrow + 1;
7639            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
7640                j += 1;
7641            }
7642            if j >= thunks.len() {
7643                continue;
7644            }
7645            // Must be Rope reading Narrow's dst
7646            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
7647                Thunk::Narrow {
7648                    src,
7649                    dst,
7650                    src_stride,
7651                    ..
7652                } => (*src, *dst, *src_stride),
7653                _ => continue,
7654            };
7655            let rope_reads_narrow = matches!(&thunks[j],
7656                Thunk::Rope { src, .. } if *src == n_dst);
7657            if !rope_reads_narrow {
7658                continue;
7659            }
7660            // Conservatively require that the Narrow's dst has exactly
7661            // one reader (the Rope). Anything else and rewriting would
7662            // skip a needed write.
7663            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
7664                continue;
7665            }
7666
7667            // Rewire: Rope reads from Narrow's adjusted source with the
7668            // parent buffer's row stride.
7669            if let Thunk::Rope {
7670                src,
7671                src_row_stride,
7672                ..
7673            } = &mut thunks[j]
7674            {
7675                *src = n_src;
7676                *src_row_stride = n_src_stride;
7677            }
7678            thunks[narrow] = Thunk::Nop;
7679            fused_count += 1;
7680        }
7681
7682        if fused_count > 0 && cfg.verbose >= 1 {
7683            eprintln!(
7684                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
7685                fused_count
7686            );
7687        }
7688    }
7689
7690    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
7691    // For each Attention thunk in the schedule, look up the producers
7692    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
7693    // one consumer (the Attention), rewire Attention to read directly
7694    // from the parent buffer with the parent's row stride. The three
7695    // Narrows become Nops.
7696    //
7697    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
7698    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
7699    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
7700    // per layer × 6 layers = ~14 MB of write traffic gone.
7701    {
7702        let mut read_counts: HashMap<usize, usize> = HashMap::new();
7703        for t in &thunks {
7704            for off in thunk_read_offsets(t) {
7705                *read_counts.entry(off).or_insert(0) += 1;
7706            }
7707        }
7708        // Build dst→index map for fast producer lookup.
7709        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
7710        for (i, t) in thunks.iter().enumerate() {
7711            if let Thunk::Narrow { dst, .. } = t {
7712                dst_to_idx.insert(*dst, i);
7713            }
7714        }
7715
7716        let mut fused_count = 0usize;
7717        for i in 0..thunks.len() {
7718            let (q_off, k_off, v_off) = match &thunks[i] {
7719                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
7720                _ => continue,
7721            };
7722            // All three inputs must come from Narrows.
7723            let q_n = match dst_to_idx.get(&q_off).copied() {
7724                Some(x) => x,
7725                None => continue,
7726            };
7727            let k_n = match dst_to_idx.get(&k_off).copied() {
7728                Some(x) => x,
7729                None => continue,
7730            };
7731            let v_n = match dst_to_idx.get(&v_off).copied() {
7732                Some(x) => x,
7733                None => continue,
7734            };
7735            // Each Narrow's dst must have exactly one reader (this Attn).
7736            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
7737                continue;
7738            }
7739            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
7740                continue;
7741            }
7742            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
7743                continue;
7744            }
7745
7746            let (q_src, q_stride) = match &thunks[q_n] {
7747                Thunk::Narrow {
7748                    src, src_stride, ..
7749                } => (*src, *src_stride),
7750                _ => continue,
7751            };
7752            let (k_src, k_stride) = match &thunks[k_n] {
7753                Thunk::Narrow {
7754                    src, src_stride, ..
7755                } => (*src, *src_stride),
7756                _ => continue,
7757            };
7758            let (v_src, v_stride) = match &thunks[v_n] {
7759                Thunk::Narrow {
7760                    src, src_stride, ..
7761                } => (*src, *src_stride),
7762                _ => continue,
7763            };
7764
7765            if let Thunk::Attention {
7766                q,
7767                k,
7768                v,
7769                q_row_stride,
7770                k_row_stride,
7771                v_row_stride,
7772                ..
7773            } = &mut thunks[i]
7774            {
7775                *q = q_src;
7776                *k = k_src;
7777                *v = v_src;
7778                *q_row_stride = q_stride;
7779                *k_row_stride = k_stride;
7780                *v_row_stride = v_stride;
7781            }
7782            thunks[q_n] = Thunk::Nop;
7783            thunks[k_n] = Thunk::Nop;
7784            thunks[v_n] = Thunk::Nop;
7785            fused_count += 1;
7786        }
7787
7788        if fused_count > 0 && cfg.verbose >= 1 {
7789            eprintln!(
7790                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
7791                fused_count
7792            );
7793        }
7794    }
7795
7796    ThunkSchedule {
7797        thunks,
7798        moe_resident: None,
7799        moe_resident_layers: None,
7800        moe_topk_capture: None,
7801        mask_threshold: cfg.mask_binary_threshold,
7802        mask_neg_inf: cfg.attn_mask_neg_inf,
7803        score_skip: cfg.score_skip_threshold,
7804        compiled_fns,
7805        rng: rng_shared,
7806    }
7807}
7808
7809fn get_len(graph: &Graph, id: NodeId) -> usize {
7810    graph.node(id).shape.num_elements().unwrap_or(0)
7811}
7812
7813/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
7814fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
7815    let dims = graph.node(id).shape.dims();
7816    let mut out = Vec::with_capacity(dims.len());
7817    for d in dims {
7818        if let Some(s) = match d {
7819            rlx_ir::Dim::Static(s) => Some(*s),
7820            _ => None,
7821        } {
7822            out.push(s);
7823        } else {
7824            return Vec::new();
7825        }
7826    }
7827    out
7828}
7829
7830/// Extent along `axis` for a concat input, treating leading implicit 1s when
7831/// `input.rank() < output.rank()` (ONNX / numpy concat broadcast rules).
7832fn concat_axis_extent(input: &rlx_ir::Shape, axis: usize, out_rank: usize) -> usize {
7833    let in_rank = input.rank();
7834    if axis >= out_rank {
7835        return 1;
7836    }
7837    if axis < in_rank {
7838        input.dim(axis).unwrap_static()
7839    } else {
7840        1
7841    }
7842}
7843
7844fn broadcast_src_index(src_idx: usize, in_len: usize) -> usize {
7845    if in_len == 0 { 0 } else { src_idx % in_len }
7846}
7847
7848fn concat_copy_rows_f32(
7849    out: &mut [f32],
7850    inp: &[f32],
7851    outer: usize,
7852    copy_per_row: usize,
7853    row_stride: usize,
7854    dst_col_off: usize,
7855    in_numel: usize,
7856) {
7857    let need = outer.saturating_mul(copy_per_row.max(1));
7858    let broadcast_outer = in_numel < need;
7859    for o in 0..outer {
7860        let dst_row_start = o * row_stride + dst_col_off;
7861        if broadcast_outer {
7862            if in_numel == 1 {
7863                if copy_per_row == 1 {
7864                    out[dst_row_start] = inp[0];
7865                } else {
7866                    out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7867                }
7868            } else if copy_per_row <= inp.len() {
7869                out[dst_row_start..dst_row_start + copy_per_row]
7870                    .copy_from_slice(&inp[..copy_per_row]);
7871            } else if !inp.is_empty() {
7872                out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7873            }
7874        } else {
7875            let src_row_start = o * copy_per_row;
7876            out[dst_row_start..dst_row_start + copy_per_row]
7877                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
7878        }
7879    }
7880}
7881
7882fn concat_copy_rows_f64(
7883    out: &mut [f64],
7884    inp: &[f64],
7885    outer: usize,
7886    copy_per_row: usize,
7887    row_stride: usize,
7888    dst_col_off: usize,
7889    in_numel: usize,
7890) {
7891    let need = outer.saturating_mul(copy_per_row.max(1));
7892    let broadcast_outer = in_numel < need;
7893    for o in 0..outer {
7894        let dst_row_start = o * row_stride + dst_col_off;
7895        if broadcast_outer {
7896            if in_numel == 1 {
7897                if copy_per_row == 1 {
7898                    out[dst_row_start] = inp[0];
7899                } else {
7900                    out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7901                }
7902            } else if copy_per_row <= inp.len() {
7903                out[dst_row_start..dst_row_start + copy_per_row]
7904                    .copy_from_slice(&inp[..copy_per_row]);
7905            } else if !inp.is_empty() {
7906                out[dst_row_start..dst_row_start + copy_per_row].fill(inp[0]);
7907            }
7908        } else {
7909            let src_row_start = o * copy_per_row;
7910            out[dst_row_start..dst_row_start + copy_per_row]
7911                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
7912        }
7913    }
7914}
7915
7916/// NumPy-style broadcast strides for one operand into the flat output
7917/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
7918/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
7919/// (after left-padding with size-1 to match ranks), otherwise the
7920/// natural row-major stride into the *input* buffer.
7921///
7922/// Caller iterates output flat index `i` → output coords (row-major)
7923/// → input flat index = dot(coords, strides). The result is correct
7924/// for any broadcast pattern (scalar, last-axis, middle-axis,
7925/// bidirectional).
7926/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
7927/// — i.e. every rhs dim either equals the corresponding output dim
7928/// (counting from the right) or rhs is shorter (left-padded with 1s).
7929/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
7930/// where `c > 1`) are NOT trailing broadcasts and require the
7931/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
7932/// kernel silently miscomputes them.
7933fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
7934    if rhs_dims.len() > out_dims.len() {
7935        return false;
7936    }
7937    let off = out_dims.len() - rhs_dims.len();
7938    for i in 0..rhs_dims.len() {
7939        let r = match rhs_dims[i] {
7940            rlx_ir::Dim::Static(n) => n,
7941            _ => return false,
7942        };
7943        let o = match out_dims[off + i] {
7944            rlx_ir::Dim::Static(n) => n,
7945            _ => return false,
7946        };
7947        if r != o {
7948            return false;
7949        }
7950    }
7951    true
7952}
7953
7954fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
7955    let r_out = out_dims.len();
7956    let r_in = in_dims.len();
7957    assert!(
7958        r_in <= r_out,
7959        "broadcast: input rank {r_in} > output rank {r_out}"
7960    );
7961    let pad = r_out - r_in;
7962    let mut strides = vec![0u32; r_out];
7963    let mut acc: usize = 1;
7964    for d in (0..r_out).rev() {
7965        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
7966        if in_size == 1 {
7967            strides[d] = 0;
7968        } else {
7969            assert_eq!(
7970                in_size, out_dims[d],
7971                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
7972                out_dims[d]
7973            );
7974            strides[d] = acc as u32;
7975            acc *= in_size;
7976        }
7977    }
7978    strides
7979}
7980
7981/// Execute a thunk schedule on a raw arena buffer.
7982/// Fastest executor: call pre-compiled closures sequentially.
7983/// Zero match dispatch — each closure is a direct kernel call.
7984pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
7985    let base = arena_buf.as_mut_ptr();
7986    for f in &schedule.compiled_fns {
7987        f(base);
7988    }
7989}
7990
7991/// Active-extent execution stub. The runtime calls this when it has an
7992/// active-extent hint set. CPU doesn't implement per-thunk active-extent
7993/// scaling yet — return false so the caller falls back to the full
7994/// `execute_thunks` path.
7995pub fn execute_thunks_active(
7996    schedule: &ThunkSchedule,
7997    _arena_buf: &mut [u8],
7998    _actual: usize,
7999    _upper: usize,
8000) -> bool {
8001    let _ = schedule;
8002    false
8003}
8004
8005/// Match-based executor (fallback, used by tests).
8006struct MoeResidencyGuard;
8007impl Drop for MoeResidencyGuard {
8008    fn drop(&mut self) {
8009        if let Some(stats) = crate::moe_residency::take_stats() {
8010            crate::moe_residency::stash_last_forward_stats(stats);
8011        } else {
8012            crate::moe_residency::clear_mask();
8013        }
8014    }
8015}
8016
8017fn thunk_kind_name(t: &Thunk) -> &'static str {
8018    match t {
8019        Thunk::Nop => "Nop",
8020        Thunk::Gather { .. } => "Gather",
8021        Thunk::GatherAxis { .. } => "GatherAxis",
8022        Thunk::TopK { .. } => "TopK",
8023        Thunk::Copy { .. } => "Copy",
8024        Thunk::CopyF64 { .. } => "CopyF64",
8025        Thunk::CopyI64 { .. } => "CopyI64",
8026        Thunk::CastF32ToI64 { .. } => "CastF32ToI64",
8027        Thunk::CastI64ToF32 { .. } => "CastI64ToF32",
8028        Thunk::CastBoolToI32 { .. } => "CastBoolToI32",
8029        Thunk::CastBoolToF32 { .. } => "CastBoolToF32",
8030        Thunk::CastI32ToF32 { .. } => "CastI32ToF32",
8031        Thunk::Transpose { .. } => "Transpose",
8032        Thunk::TransposeF64 { .. } => "TransposeF64",
8033        Thunk::Where { .. } => "Where",
8034        Thunk::Compare { .. } => "Compare",
8035        Thunk::BinaryFull { .. } => "BinaryFull",
8036        Thunk::BinaryFullF64 { .. } => "BinaryFullF64",
8037        Thunk::Sgemm { .. } => "Sgemm",
8038        Thunk::Dgemm { .. } => "Dgemm",
8039        Thunk::FusedMmBiasAct { .. } => "FusedMmBiasAct",
8040        Thunk::BiasAdd { .. } => "BiasAdd",
8041        Thunk::LayerNorm { .. } => "LayerNorm",
8042        Thunk::Softmax { .. } => "Softmax",
8043        Thunk::Conv2D { .. } => "Conv2D",
8044        Thunk::Conv2D1x1 { .. } => "Conv2D1x1",
8045        Thunk::CustomOp { .. } => "CustomOp",
8046        Thunk::ActivationInPlace { .. } => "ActivationInPlace",
8047        Thunk::Narrow { .. } => "Narrow",
8048        Thunk::Cumsum { .. } => "Cumsum",
8049        Thunk::Reduce { .. } => "Reduce",
8050        Thunk::BatchedSgemm { .. } => "BatchedSgemm",
8051        Thunk::DequantMatMul { .. } => "DequantMatMul",
8052        Thunk::Quantize { .. } => "Quantize",
8053        Thunk::Dequantize { .. } => "Dequantize",
8054        Thunk::ConvTranspose2d { .. } => "ConvTranspose2d",
8055        Thunk::ResizeNearest2x { .. } => "ResizeNearest2x",
8056        _ => "Other",
8057    }
8058}
8059
8060pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
8061    crate::moe_residency::reset_gmm_counters();
8062    if let Some(layers) = schedule.moe_resident_layers.clone() {
8063        crate::moe_residency::set_per_layer_masks(Some(layers));
8064    } else {
8065        crate::moe_residency::set_mask(schedule.moe_resident.clone());
8066    }
8067    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
8068        cap.clear();
8069    }
8070    let _moe_guard = MoeResidencyGuard;
8071    let base = arena_buf.as_mut_ptr();
8072    let mask_thr = schedule.mask_threshold;
8073    let mask_neg = schedule.mask_neg_inf;
8074    let score_thr = schedule.score_skip;
8075    let thunks = &schedule.thunks;
8076    let len = thunks.len();
8077
8078    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
8079    let max_h = thunks
8080        .iter()
8081        .filter_map(|t| match t {
8082            Thunk::FusedResidualLN { h, .. }
8083            | Thunk::FusedResidualRmsNorm { h, .. }
8084            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
8085            _ => None,
8086        })
8087        .max()
8088        .unwrap_or(0);
8089    let zero_bias = vec![0f32; max_h];
8090
8091    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
8092    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
8093    let max_sdpa = thunks
8094        .iter()
8095        .filter_map(|t| match t {
8096            Thunk::Attention {
8097                batch,
8098                seq,
8099                kv_seq,
8100                heads,
8101                head_dim,
8102                ..
8103            } => Some((
8104                *batch as usize,
8105                (*seq as usize).max(*kv_seq as usize),
8106                *heads as usize,
8107                *head_dim as usize,
8108            )),
8109            _ => None,
8110        })
8111        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
8112            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
8113        });
8114    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
8115    let max_units = max_batch * max_heads;
8116    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
8117
8118    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
8119    let fl = thunks
8120        .iter()
8121        .filter_map(|t| match t {
8122            Thunk::FusedBertLayer {
8123                batch,
8124                seq,
8125                hs,
8126                int_dim,
8127                ..
8128            } => {
8129                let m = (*batch as usize) * (*seq as usize);
8130                let h = *hs as usize;
8131                let id = *int_dim as usize;
8132                Some((m, h, id, m * (*seq as usize)))
8133            }
8134            Thunk::FusedNomicLayer {
8135                batch,
8136                seq,
8137                hs,
8138                int_dim,
8139                ..
8140            } => {
8141                let m = (*batch as usize) * (*seq as usize);
8142                let h = *hs as usize;
8143                let id = *int_dim as usize;
8144                Some((m, h, id, m * (*seq as usize)))
8145            }
8146            _ => None,
8147        })
8148        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
8149            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
8150        });
8151    let (fl_m, fl_h, fl_int, fl_ss) = fl;
8152    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
8153    let mut fl_attn = vec![0f32; fl_m * fl_h];
8154    let mut fl_res = vec![0f32; fl_m * fl_h];
8155    let mut fl_normed = vec![0f32; fl_m * fl_h];
8156    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
8157    let mut fl_sc = vec![0f32; fl_ss.max(1)];
8158
8159    let trace_thunks = std::env::var_os("RLX_TRACE_THUNK").is_some();
8160    if trace_thunks {
8161        eprintln!(
8162            "[thunk] prealloc max_h={max_h} sdpa={} fl_m={fl_m} fl_h={fl_h} fl_int={fl_int}",
8163            max_units * max_seq * max_seq
8164        );
8165    }
8166    for i in 0..len {
8167        let thunk = unsafe { thunks.get_unchecked(i) };
8168        if trace_thunks && (i < 120 || i % 200 == 0 || i + 1 == len) {
8169            eprintln!("[thunk {i}/{len}] {}", thunk_kind_name(thunk));
8170        }
8171        let trace_done = trace_thunks && i < 120;
8172        match thunk {
8173            Thunk::Nop => {}
8174
8175            Thunk::GaussianSplatRender {
8176                positions_off,
8177                positions_len,
8178                scales_off,
8179                scales_len,
8180                rotations_off,
8181                rotations_len,
8182                opacities_off,
8183                opacities_len,
8184                colors_off,
8185                colors_len,
8186                sh_coeffs_off,
8187                sh_coeffs_len,
8188                meta_off,
8189                dst_off,
8190                dst_len,
8191                width,
8192                height,
8193                tile_size,
8194                radius_scale,
8195                alpha_cutoff,
8196                max_splat_steps,
8197                transmittance_threshold,
8198                max_list_entries,
8199            } => unsafe {
8200                crate::splat::execute_gaussian_splat_render(
8201                    *positions_off,
8202                    *positions_len,
8203                    *scales_off,
8204                    *scales_len,
8205                    *rotations_off,
8206                    *rotations_len,
8207                    *opacities_off,
8208                    *opacities_len,
8209                    *colors_off,
8210                    *colors_len,
8211                    *sh_coeffs_off,
8212                    *sh_coeffs_len,
8213                    *meta_off,
8214                    *dst_off,
8215                    *dst_len,
8216                    *width,
8217                    *height,
8218                    *tile_size,
8219                    *radius_scale,
8220                    *alpha_cutoff,
8221                    *max_splat_steps,
8222                    *transmittance_threshold,
8223                    *max_list_entries,
8224                    base,
8225                );
8226            },
8227
8228            Thunk::GaussianSplatRenderBackward {
8229                positions_off,
8230                positions_len,
8231                scales_off,
8232                scales_len,
8233                rotations_off,
8234                rotations_len,
8235                opacities_off,
8236                opacities_len,
8237                colors_off,
8238                colors_len,
8239                sh_coeffs_off,
8240                sh_coeffs_len,
8241                meta_off,
8242                d_loss_off,
8243                d_loss_len,
8244                packed_off,
8245                packed_len,
8246                width,
8247                height,
8248                tile_size,
8249                radius_scale,
8250                alpha_cutoff,
8251                max_splat_steps,
8252                transmittance_threshold,
8253                max_list_entries,
8254                loss_grad_clip,
8255                sh_band,
8256                max_anisotropy,
8257            } => unsafe {
8258                crate::splat::execute_gaussian_splat_render_backward(
8259                    *positions_off,
8260                    *positions_len,
8261                    *scales_off,
8262                    *scales_len,
8263                    *rotations_off,
8264                    *rotations_len,
8265                    *opacities_off,
8266                    *opacities_len,
8267                    *colors_off,
8268                    *colors_len,
8269                    *sh_coeffs_off,
8270                    *sh_coeffs_len,
8271                    *meta_off,
8272                    *d_loss_off,
8273                    *d_loss_len,
8274                    *packed_off,
8275                    *packed_len,
8276                    *width,
8277                    *height,
8278                    *tile_size,
8279                    *radius_scale,
8280                    *alpha_cutoff,
8281                    *max_splat_steps,
8282                    *transmittance_threshold,
8283                    *max_list_entries,
8284                    *loss_grad_clip,
8285                    *sh_band,
8286                    *max_anisotropy,
8287                    base,
8288                );
8289            },
8290
8291            Thunk::GaussianSplatPrepare {
8292                positions_off,
8293                positions_len,
8294                scales_off,
8295                scales_len,
8296                rotations_off,
8297                rotations_len,
8298                opacities_off,
8299                opacities_len,
8300                colors_off,
8301                colors_len,
8302                sh_coeffs_off,
8303                sh_coeffs_len,
8304                meta_off,
8305                meta_len,
8306                prep_off,
8307                prep_len,
8308                width,
8309                height,
8310                tile_size,
8311                radius_scale,
8312                alpha_cutoff,
8313                max_splat_steps,
8314                transmittance_threshold,
8315                max_list_entries,
8316            } => unsafe {
8317                crate::splat::execute_gaussian_splat_prepare(
8318                    *positions_off,
8319                    *positions_len,
8320                    *scales_off,
8321                    *scales_len,
8322                    *rotations_off,
8323                    *rotations_len,
8324                    *opacities_off,
8325                    *opacities_len,
8326                    *colors_off,
8327                    *colors_len,
8328                    *sh_coeffs_off,
8329                    *sh_coeffs_len,
8330                    *meta_off,
8331                    *meta_len,
8332                    *prep_off,
8333                    *prep_len,
8334                    *width,
8335                    *height,
8336                    *tile_size,
8337                    *radius_scale,
8338                    *alpha_cutoff,
8339                    *max_splat_steps,
8340                    *transmittance_threshold,
8341                    *max_list_entries,
8342                    base,
8343                );
8344            },
8345
8346            Thunk::GaussianSplatRasterize {
8347                prep_off,
8348                prep_len,
8349                meta_off,
8350                meta_len,
8351                dst_off,
8352                dst_len,
8353                count,
8354                width,
8355                height,
8356                tile_size,
8357                alpha_cutoff,
8358                max_splat_steps,
8359                transmittance_threshold,
8360                max_list_entries,
8361            } => unsafe {
8362                crate::splat::execute_gaussian_splat_rasterize(
8363                    *prep_off,
8364                    *prep_len,
8365                    *meta_off,
8366                    *meta_len,
8367                    *dst_off,
8368                    *dst_len,
8369                    *count,
8370                    *width,
8371                    *height,
8372                    *tile_size,
8373                    *alpha_cutoff,
8374                    *max_splat_steps,
8375                    *transmittance_threshold,
8376                    *max_list_entries,
8377                    base,
8378                );
8379            },
8380
8381            Thunk::Fft1d {
8382                src,
8383                dst,
8384                outer,
8385                n_complex,
8386                inverse,
8387                norm_tag,
8388                dtype,
8389            } => unsafe {
8390                match dtype {
8391                    rlx_ir::DType::F64 => execute_fft1d_f64(
8392                        *src,
8393                        *dst,
8394                        *outer as usize,
8395                        *n_complex as usize,
8396                        *inverse,
8397                        *norm_tag,
8398                        base,
8399                    ),
8400                    rlx_ir::DType::F32 => execute_fft1d_f32(
8401                        *src,
8402                        *dst,
8403                        *outer as usize,
8404                        *n_complex as usize,
8405                        *inverse,
8406                        *norm_tag,
8407                        base,
8408                    ),
8409                    rlx_ir::DType::C64 => execute_fft1d_c64(
8410                        *src,
8411                        *dst,
8412                        *outer as usize,
8413                        *n_complex as usize,
8414                        *inverse,
8415                        *norm_tag,
8416                        base,
8417                    ),
8418                    other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
8419                }
8420            },
8421
8422            Thunk::FftButterflyStage {
8423                state_src,
8424                state_dst,
8425                gate_src,
8426                rev_src,
8427                tw_re_src,
8428                tw_im_src,
8429                batch,
8430                n_fft,
8431                stage,
8432            } => unsafe {
8433                execute_fft_butterfly_stage_f32(
8434                    *state_src,
8435                    *state_dst,
8436                    *gate_src,
8437                    *rev_src,
8438                    *tw_re_src,
8439                    *tw_im_src,
8440                    *batch as usize,
8441                    *n_fft as usize,
8442                    *stage as usize,
8443                    base,
8444                );
8445            },
8446
8447            Thunk::LogMel {
8448                spec,
8449                filters,
8450                dst,
8451                outer,
8452                n_fft,
8453                n_bins,
8454                n_mels,
8455            } => unsafe {
8456                execute_log_mel_f32(
8457                    *spec,
8458                    *filters,
8459                    *dst,
8460                    *outer as usize,
8461                    *n_fft as usize,
8462                    *n_bins as usize,
8463                    *n_mels as usize,
8464                    base,
8465                );
8466            },
8467
8468            Thunk::LogMelBackward {
8469                spec,
8470                filters,
8471                dy,
8472                dst,
8473                outer,
8474                n_fft,
8475                n_bins,
8476                n_mels,
8477            } => unsafe {
8478                execute_log_mel_backward_f32(
8479                    *spec,
8480                    *filters,
8481                    *dy,
8482                    *dst,
8483                    *outer as usize,
8484                    *n_fft as usize,
8485                    *n_bins as usize,
8486                    *n_mels as usize,
8487                    base,
8488                );
8489            },
8490
8491            Thunk::WelchPeaks {
8492                spec,
8493                dst,
8494                welch_batch,
8495                n_fft,
8496                n_segments,
8497                k,
8498            } => unsafe {
8499                execute_welch_peaks_f32(
8500                    *spec,
8501                    *dst,
8502                    *welch_batch as usize,
8503                    *n_fft as usize,
8504                    *n_segments as usize,
8505                    *k as usize,
8506                    base,
8507                );
8508            },
8509
8510            // CustomFn dispatch (interpreted path). Mirrors the
8511            // pre-compiled-closure variant elsewhere in this file.
8512            // Patched by rlx-eda.
8513            Thunk::CustomFn {
8514                body,
8515                body_init,
8516                inputs,
8517                body_output_off,
8518                outer_output_off,
8519                out_bytes,
8520            } => {
8521                let mut body_buf: Vec<u8> = (**body_init).clone();
8522                unsafe {
8523                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
8524                        let src = (base as *const u8).add(*outer_in_off);
8525                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
8526                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
8527                    }
8528                }
8529                execute_thunks(body, &mut body_buf);
8530                unsafe {
8531                    let src = body_buf.as_ptr().add(*body_output_off);
8532                    let dst = base.add(*outer_output_off);
8533                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
8534                }
8535            }
8536
8537            Thunk::Sgemm { a, b, c, m, k, n } => {
8538                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8539                if trace_thunks {
8540                    eprintln!("[sgemm] m={m} k={k} n={n} a={} b={} c={}", *a, *b, *c);
8541                }
8542                let c_len = m.saturating_mul(n);
8543                let a_len = m.saturating_mul(k);
8544                let b_len = k.saturating_mul(n);
8545                let arena_len = arena_buf.len();
8546                let max_a = (arena_len.saturating_sub(*a)) / 4;
8547                let max_b = (arena_len.saturating_sub(*b)) / 4;
8548                let max_c = (arena_len.saturating_sub(*c)) / 4;
8549                let a_len = a_len.min(max_a);
8550                let b_len = b_len.min(max_b);
8551                let c_len = c_len.min(max_c);
8552                unsafe {
8553                    let a_sl = sl(*a, base, a_len);
8554                    let b_sl = sl(*b, base, b_len);
8555                    let c_sl = sl_mut(*c, base, c_len);
8556                    if std::ptr::eq(a_sl.as_ptr(), c_sl.as_ptr())
8557                        || std::ptr::eq(b_sl.as_ptr(), c_sl.as_ptr())
8558                    {
8559                        let mut tmp = vec![0.0f32; c_len];
8560                        crate::blas::sgemm_auto(a_sl, b_sl, &mut tmp, m, k, n);
8561                        c_sl.copy_from_slice(&tmp);
8562                    } else {
8563                        crate::blas::sgemm_auto(a_sl, b_sl, c_sl, m, k, n);
8564                    }
8565                }
8566            }
8567
8568            Thunk::CgemmC64 { a, b, c, m, k, n } => unsafe {
8569                cgemm_c64(*a, *b, *c, *m as usize, *k as usize, *n as usize, base);
8570            },
8571
8572            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
8573                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8574                // LAPACK overwrites both A and B; clone into scratch
8575                // each call. Caller's A and b must be preserved for
8576                // VJP recompute. (Eventually: swap to a factor-once /
8577                // solve-many scheme; that's the symbolic-reuse story
8578                // and lives with the sparse path.)
8579                unsafe {
8580                    let a_src = sl_f64(*a, base, n_ * n_);
8581                    let b_src = sl_f64(*b, base, n_ * nrhs_);
8582                    let mut a_scratch: Vec<f64> = a_src.to_vec();
8583                    let mut x_buf: Vec<f64> = b_src.to_vec();
8584                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8585                    if info != 0 {
8586                        panic!(
8587                            "DenseSolveF64: dgesv reported singular matrix \
8588                                (info={info}, n={n_}, nrhs={nrhs_})"
8589                        );
8590                    }
8591                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
8592                    dst.copy_from_slice(&x_buf);
8593                }
8594            }
8595
8596            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
8597                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
8598                unsafe {
8599                    let a_src = sl(*a, base, n_ * n_);
8600                    let b_src = sl(*b, base, n_ * nrhs_);
8601                    let mut a_scratch: Vec<f32> = a_src.to_vec();
8602                    let mut x_buf: Vec<f32> = b_src.to_vec();
8603                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8604                    if info != 0 {
8605                        panic!(
8606                            "DenseSolveF32: sgesv reported singular matrix \
8607                             (info={info}, n={n_}, nrhs={nrhs_})"
8608                        );
8609                    }
8610                    let dst = sl_mut(*x, base, n_ * nrhs_);
8611                    dst.copy_from_slice(&x_buf);
8612                }
8613            }
8614
8615            Thunk::BatchedDenseSolveF64 {
8616                a,
8617                b,
8618                x,
8619                batch,
8620                n,
8621                nrhs,
8622            } => {
8623                // Per slice: extract A_i and b_i, dgesv, write x_i.
8624                // LAPACK has no batched dgesv on Accelerate, so this
8625                // is a serial loop over the batch axis. cuSOLVER /
8626                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
8627                // the GPU path — we'll wire that in rlx-cuda when
8628                // someone needs Linux+CUDA.
8629                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8630                let a_stride = n_ * n_;
8631                let b_stride = n_ * nrhs_;
8632                unsafe {
8633                    let a_full = sl_f64(*a, base, b_ * a_stride);
8634                    let b_full = sl_f64(*b, base, b_ * b_stride);
8635                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
8636                    for bi in 0..b_ {
8637                        let mut a_scratch: Vec<f64> =
8638                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8639                        let mut x_buf: Vec<f64> =
8640                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8641                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8642                        if info != 0 {
8643                            panic!(
8644                                "BatchedDenseSolveF64: slice {bi} \
8645                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
8646                            );
8647                        }
8648                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8649                    }
8650                }
8651            }
8652
8653            Thunk::BatchedDenseSolveF32 {
8654                a,
8655                b,
8656                x,
8657                batch,
8658                n,
8659                nrhs,
8660            } => {
8661                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
8662                let a_stride = n_ * n_;
8663                let b_stride = n_ * nrhs_;
8664                unsafe {
8665                    let a_full = sl(*a, base, b_ * a_stride);
8666                    let b_full = sl(*b, base, b_ * b_stride);
8667                    let x_full = sl_mut(*x, base, b_ * b_stride);
8668                    for bi in 0..b_ {
8669                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
8670                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
8671                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
8672                        if info != 0 {
8673                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
8674                        }
8675                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
8676                    }
8677                }
8678            }
8679
8680            Thunk::BatchedDgemmF64 {
8681                a,
8682                b,
8683                c,
8684                batch,
8685                m,
8686                k,
8687                n,
8688            } => {
8689                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8690                let a_stride = m_ * k_;
8691                let b_stride = k_ * n_;
8692                let c_stride = m_ * n_;
8693                unsafe {
8694                    let a_full = sl_f64(*a, base, b_ * a_stride);
8695                    let b_full = sl_f64(*b, base, b_ * b_stride);
8696                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
8697                    for bi in 0..b_ {
8698                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
8699                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
8700                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
8701                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
8702                    }
8703                }
8704            }
8705
8706            Thunk::BatchedSgemm {
8707                a,
8708                b,
8709                c,
8710                batch,
8711                m,
8712                k,
8713                n,
8714            } => {
8715                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
8716                if trace_thunks {
8717                    eprintln!(
8718                        "[batched-sgemm] batch={b_} m={m_} k={k_} n={n_} a={} b={} c={}",
8719                        *a, *b, *c
8720                    );
8721                }
8722                let a_stride = m_.saturating_mul(k_);
8723                let b_stride = k_.saturating_mul(n_);
8724                let c_stride = m_.saturating_mul(n_);
8725                let arena_len = arena_buf.len();
8726                let a_cap = (arena_len.saturating_sub(*a)) / 4;
8727                let b_cap = (arena_len.saturating_sub(*b)) / 4;
8728                let c_cap = (arena_len.saturating_sub(*c)) / 4;
8729                let a_elems = (b_ * a_stride).min(a_cap);
8730                let b_elems = (b_ * b_stride).min(b_cap);
8731                let c_elems = (b_ * c_stride).min(c_cap);
8732                let b_eff = b_
8733                    .min(a_elems.checked_div(a_stride).unwrap_or(0))
8734                    .min(b_elems.checked_div(b_stride).unwrap_or(0))
8735                    .min(c_elems.checked_div(c_stride).unwrap_or(0));
8736                unsafe {
8737                    let a_full = sl(*a, base, a_elems);
8738                    let b_full = sl(*b, base, b_elems);
8739                    let c_full = sl_mut(*c, base, c_elems);
8740                    for bi in 0..b_eff {
8741                        let a0 = bi * a_stride;
8742                        let b0 = bi * b_stride;
8743                        let c0 = bi * c_stride;
8744                        if a0 + a_stride > a_full.len()
8745                            || b0 + b_stride > b_full.len()
8746                            || c0 + c_stride > c_full.len()
8747                        {
8748                            break;
8749                        }
8750                        let a_slice = &a_full[a0..a0 + a_stride];
8751                        let b_slice = &b_full[b0..b0 + b_stride];
8752                        let c_slice = &mut c_full[c0..c0 + c_stride];
8753                        if std::ptr::eq(a_slice.as_ptr(), c_slice.as_mut_ptr())
8754                            || std::ptr::eq(b_slice.as_ptr(), c_slice.as_mut_ptr())
8755                        {
8756                            let mut tmp = vec![0.0f32; c_stride];
8757                            crate::blas::sgemm_auto(a_slice, b_slice, &mut tmp, m_, k_, n_);
8758                            c_slice.copy_from_slice(&tmp);
8759                        } else {
8760                            crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
8761                        }
8762                    }
8763                }
8764            }
8765
8766            Thunk::Dgemm { a, b, c, m, k, n } => {
8767                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8768                unsafe {
8769                    crate::blas::dgemm(
8770                        sl_f64(*a, base, m * k),
8771                        sl_f64(*b, base, k * n),
8772                        sl_mut_f64(*c, base, m * n),
8773                        m,
8774                        k,
8775                        n,
8776                    );
8777                }
8778            }
8779
8780            Thunk::TransposeF64 {
8781                src,
8782                dst,
8783                in_total,
8784                out_dims,
8785                in_strides,
8786            } => unsafe {
8787                let inp = sl_f64(*src, base, *in_total as usize);
8788                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
8789                let out = sl_mut_f64(*dst, base, out_total);
8790                transpose_walk_f64(inp, out, out_dims, in_strides);
8791            },
8792
8793            Thunk::ActivationF64 {
8794                src,
8795                dst,
8796                len,
8797                kind,
8798            } => {
8799                let len = *len as usize;
8800                unsafe {
8801                    let inp = sl_f64(*src, base, len);
8802                    let out = sl_mut_f64(*dst, base, len);
8803                    apply_activation_f64(inp, out, *kind);
8804                }
8805            }
8806
8807            Thunk::ReduceSumF64 {
8808                src,
8809                dst,
8810                outer,
8811                reduced,
8812                inner,
8813            } => {
8814                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
8815                unsafe {
8816                    let inp = sl_f64(*src, base, o * r * n);
8817                    let out = sl_mut_f64(*dst, base, o * n);
8818                    reduce_sum_f64(inp, out, o, r, n);
8819                }
8820            }
8821
8822            Thunk::CopyF64 { src, dst, len } => {
8823                let mut len = *len as usize;
8824                if *src == *dst || len == 0 {
8825                    continue;
8826                }
8827                let arena_len = arena_buf.len();
8828                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8829                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8830                len = len.min(max_from_src).min(max_from_dst);
8831                if len == 0 {
8832                    continue;
8833                }
8834                let byte_len = len.saturating_mul(8);
8835                unsafe {
8836                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8837                }
8838            }
8839
8840            Thunk::CopyI64 { src, dst, len } => {
8841                let mut len = *len as usize;
8842                if *src == *dst || len == 0 {
8843                    continue;
8844                }
8845                let arena_len = arena_buf.len();
8846                let max_from_src = (arena_len.saturating_sub(*src)) / 8;
8847                let max_from_dst = (arena_len.saturating_sub(*dst)) / 8;
8848                len = len.min(max_from_src).min(max_from_dst);
8849                if len == 0 {
8850                    continue;
8851                }
8852                let byte_len = len.saturating_mul(8);
8853                unsafe {
8854                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
8855                }
8856            }
8857
8858            Thunk::CastF32ToI64 { src, dst, len } => {
8859                let len = *len as usize;
8860                if len == 0 {
8861                    continue;
8862                }
8863                unsafe {
8864                    let inp = sl(*src, base, len);
8865                    let out = sl_mut_i64(*dst, base, len);
8866                    for i in 0..len {
8867                        out[i] = inp[i].round() as i64;
8868                    }
8869                }
8870            }
8871
8872            Thunk::CastF32ToF64 { src, dst, len } => {
8873                let len = *len as usize;
8874                if len == 0 {
8875                    continue;
8876                }
8877                unsafe {
8878                    let inp = sl(*src, base, len);
8879                    let out = sl_mut_f64(*dst, base, len);
8880                    for i in 0..len {
8881                        out[i] = inp[i] as f64;
8882                    }
8883                }
8884            }
8885
8886            Thunk::CastF32ToI32 { src, dst, len } => {
8887                let len = *len as usize;
8888                if len == 0 {
8889                    continue;
8890                }
8891                unsafe {
8892                    let inp = sl(*src, base, len);
8893                    let out = sl_mut_i32(*dst, base, len);
8894                    for i in 0..len {
8895                        out[i] = inp[i].round() as i32;
8896                    }
8897                }
8898            }
8899
8900            Thunk::CastI64ToF32 { src, dst, len } => {
8901                let len = *len as usize;
8902                if len == 0 {
8903                    continue;
8904                }
8905                unsafe {
8906                    let inp = sl_i64(*src, base, len);
8907                    let out = sl_mut(*dst, base, len);
8908                    for i in 0..len {
8909                        out[i] = inp[i] as f32;
8910                    }
8911                }
8912            }
8913
8914            Thunk::CastBoolToI32 { src, dst, len } => {
8915                let len = *len as usize;
8916                if len == 0 {
8917                    continue;
8918                }
8919                unsafe {
8920                    let inp = &arena_buf[*src..*src + len];
8921                    let out = sl_mut_i32(*dst, base, len);
8922                    for i in 0..len {
8923                        out[i] = i32::from(inp[i] != 0);
8924                    }
8925                }
8926            }
8927
8928            Thunk::CastI32ToF32 { src, dst, len } => {
8929                let len = *len as usize;
8930                if len == 0 {
8931                    continue;
8932                }
8933                unsafe {
8934                    let inp = sl_i32(*src, base, len);
8935                    let out = sl_mut(*dst, base, len);
8936                    for i in 0..len {
8937                        out[i] = inp[i] as f32;
8938                    }
8939                }
8940            }
8941
8942            Thunk::CastBoolToF32 { src, dst, len } => {
8943                let len = *len as usize;
8944                if len == 0 {
8945                    continue;
8946                }
8947                unsafe {
8948                    let inp = &arena_buf[*src..*src + len];
8949                    let out = sl_mut(*dst, base, len);
8950                    for i in 0..len {
8951                        out[i] = if inp[i] != 0 { 1.0 } else { 0.0 };
8952                    }
8953                }
8954            }
8955
8956            Thunk::BinaryFullF64 {
8957                lhs,
8958                rhs,
8959                dst,
8960                len,
8961                lhs_len,
8962                rhs_len,
8963                op,
8964                out_dims_bcast,
8965                bcast_lhs_strides,
8966                bcast_rhs_strides,
8967            } => {
8968                let len = *len as usize;
8969                let lhs_len = *lhs_len as usize;
8970                let rhs_len = *rhs_len as usize;
8971                unsafe {
8972                    let l = sl_f64(*lhs, base, lhs_len);
8973                    let r = sl_f64(*rhs, base, rhs_len);
8974                    let d = sl_mut_f64(*dst, base, len);
8975                    if lhs_len == len && rhs_len == len {
8976                        for i in 0..len {
8977                            d[i] = binary_op_f64(*op, l[i], r[i]);
8978                        }
8979                    } else if !out_dims_bcast.is_empty() {
8980                        // Shape-aware broadcast path: correct for
8981                        // arbitrary NumPy-style broadcasts including
8982                        // bidirectional `[N,1] op [1,S]`.
8983                        let rank = out_dims_bcast.len();
8984                        let mut coords = vec![0u32; rank];
8985                        for i in 0..len {
8986                            let mut rem = i;
8987                            for ax in (0..rank).rev() {
8988                                let sz = out_dims_bcast[ax] as usize;
8989                                coords[ax] = (rem % sz) as u32;
8990                                rem /= sz;
8991                            }
8992                            let mut li: usize = 0;
8993                            let mut ri: usize = 0;
8994                            for ax in 0..rank {
8995                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8996                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8997                            }
8998                            d[i] = binary_op_f64(*op, l[li], r[ri]);
8999                        }
9000                    } else {
9001                        // Fallback: legacy modulo path (preserved for
9002                        // dynamic-shape graphs where strides can't be
9003                        // precomputed). Only correct for scalar /
9004                        // last-axis broadcast.
9005                        for i in 0..len {
9006                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
9007                        }
9008                    }
9009                }
9010            }
9011
9012            Thunk::BinaryFullC64 {
9013                lhs,
9014                rhs,
9015                dst,
9016                len,
9017                lhs_len,
9018                rhs_len,
9019                op,
9020                out_dims_bcast,
9021                bcast_lhs_strides,
9022                bcast_rhs_strides,
9023            } => {
9024                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
9025                // Underlying f32 buffer length is 2·N (N = complex
9026                // element count). All offsets are byte offsets; the
9027                // `sl` helper reads as f32 starting at the byte
9028                // offset, so f32-length = 2·complex-len.
9029                let n_out = *len as usize;
9030                let n_l = *lhs_len as usize;
9031                let n_r = *rhs_len as usize;
9032                unsafe {
9033                    let l = sl(*lhs, base, 2 * n_l);
9034                    let r = sl(*rhs, base, 2 * n_r);
9035                    let d = sl_mut(*dst, base, 2 * n_out);
9036                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
9037                        match op {
9038                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
9039                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
9040                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
9041                            BinaryOp::Div => {
9042                                let denom = b_re * b_re + b_im * b_im;
9043                                (
9044                                    (a_re * b_re + a_im * b_im) / denom,
9045                                    (a_im * b_re - a_re * b_im) / denom,
9046                                )
9047                            }
9048                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
9049                                unreachable!("C64 max/min/pow rejected at lowering")
9050                            }
9051                        }
9052                    };
9053                    if n_l == n_out && n_r == n_out {
9054                        for i in 0..n_out {
9055                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
9056                            d[2 * i] = re;
9057                            d[2 * i + 1] = im;
9058                        }
9059                    } else if !out_dims_bcast.is_empty() {
9060                        // Strided complex broadcast: strides are in
9061                        // *complex element* units; multiply by 2 when
9062                        // indexing into the f32 buffer.
9063                        let rank = out_dims_bcast.len();
9064                        let mut coords = vec![0u32; rank];
9065                        for i in 0..n_out {
9066                            let mut rem = i;
9067                            for ax in (0..rank).rev() {
9068                                let sz = out_dims_bcast[ax] as usize;
9069                                coords[ax] = (rem % sz) as u32;
9070                                rem /= sz;
9071                            }
9072                            let mut li: usize = 0;
9073                            let mut ri: usize = 0;
9074                            for ax in 0..rank {
9075                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9076                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9077                            }
9078                            let (re, im) =
9079                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
9080                            d[2 * i] = re;
9081                            d[2 * i + 1] = im;
9082                        }
9083                    } else {
9084                        // Modulo fallback (scalar / last-axis broadcast).
9085                        for i in 0..n_out {
9086                            let li = if n_l == 1 { 0 } else { i % n_l };
9087                            let ri = if n_r == 1 { 0 } else { i % n_r };
9088                            let (re, im) =
9089                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
9090                            d[2 * i] = re;
9091                            d[2 * i + 1] = im;
9092                        }
9093                    }
9094                }
9095            }
9096
9097            Thunk::ComplexNormSqF32 { src, dst, len } => {
9098                let n = *len as usize;
9099                unsafe {
9100                    let s = sl(*src, base, 2 * n);
9101                    let d = sl_mut(*dst, base, n);
9102                    for i in 0..n {
9103                        let re = s[2 * i];
9104                        let im = s[2 * i + 1];
9105                        d[i] = re * re + im * im;
9106                    }
9107                }
9108            }
9109
9110            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
9111                // Wirtinger: dz = g · z, element-wise complex
9112                // (g is real, z is complex).
9113                let n = *len as usize;
9114                unsafe {
9115                    let zb = sl(*z, base, 2 * n);
9116                    let gb = sl(*g, base, n);
9117                    let db = sl_mut(*dz, base, 2 * n);
9118                    for i in 0..n {
9119                        let re = zb[2 * i];
9120                        let im = zb[2 * i + 1];
9121                        let gv = gb[i];
9122                        db[2 * i] = gv * re;
9123                        db[2 * i + 1] = gv * im;
9124                    }
9125                }
9126            }
9127
9128            Thunk::ConjugateC64 { src, dst, len } => {
9129                let n = *len as usize;
9130                unsafe {
9131                    let s = sl(*src, base, 2 * n);
9132                    let d = sl_mut(*dst, base, 2 * n);
9133                    for i in 0..n {
9134                        d[2 * i] = s[2 * i];
9135                        d[2 * i + 1] = -s[2 * i + 1];
9136                    }
9137                }
9138            }
9139
9140            Thunk::ActivationC64 {
9141                src,
9142                dst,
9143                len,
9144                kind,
9145            } => {
9146                let n = *len as usize;
9147                unsafe {
9148                    let s = sl(*src, base, 2 * n);
9149                    let d = sl_mut(*dst, base, 2 * n);
9150                    for i in 0..n {
9151                        let a = s[2 * i];
9152                        let b = s[2 * i + 1];
9153                        let (re, im) = match kind {
9154                            Activation::Neg => (-a, -b),
9155                            Activation::Exp => {
9156                                // exp(a + bi) = e^a · (cos b + i·sin b)
9157                                let ea = a.exp();
9158                                (ea * b.cos(), ea * b.sin())
9159                            }
9160                            Activation::Log => {
9161                                // log(z) = log|z| + i·arg(z), principal branch
9162                                let r = (a * a + b * b).sqrt();
9163                                (r.ln(), b.atan2(a))
9164                            }
9165                            Activation::Sqrt => {
9166                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
9167                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
9168                                let r = (a * a + b * b).sqrt();
9169                                let re = ((r + a) * 0.5).max(0.0).sqrt();
9170                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
9171                                let im = if b >= 0.0 { im_mag } else { -im_mag };
9172                                (re, im)
9173                            }
9174                            _ => unreachable!("non-C64 activation kind survived lowering"),
9175                        };
9176                        d[2 * i] = re;
9177                        d[2 * i + 1] = im;
9178                    }
9179                }
9180            }
9181
9182            Thunk::Scan {
9183                body,
9184                body_init,
9185                body_input_off,
9186                body_output_off,
9187                outer_init_off,
9188                outer_final_off,
9189                length,
9190                carry_bytes,
9191                save_trajectory,
9192                xs_inputs,
9193                bcast_inputs,
9194                num_checkpoints,
9195            } => {
9196                let cb = *carry_bytes as usize;
9197                let n_steps = *length as usize;
9198                // Checkpoint mode: when 0 < K < length, save trajectory[k]
9199                // only when t == c_k = floor((k+1) * length / K) - 1.
9200                // The last index c_{K-1} = length - 1 always.
9201                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
9202                    n_steps // save every step
9203                } else {
9204                    *num_checkpoints as usize
9205                };
9206                let checkpoint_t_for_k = |k: usize| -> usize {
9207                    if k_total == n_steps {
9208                        k
9209                    } else {
9210                        ((k + 1) * n_steps)
9211                            .div_ceil(k_total)
9212                            .saturating_sub(1)
9213                            .min(n_steps - 1)
9214                    }
9215                };
9216                let mut next_k = 0usize;
9217
9218                let mut body_buf: Vec<u8> = (**body_init).clone();
9219                unsafe {
9220                    std::ptr::copy_nonoverlapping(
9221                        base.add(*outer_init_off),
9222                        body_buf.as_mut_ptr().add(*body_input_off),
9223                        cb,
9224                    );
9225                    // Broadcast inputs: copy each one into the body's
9226                    // input slot ONCE. They aren't touched in the
9227                    // iteration loop below (in contrast to xs).
9228                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
9229                        std::ptr::copy_nonoverlapping(
9230                            base.add(*outer_b_off),
9231                            body_buf.as_mut_ptr().add(*body_b_off),
9232                            *total_bytes as usize,
9233                        );
9234                    }
9235                }
9236                for t in 0..n_steps {
9237                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
9238                        let psb = *per_step_bytes as usize;
9239                        unsafe {
9240                            std::ptr::copy_nonoverlapping(
9241                                base.add(*outer_xs_off + t * psb),
9242                                body_buf.as_mut_ptr().add(*body_x_off),
9243                                psb,
9244                            );
9245                        }
9246                    }
9247
9248                    execute_thunks(body, &mut body_buf);
9249
9250                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
9251                        unsafe {
9252                            std::ptr::copy_nonoverlapping(
9253                                body_buf.as_ptr().add(*body_output_off),
9254                                base.add(*outer_final_off + next_k * cb),
9255                                cb,
9256                            );
9257                        }
9258                        next_k += 1;
9259                    }
9260
9261                    if *body_output_off != *body_input_off {
9262                        body_buf
9263                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
9264                    }
9265                }
9266
9267                if !*save_trajectory {
9268                    // Single final-carry write.
9269                    unsafe {
9270                        std::ptr::copy_nonoverlapping(
9271                            body_buf.as_ptr().add(*body_output_off),
9272                            base.add(*outer_final_off),
9273                            cb,
9274                        );
9275                    }
9276                }
9277            }
9278
9279            Thunk::ScanBackward {
9280                body_vjp,
9281                body_init,
9282                body_carry_in_off,
9283                body_x_offs,
9284                body_d_output_off,
9285                body_dcarry_out_off,
9286                outer_init_off,
9287                outer_traj_off,
9288                outer_upstream_off,
9289                outer_xs_offs,
9290                outer_dinit_off,
9291                length,
9292                carry_bytes,
9293                save_trajectory,
9294                num_checkpoints,
9295                forward_body,
9296                forward_body_init,
9297                forward_body_carry_in_off,
9298                forward_body_output_off,
9299                forward_body_x_offs,
9300                carry_elem_size,
9301            } => {
9302                // Two backward paths share the same per-iteration body
9303                // (body_vjp run + dcarry threading). The "All" path
9304                // reads the carry directly from the saved trajectory
9305                // each step. The "Recursive checkpointing" path stores
9306                // only K saved checkpoints and reconstructs intermediate
9307                // carries via Griewank-style recursive subdivision —
9308                // see [`griewank_process_segment`]. Auxiliary memory
9309                // is `O(log(segment_size) · carry_bytes)` for the
9310                // recursion stack, vs the old segment-cache scheme's
9311                // `O(segment_size · carry_bytes)`. Total recompute work
9312                // grows from `O(length)` to `O(length · log)`, which
9313                // is the canonical Griewank trade.
9314                let cb = *carry_bytes as usize;
9315                let n_steps = *length as usize;
9316                let k_total = *num_checkpoints as usize;
9317                let is_recursive = k_total != 0 && k_total != n_steps;
9318                let checkpoint_t_for_k = |k: usize| -> usize {
9319                    ((k + 1) * n_steps)
9320                        .div_ceil(k_total)
9321                        .saturating_sub(1)
9322                        .min(n_steps - 1)
9323                };
9324
9325                let mut fwd_buf: Vec<u8> = if is_recursive {
9326                    (**forward_body_init.as_ref().unwrap()).clone()
9327                } else {
9328                    Vec::new()
9329                };
9330
9331                let mut dcarry: Vec<u8> = vec![0u8; cb];
9332                if !*save_trajectory {
9333                    unsafe {
9334                        std::ptr::copy_nonoverlapping(
9335                            base.add(*outer_upstream_off),
9336                            dcarry.as_mut_ptr(),
9337                            cb,
9338                        );
9339                    }
9340                }
9341
9342                let mut body_buf: Vec<u8> = (**body_init).clone();
9343
9344                // Per-iteration backward action — shared between the
9345                // direct-trajectory (All) and Griewank (Recursive) paths.
9346                // Both feed the same body_vjp run with carry-at-t,
9347                // x_t_i, and d_output, then thread dcarry backward.
9348                let process_iter =
9349                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
9350                        if *save_trajectory {
9351                            unsafe {
9352                                let up_off = *outer_upstream_off + t * cb;
9353                                match *carry_elem_size {
9354                                    4 => {
9355                                        let up_ptr = base.add(up_off) as *const f32;
9356                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9357                                        let n_elems = cb / 4;
9358                                        for i in 0..n_elems {
9359                                            *dc_ptr.add(i) += *up_ptr.add(i);
9360                                        }
9361                                    }
9362                                    8 => {
9363                                        let up_ptr = base.add(up_off) as *const f64;
9364                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9365                                        let n_elems = cb / 8;
9366                                        for i in 0..n_elems {
9367                                            *dc_ptr.add(i) += *up_ptr.add(i);
9368                                        }
9369                                    }
9370                                    other => panic!(
9371                                        "ScanBackward: unsupported carry elem size {other} \
9372                                     (only f32/f64 carries are supported today)"
9373                                    ),
9374                                }
9375                            }
9376                        }
9377                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
9378                            .copy_from_slice(carry_in);
9379                        unsafe {
9380                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
9381                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
9382                                let psb = per_step_bytes as usize;
9383                                std::ptr::copy_nonoverlapping(
9384                                    base.add(outer_xs_off + t * psb),
9385                                    body_buf.as_mut_ptr().add(*body_x_off),
9386                                    psb,
9387                                );
9388                            }
9389                            std::ptr::copy_nonoverlapping(
9390                                dcarry.as_ptr(),
9391                                body_buf.as_mut_ptr().add(*body_d_output_off),
9392                                cb,
9393                            );
9394                        }
9395                        execute_thunks(body_vjp, body_buf);
9396                        unsafe {
9397                            std::ptr::copy_nonoverlapping(
9398                                body_buf.as_ptr().add(*body_dcarry_out_off),
9399                                dcarry.as_mut_ptr(),
9400                                cb,
9401                            );
9402                        }
9403                    };
9404
9405                if is_recursive {
9406                    // Griewank treeverse path. Process saved-checkpoint
9407                    // segments from highest-t to lowest-t; within each,
9408                    // recursive binary subdivision via
9409                    // `griewank_process_segment`. Auxiliary memory:
9410                    // O(log(seg_size) · cb) for the recursion stack
9411                    // (vs O(seg_size · cb) for the older segment-cache
9412                    // scheme); recompute work: O(seg_size · log).
9413                    let leaf_threshold = 4usize;
9414                    let fb_sched = forward_body.as_ref().unwrap();
9415                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
9416                    let mut segment_end = n_steps - 1;
9417                    for seg_k in (0..k_total).rev() {
9418                        let segment_start = if seg_k == 0 {
9419                            0
9420                        } else {
9421                            checkpoint_t_for_k(seg_k - 1) + 1
9422                        };
9423                        let mut anchor: Vec<u8> = vec![0u8; cb];
9424                        unsafe {
9425                            let src = if seg_k == 0 {
9426                                base.add(*outer_init_off)
9427                            } else {
9428                                base.add(*outer_traj_off + (seg_k - 1) * cb)
9429                            };
9430                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
9431                        }
9432                        // Closure adapter for the helper's signature
9433                        // (mutably re-borrows dcarry / body_buf each call).
9434                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
9435                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
9436                        };
9437                        unsafe {
9438                            griewank_process_segment(
9439                                segment_start,
9440                                segment_end,
9441                                &anchor,
9442                                cb,
9443                                fb_sched,
9444                                fb_init,
9445                                *forward_body_carry_in_off,
9446                                *forward_body_output_off,
9447                                forward_body_x_offs,
9448                                base,
9449                                outer_xs_offs,
9450                                &mut fwd_buf,
9451                                leaf_threshold,
9452                                &mut leaf_action,
9453                            );
9454                        }
9455                        if seg_k == 0 {
9456                            break;
9457                        }
9458                        segment_end = segment_start - 1;
9459                    }
9460                } else {
9461                    // All-trajectory path: read each carry directly
9462                    // from the saved trajectory buffer.
9463                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
9464                    for t in (0..n_steps).rev() {
9465                        unsafe {
9466                            let src = if t == 0 {
9467                                base.add(*outer_init_off)
9468                            } else {
9469                                base.add(*outer_traj_off + (t - 1) * cb)
9470                            };
9471                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
9472                        }
9473                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
9474                    }
9475                }
9476
9477                unsafe {
9478                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
9479                }
9480            }
9481
9482            Thunk::ScanBackwardXs {
9483                body_vjp,
9484                body_init,
9485                body_carry_in_off,
9486                body_x_offs,
9487                body_d_output_off,
9488                body_dcarry_out_off,
9489                body_dxs_out_off,
9490                outer_init_off,
9491                outer_traj_off,
9492                outer_upstream_off,
9493                outer_xs_offs,
9494                outer_dxs_off,
9495                length,
9496                carry_bytes,
9497                carry_elem_size,
9498                per_step_bytes,
9499                save_trajectory,
9500                num_checkpoints,
9501                forward_body,
9502                forward_body_init,
9503                forward_body_carry_in_off,
9504                forward_body_output_off,
9505                forward_body_x_offs,
9506            } => {
9507                let cb = *carry_bytes as usize;
9508                let psb = *per_step_bytes as usize;
9509                let n_steps = *length as usize;
9510                let k_total = *num_checkpoints as usize;
9511                let is_recursive = k_total != 0 && k_total != n_steps;
9512                let checkpoint_t_for_k = |k: usize| -> usize {
9513                    ((k + 1) * n_steps)
9514                        .div_ceil(k_total)
9515                        .saturating_sub(1)
9516                        .min(n_steps - 1)
9517                };
9518
9519                // Forward-body recompute scratch + segment cache —
9520                // exact mirror of the ScanBackward path. With ≈√length
9521                // checkpoints, total recompute work is O(length).
9522                let mut fwd_buf: Vec<u8> = if is_recursive {
9523                    (**forward_body_init.as_ref().unwrap()).clone()
9524                } else {
9525                    Vec::new()
9526                };
9527                let mut seg_cache: Vec<u8> = Vec::new();
9528                let mut seg_start_t: usize = usize::MAX;
9529                let mut seg_count: usize = 0;
9530                let recompute_carry_t =
9531                    |t: usize,
9532                     dst: &mut [u8],
9533                     fwd_buf: &mut Vec<u8>,
9534                     seg_cache: &mut Vec<u8>,
9535                     seg_start_t: &mut usize,
9536                     seg_count: &mut usize| {
9537                        if !is_recursive {
9538                            unsafe {
9539                                let src = if t == 0 {
9540                                    base.add(*outer_init_off)
9541                                } else {
9542                                    base.add(*outer_traj_off + (t - 1) * cb)
9543                                };
9544                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
9545                            }
9546                            return;
9547                        }
9548                        if *seg_start_t != usize::MAX
9549                            && t >= *seg_start_t
9550                            && t < *seg_start_t + *seg_count
9551                        {
9552                            let off = (t - *seg_start_t) * cb;
9553                            dst.copy_from_slice(&seg_cache[off..off + cb]);
9554                            return;
9555                        }
9556                        let seg_k = (0..k_total)
9557                            .find(|&k| t <= checkpoint_t_for_k(k))
9558                            .unwrap_or(k_total - 1);
9559                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
9560                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
9561                        } else {
9562                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
9563                            (prev_ck + 1, unsafe {
9564                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
9565                            })
9566                        };
9567                        let seg_end_t = checkpoint_t_for_k(seg_k);
9568                        let seg_size = seg_end_t - anchor_t + 1;
9569
9570                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
9571                        unsafe {
9572                            std::ptr::copy_nonoverlapping(
9573                                anchor_ptr,
9574                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
9575                                cb,
9576                            );
9577                        }
9578                        seg_cache.resize(seg_size * cb, 0u8);
9579                        seg_cache[0..cb].copy_from_slice(
9580                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9581                        );
9582                        let fb_sched = forward_body.as_ref().unwrap();
9583                        for i in 1..seg_size {
9584                            let cur_iter = anchor_t + i - 1;
9585                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
9586                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
9587                                let xb = x_psb as usize;
9588                                unsafe {
9589                                    std::ptr::copy_nonoverlapping(
9590                                        base.add(outer_xs_off + cur_iter * xb),
9591                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
9592                                        xb,
9593                                    );
9594                                }
9595                            }
9596                            execute_thunks(fb_sched, fwd_buf);
9597                            if *forward_body_output_off != *forward_body_carry_in_off {
9598                                fwd_buf.copy_within(
9599                                    *forward_body_output_off..*forward_body_output_off + cb,
9600                                    *forward_body_carry_in_off,
9601                                );
9602                            }
9603                            let cache_off = i * cb;
9604                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
9605                                &fwd_buf
9606                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
9607                            );
9608                        }
9609                        *seg_start_t = anchor_t;
9610                        *seg_count = seg_size;
9611
9612                        let off = (t - anchor_t) * cb;
9613                        dst.copy_from_slice(&seg_cache[off..off + cb]);
9614                    };
9615
9616                let mut dcarry: Vec<u8> = vec![0u8; cb];
9617                if !*save_trajectory {
9618                    unsafe {
9619                        std::ptr::copy_nonoverlapping(
9620                            base.add(*outer_upstream_off),
9621                            dcarry.as_mut_ptr(),
9622                            cb,
9623                        );
9624                    }
9625                }
9626
9627                let mut body_buf: Vec<u8> = (**body_init).clone();
9628
9629                for t in (0..n_steps).rev() {
9630                    if *save_trajectory {
9631                        unsafe {
9632                            let up_off = *outer_upstream_off + t * cb;
9633                            match *carry_elem_size {
9634                                4 => {
9635                                    let up_ptr = base.add(up_off) as *const f32;
9636                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
9637                                    let n_elems = cb / 4;
9638                                    for i in 0..n_elems {
9639                                        *dc_ptr.add(i) += *up_ptr.add(i);
9640                                    }
9641                                }
9642                                8 => {
9643                                    let up_ptr = base.add(up_off) as *const f64;
9644                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
9645                                    let n_elems = cb / 8;
9646                                    for i in 0..n_elems {
9647                                        *dc_ptr.add(i) += *up_ptr.add(i);
9648                                    }
9649                                }
9650                                other => panic!(
9651                                    "ScanBackwardXs: unsupported carry elem size {other} \
9652                                     (only f32/f64 carries are supported today)"
9653                                ),
9654                            }
9655                        }
9656                    }
9657
9658                    // Seed body_vjp's carry input via the recompute
9659                    // helper (works for both All and Recursive modes),
9660                    // then x_t_i + d_output.
9661                    let carry_dst_start = *body_carry_in_off;
9662                    {
9663                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
9664                        recompute_carry_t(
9665                            t,
9666                            carry_slice,
9667                            &mut fwd_buf,
9668                            &mut seg_cache,
9669                            &mut seg_start_t,
9670                            &mut seg_count,
9671                        );
9672                    }
9673                    unsafe {
9674                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
9675                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
9676                            let xb = x_psb as usize;
9677                            std::ptr::copy_nonoverlapping(
9678                                base.add(outer_xs_off + t * xb),
9679                                body_buf.as_mut_ptr().add(*body_x_off),
9680                                xb,
9681                            );
9682                        }
9683                        std::ptr::copy_nonoverlapping(
9684                            dcarry.as_ptr(),
9685                            body_buf.as_mut_ptr().add(*body_d_output_off),
9686                            cb,
9687                        );
9688                    }
9689
9690                    execute_thunks(body_vjp, &mut body_buf);
9691
9692                    // Stash this step's dxs into row `t` of the outer
9693                    // [length, *per_step_xs] output.
9694                    unsafe {
9695                        std::ptr::copy_nonoverlapping(
9696                            body_buf.as_ptr().add(*body_dxs_out_off),
9697                            base.add(*outer_dxs_off + t * psb),
9698                            psb,
9699                        );
9700                    }
9701
9702                    // Update dcarry for next backward iteration.
9703                    unsafe {
9704                        std::ptr::copy_nonoverlapping(
9705                            body_buf.as_ptr().add(*body_dcarry_out_off),
9706                            dcarry.as_mut_ptr(),
9707                            cb,
9708                        );
9709                    }
9710                }
9711            }
9712
9713            Thunk::FusedMmBiasAct {
9714                a,
9715                w,
9716                bias,
9717                c,
9718                m,
9719                k,
9720                n,
9721                act,
9722            } => {
9723                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9724                unsafe {
9725                    let out = sl_mut(*c, base, m * n);
9726                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
9727                    match act {
9728                        Some(Activation::Gelu) => {
9729                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
9730                        }
9731                        Some(other) => {
9732                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9733                            apply_activation_inplace(out, *other);
9734                        }
9735                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
9736                    }
9737                }
9738            }
9739
9740            Thunk::FusedResidualLN {
9741                x,
9742                res,
9743                bias,
9744                g,
9745                b,
9746                out,
9747                rows,
9748                h,
9749                eps,
9750                has_bias,
9751            } => {
9752                let (rows, h) = (*rows as usize, *h as usize);
9753                unsafe {
9754                    let zero = &zero_bias[..h];
9755                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9756                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9757                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9758                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9759                    let bi_ptr = bi.as_ptr() as usize;
9760                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9761                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9762                    let e = *eps;
9763                    crate::pool::par_for(rows, 4, &|off, cnt| {
9764                        let xs =
9765                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9766                        let rs =
9767                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9768                        let os = std::slice::from_raw_parts_mut(
9769                            (o_ptr as *mut f32).add(off * h),
9770                            cnt * h,
9771                        );
9772                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9773                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9774                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9775                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
9776                    });
9777                }
9778            }
9779
9780            Thunk::FusedResidualRmsNorm {
9781                x,
9782                res,
9783                bias,
9784                g,
9785                b,
9786                out,
9787                rows,
9788                h,
9789                eps,
9790                has_bias,
9791            } => {
9792                let (rows, h) = (*rows as usize, *h as usize);
9793                unsafe {
9794                    let zero = &zero_bias[..h];
9795                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
9796                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
9797                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
9798                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
9799                    let bi_ptr = bi.as_ptr() as usize;
9800                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
9801                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
9802                    let e = *eps;
9803                    crate::pool::par_for(rows, 4, &|off, cnt| {
9804                        let xs =
9805                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
9806                        let rs =
9807                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
9808                        let os = std::slice::from_raw_parts_mut(
9809                            (o_ptr as *mut f32).add(off * h),
9810                            cnt * h,
9811                        );
9812                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
9813                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
9814                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
9815                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
9816                    });
9817                }
9818            }
9819
9820            Thunk::BiasAdd {
9821                src,
9822                bias,
9823                dst,
9824                m,
9825                n,
9826            } => {
9827                let (m, n) = (*m as usize, *n as usize);
9828                let len = m * n;
9829                unsafe {
9830                    let out = sl_mut(*dst, base, len);
9831                    if *src != *dst {
9832                        let src_ptr = base.add(*src) as *const f32;
9833                        let dst_ptr = base.add(*dst) as *mut f32;
9834                        if src_ptr != dst_ptr {
9835                            std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
9836                        }
9837                    }
9838                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
9839                }
9840            }
9841
9842            Thunk::BinaryFull {
9843                lhs,
9844                rhs,
9845                dst,
9846                len,
9847                lhs_len,
9848                rhs_len,
9849                op,
9850                out_dims_bcast,
9851                bcast_lhs_strides,
9852                bcast_rhs_strides,
9853                elem_bytes,
9854            } => {
9855                let len = *len as usize;
9856                let ll = (*lhs_len as usize).max(1);
9857                let rl = (*rhs_len as usize).max(1);
9858                let eb = (*elem_bytes).max(1) as usize;
9859                let arena_len = arena_buf.len();
9860                let ll = ll.min((arena_len.saturating_sub(*lhs)) / eb);
9861                let rl = rl.min((arena_len.saturating_sub(*rhs)) / eb);
9862                let len = len.min((arena_len.saturating_sub(*dst)) / eb);
9863                unsafe {
9864                    if eb == 8 {
9865                        let l = sl_i64(*lhs, base, ll);
9866                        let r = sl_i64(*rhs, base, rl);
9867                        let o = sl_mut_i64(*dst, base, len);
9868                        if !out_dims_bcast.is_empty() {
9869                            let rank = out_dims_bcast.len();
9870                            let mut coords = vec![0u32; rank];
9871                            for i in 0..len {
9872                                let mut rem = i;
9873                                for ax in (0..rank).rev() {
9874                                    let sz = out_dims_bcast[ax] as usize;
9875                                    coords[ax] = (rem % sz) as u32;
9876                                    rem /= sz;
9877                                }
9878                                let mut li = 0usize;
9879                                let mut ri = 0usize;
9880                                for ax in 0..rank {
9881                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9882                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9883                                }
9884                                o[i] = match op {
9885                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9886                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9887                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9888                                    BinaryOp::Div => {
9889                                        if r[ri] == 0 {
9890                                            0
9891                                        } else {
9892                                            l[li] / r[ri]
9893                                        }
9894                                    }
9895                                    BinaryOp::Max => l[li].max(r[ri]),
9896                                    BinaryOp::Min => l[li].min(r[ri]),
9897                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9898                                };
9899                            }
9900                        } else {
9901                            for i in 0..len {
9902                                let li = if ll == 1 { 0 } else { i % ll };
9903                                let ri = if rl == 1 { 0 } else { i % rl };
9904                                o[i] = match op {
9905                                    BinaryOp::Add => l[li].wrapping_add(r[ri]),
9906                                    BinaryOp::Sub => l[li].wrapping_sub(r[ri]),
9907                                    BinaryOp::Mul => l[li].wrapping_mul(r[ri]),
9908                                    BinaryOp::Div => {
9909                                        if r[ri] == 0 {
9910                                            0
9911                                        } else {
9912                                            l[li] / r[ri]
9913                                        }
9914                                    }
9915                                    BinaryOp::Max => l[li].max(r[ri]),
9916                                    BinaryOp::Min => l[li].min(r[ri]),
9917                                    BinaryOp::Pow => l[li].pow(r[ri].max(0) as u32),
9918                                };
9919                            }
9920                        }
9921                    } else {
9922                        let l = sl(*lhs, base, ll);
9923                        let r = sl(*rhs, base, rl);
9924                        let o = sl_mut(*dst, base, len);
9925                        if ll == len && rl == len {
9926                            #[cfg(target_arch = "aarch64")]
9927                            if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
9928                                use std::arch::aarch64::*;
9929                                let chunks = len / 4;
9930                                for c in 0..chunks {
9931                                    let off = c * 4;
9932                                    let vl = vld1q_f32(l.as_ptr().add(off));
9933                                    let vr = vld1q_f32(r.as_ptr().add(off));
9934                                    let res = match op {
9935                                        BinaryOp::Add => vaddq_f32(vl, vr),
9936                                        BinaryOp::Mul => vmulq_f32(vl, vr),
9937                                        _ => unreachable!(),
9938                                    };
9939                                    vst1q_f32(o.as_mut_ptr().add(off), res);
9940                                }
9941                                for i in (chunks * 4)..len {
9942                                    o[i] = match op {
9943                                        BinaryOp::Add => l[i] + r[i],
9944                                        BinaryOp::Mul => l[i] * r[i],
9945                                        _ => unreachable!(),
9946                                    };
9947                                }
9948                                continue;
9949                            }
9950                        }
9951                        if !out_dims_bcast.is_empty() {
9952                            let rank = out_dims_bcast.len();
9953                            let mut coords = vec![0u32; rank];
9954                            for i in 0..len {
9955                                let mut rem = i;
9956                                for ax in (0..rank).rev() {
9957                                    let sz = out_dims_bcast[ax] as usize;
9958                                    coords[ax] = (rem % sz) as u32;
9959                                    rem /= sz;
9960                                }
9961                                let mut li = 0usize;
9962                                let mut ri = 0usize;
9963                                for ax in 0..rank {
9964                                    li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
9965                                    ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
9966                                }
9967                                o[i] = match op {
9968                                    BinaryOp::Add => l[li] + r[ri],
9969                                    BinaryOp::Sub => l[li] - r[ri],
9970                                    BinaryOp::Mul => l[li] * r[ri],
9971                                    BinaryOp::Div => l[li] / r[ri],
9972                                    BinaryOp::Max => l[li].max(r[ri]),
9973                                    BinaryOp::Min => l[li].min(r[ri]),
9974                                    BinaryOp::Pow => l[li].powf(r[ri]),
9975                                };
9976                            }
9977                        } else {
9978                            for i in 0..len {
9979                                let li = if ll == 1 { 0 } else { i % ll };
9980                                let ri = if rl == 1 { 0 } else { i % rl };
9981                                o[i] = match op {
9982                                    BinaryOp::Add => l[li] + r[ri],
9983                                    BinaryOp::Sub => l[li] - r[ri],
9984                                    BinaryOp::Mul => l[li] * r[ri],
9985                                    BinaryOp::Div => l[li] / r[ri],
9986                                    BinaryOp::Max => l[li].max(r[ri]),
9987                                    BinaryOp::Min => l[li].min(r[ri]),
9988                                    BinaryOp::Pow => l[li].powf(r[ri]),
9989                                };
9990                            }
9991                        }
9992                    }
9993                }
9994            }
9995
9996            Thunk::Gather {
9997                table,
9998                table_len,
9999                idx,
10000                dst,
10001                num_idx,
10002                trailing,
10003                idx_i64,
10004                table_bytes,
10005            } => {
10006                let (ni, tr) = (*num_idx as usize, *trailing as usize);
10007                let rows = *table_len as usize / tr.max(1);
10008                unsafe {
10009                    if *table_bytes == 8 {
10010                        let tab = sl_i64(*table, base, *table_len as usize);
10011                        let out = sl_mut_i64(*dst, base, ni * tr);
10012                        if *idx_i64 != 0 {
10013                            let ids = sl_i64(*idx, base, ni);
10014                            for i in 0..ni {
10015                                let row = ids[i].max(0) as usize;
10016                                if row < rows {
10017                                    out[i * tr..(i + 1) * tr]
10018                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10019                                }
10020                            }
10021                        } else {
10022                            let ids = sl(*idx, base, ni);
10023                            for i in 0..ni {
10024                                let row = ids[i] as usize;
10025                                if row < rows {
10026                                    out[i * tr..(i + 1) * tr]
10027                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10028                                }
10029                            }
10030                        }
10031                    } else {
10032                        let tab = sl(*table, base, *table_len as usize);
10033                        let out = sl_mut(*dst, base, ni * tr);
10034                        if *idx_i64 != 0 {
10035                            let ids = sl_i64(*idx, base, ni);
10036                            for i in 0..ni {
10037                                let row = ids[i].max(0) as usize;
10038                                if row < rows {
10039                                    out[i * tr..(i + 1) * tr]
10040                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10041                                }
10042                            }
10043                        } else {
10044                            let ids = sl(*idx, base, ni);
10045                            for i in 0..ni {
10046                                let row = ids[i] as usize;
10047                                if row < rows {
10048                                    out[i * tr..(i + 1) * tr]
10049                                        .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
10050                                }
10051                            }
10052                        }
10053                    }
10054                }
10055            }
10056
10057            Thunk::Narrow {
10058                src,
10059                dst,
10060                outer,
10061                src_stride,
10062                dst_stride,
10063                inner,
10064                elem_bytes,
10065            } => {
10066                let (outer, ss, ds, inner, eb) = (
10067                    *outer as usize,
10068                    *src_stride as usize,
10069                    *dst_stride as usize,
10070                    *inner as usize,
10071                    *elem_bytes as usize,
10072                );
10073                let row_bytes = inner.saturating_mul(eb);
10074                let src_row_stride = ss.saturating_mul(eb);
10075                let dst_row_stride = ds.saturating_mul(eb);
10076                if trace_thunks {
10077                    eprintln!(
10078                        "[narrow] src={} dst={} outer={outer} ss={ss} ds={ds} inner={inner} eb={eb} row={row_bytes} arena={}",
10079                        *src,
10080                        *dst,
10081                        arena_buf.len()
10082                    );
10083                }
10084                if row_bytes > 0 && *src != *dst {
10085                    let arena_len = arena_buf.len();
10086                    for o in 0..outer {
10087                        let s_off = *src + o * src_row_stride;
10088                        let d_off = *dst + o * dst_row_stride;
10089                        if s_off == d_off {
10090                            continue;
10091                        }
10092                        if s_off.saturating_add(row_bytes) > arena_len
10093                            || d_off.saturating_add(row_bytes) > arena_len
10094                        {
10095                            break;
10096                        }
10097                        unsafe {
10098                            std::ptr::copy_nonoverlapping(
10099                                base.add(s_off),
10100                                base.add(d_off),
10101                                row_bytes,
10102                            );
10103                        }
10104                    }
10105                }
10106            }
10107
10108            Thunk::Copy { src, dst, len } => {
10109                let mut len = *len as usize;
10110                if *src == *dst || len == 0 {
10111                    continue;
10112                }
10113                let arena_len = arena_buf.len();
10114                let max_from_src = (arena_len.saturating_sub(*src)) / 4;
10115                let max_from_dst = (arena_len.saturating_sub(*dst)) / 4;
10116                len = len.min(max_from_src).min(max_from_dst);
10117                if len == 0 {
10118                    continue;
10119                }
10120                let byte_len = len.saturating_mul(4);
10121                unsafe {
10122                    std::ptr::copy(base.add(*src), base.add(*dst), byte_len);
10123                }
10124            }
10125
10126            Thunk::LayerNorm {
10127                src,
10128                g,
10129                b,
10130                dst,
10131                rows,
10132                h,
10133                eps,
10134            } => {
10135                let (rows, h) = (*rows as usize, *h as usize);
10136                unsafe {
10137                    let input = sl(*src, base, rows * h);
10138                    let gamma = sl(*g, base, h);
10139                    let beta = sl(*b, base, h);
10140                    let output = sl_mut(*dst, base, rows * h);
10141                    // Parallelize across rows (same pattern as FusedResidualLN)
10142                    if rows >= 4 && rows * h >= 30_000 {
10143                        let i_ptr = input.as_ptr() as usize;
10144                        let o_ptr = output.as_mut_ptr() as usize;
10145                        let g_ptr = gamma.as_ptr() as usize;
10146                        let b_ptr = beta.as_ptr() as usize;
10147                        let e = *eps;
10148                        crate::pool::par_for(rows, 4, &|off, cnt| {
10149                            let inp = std::slice::from_raw_parts(
10150                                (i_ptr as *const f32).add(off * h),
10151                                cnt * h,
10152                            );
10153                            let out = std::slice::from_raw_parts_mut(
10154                                (o_ptr as *mut f32).add(off * h),
10155                                cnt * h,
10156                            );
10157                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
10158                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
10159                            for row in 0..cnt {
10160                                crate::kernels::layer_norm_row(
10161                                    &inp[row * h..(row + 1) * h],
10162                                    g,
10163                                    b,
10164                                    &mut out[row * h..(row + 1) * h],
10165                                    h,
10166                                    e,
10167                                );
10168                            }
10169                        });
10170                    } else {
10171                        for row in 0..rows {
10172                            crate::kernels::layer_norm_row(
10173                                &input[row * h..(row + 1) * h],
10174                                gamma,
10175                                beta,
10176                                &mut output[row * h..(row + 1) * h],
10177                                h,
10178                                *eps,
10179                            );
10180                        }
10181                    }
10182                }
10183            }
10184
10185            Thunk::GroupNorm {
10186                src,
10187                g,
10188                b,
10189                dst,
10190                n,
10191                c,
10192                h,
10193                w,
10194                num_groups,
10195                eps,
10196            } => {
10197                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10198                let plane = c * h * w;
10199                unsafe {
10200                    for ni in 0..n {
10201                        let input = sl(*src, base.add(ni * plane), plane);
10202                        let gamma = sl(*g, base, c);
10203                        let beta = sl(*b, base, c);
10204                        let output = sl_mut(*dst, base.add(ni * plane), plane);
10205                        crate::kernels::group_norm_nchw(
10206                            input,
10207                            gamma,
10208                            beta,
10209                            output,
10210                            1,
10211                            c,
10212                            h,
10213                            w,
10214                            *num_groups as usize,
10215                            *eps,
10216                        );
10217                    }
10218                }
10219            }
10220
10221            Thunk::BatchNormInference {
10222                src,
10223                g,
10224                b,
10225                mean,
10226                var,
10227                dst,
10228                count,
10229                channels,
10230                eps,
10231            } => {
10232                let count = *count as usize;
10233                let c = *channels as usize;
10234                let n = count * c;
10235                unsafe {
10236                    crate::kernels::batch_norm_inference(
10237                        sl(*src, base, n),
10238                        sl(*g, base, c),
10239                        sl(*b, base, c),
10240                        sl(*mean, base, c),
10241                        sl(*var, base, c),
10242                        sl_mut(*dst, base, n),
10243                        c,
10244                        *eps,
10245                    );
10246                }
10247            }
10248
10249            Thunk::LayerNorm2d {
10250                src,
10251                g,
10252                b,
10253                dst,
10254                n,
10255                c,
10256                h,
10257                w,
10258                eps,
10259            } => {
10260                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10261                let plane = c * h * w;
10262                unsafe {
10263                    let input = sl(*src, base, n * plane);
10264                    let gamma = sl(*g, base, c);
10265                    let beta = sl(*b, base, c);
10266                    let output = sl_mut(*dst, base, n * plane);
10267                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
10268                }
10269            }
10270
10271            Thunk::ConvTranspose2d {
10272                src,
10273                weight,
10274                dst,
10275                n,
10276                c_in,
10277                h,
10278                w_in,
10279                c_out,
10280                h_out,
10281                w_out,
10282                kh,
10283                kw,
10284                sh,
10285                sw,
10286                ph,
10287                pw,
10288                dh,
10289                dw,
10290                groups,
10291            } => {
10292                let n = *n as usize;
10293                let c_in = *c_in as usize;
10294                let h = *h as usize;
10295                let w_in = *w_in as usize;
10296                let c_out = *c_out as usize;
10297                let h_out = *h_out as usize;
10298                let w_out = *w_out as usize;
10299                unsafe {
10300                    let inp = sl(*src, base, n * c_in * h * w_in);
10301                    let wt = sl(
10302                        *weight,
10303                        base,
10304                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
10305                    );
10306                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10307                    crate::kernels::conv_transpose2d_nchw(
10308                        inp,
10309                        wt,
10310                        out,
10311                        n,
10312                        c_in,
10313                        h,
10314                        w_in,
10315                        c_out,
10316                        h_out,
10317                        w_out,
10318                        *kh as usize,
10319                        *kw as usize,
10320                        *sh as usize,
10321                        *sw as usize,
10322                        *ph as usize,
10323                        *pw as usize,
10324                        *dh as usize,
10325                        *dw as usize,
10326                        *groups as usize,
10327                    );
10328                }
10329            }
10330
10331            Thunk::ResizeNearest2x {
10332                src,
10333                dst,
10334                n,
10335                c,
10336                h,
10337                w,
10338            } => {
10339                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
10340                let in_plane = c * h * w;
10341                let out_plane = c * h * 2 * w * 2;
10342                unsafe {
10343                    for ni in 0..n {
10344                        let input = sl(*src, base.add(ni * in_plane), in_plane);
10345                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
10346                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
10347                    }
10348                }
10349            }
10350
10351            Thunk::AxialRope2d {
10352                src,
10353                dst,
10354                batch,
10355                seq,
10356                hidden,
10357                end_x,
10358                end_y,
10359                head_dim,
10360                num_heads,
10361                theta,
10362                repeat_factor,
10363            } => {
10364                let b = *batch as usize;
10365                let s = *seq as usize;
10366                let hdim = *head_dim as usize;
10367                let nh = *num_heads as usize;
10368                let plane = s * (*hidden as usize);
10369                unsafe {
10370                    for bi in 0..b {
10371                        let input = sl(*src, base.add(bi * plane), plane);
10372                        let output = sl_mut(*dst, base.add(bi * plane), plane);
10373                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
10374                            input,
10375                            nh,
10376                            s,
10377                            hdim,
10378                            *end_x as usize,
10379                            *end_y as usize,
10380                            *theta,
10381                            *repeat_factor as usize,
10382                        );
10383                        output.copy_from_slice(&rotated);
10384                    }
10385                }
10386            }
10387
10388            Thunk::RmsNorm {
10389                src,
10390                g,
10391                b,
10392                dst,
10393                rows,
10394                h,
10395                eps,
10396            } => {
10397                let (rows, h) = (*rows as usize, *h as usize);
10398                unsafe {
10399                    let input = sl(*src, base, rows * h);
10400                    let gamma = sl(*g, base, h);
10401                    let beta = sl(*b, base, h);
10402                    let output = sl_mut(*dst, base, rows * h);
10403                    let inv_h = 1.0 / h as f32;
10404                    for row in 0..rows {
10405                        let in_row = &input[row * h..(row + 1) * h];
10406                        let out_row = &mut output[row * h..(row + 1) * h];
10407                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
10408                        let mut sumsq = 0f32;
10409                        for &v in in_row {
10410                            sumsq += v * v;
10411                        }
10412                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
10413                        for i in 0..h {
10414                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
10415                        }
10416                    }
10417                }
10418            }
10419
10420            Thunk::Softmax { data, rows, cols } => {
10421                let (rows, cols) = (*rows as usize, *cols as usize);
10422                unsafe {
10423                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
10424                }
10425            }
10426
10427            Thunk::Cumsum {
10428                src,
10429                dst,
10430                rows,
10431                cols,
10432                exclusive,
10433            } => {
10434                let (rows, cols) = (*rows as usize, *cols as usize);
10435                unsafe {
10436                    let s = sl(*src, base, rows * cols);
10437                    let d = sl_mut(*dst, base, rows * cols);
10438                    if *exclusive {
10439                        for r in 0..rows {
10440                            let mut acc = 0.0f32;
10441                            for c in 0..cols {
10442                                d[r * cols + c] = acc;
10443                                acc += s[r * cols + c];
10444                            }
10445                        }
10446                    } else {
10447                        for r in 0..rows {
10448                            let mut acc = 0.0f32;
10449                            for c in 0..cols {
10450                                acc += s[r * cols + c];
10451                                d[r * cols + c] = acc;
10452                            }
10453                        }
10454                    }
10455                }
10456            }
10457
10458            Thunk::Sample {
10459                logits,
10460                dst,
10461                batch,
10462                vocab,
10463                top_k,
10464                top_p,
10465                temperature,
10466                seed,
10467            } => {
10468                let (b, v) = (*batch as usize, *vocab as usize);
10469                let k = (*top_k as usize).min(v);
10470                unsafe {
10471                    let lg = sl(*logits, base, b * v);
10472                    let out = sl_mut(*dst, base, b);
10473                    let mut rng =
10474                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
10475                    for bi in 0..b {
10476                        let row = &lg[bi * v..(bi + 1) * v];
10477                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
10478                    }
10479                }
10480            }
10481
10482            Thunk::RngNormal {
10483                dst,
10484                len,
10485                mean,
10486                scale,
10487                key,
10488                op_seed,
10489            } => {
10490                let n = *len as usize;
10491                unsafe {
10492                    let out = sl_mut(*dst, base, n);
10493                    let opts = *schedule.rng.read().unwrap();
10494                    rlx_ir::fill_normal_like(out, *mean, *scale, opts, *key, *op_seed);
10495                }
10496            }
10497
10498            Thunk::RngUniform {
10499                dst,
10500                len,
10501                low,
10502                high,
10503                key,
10504                op_seed,
10505            } => {
10506                let n = *len as usize;
10507                unsafe {
10508                    let out = sl_mut(*dst, base, n);
10509                    let opts = *schedule.rng.read().unwrap();
10510                    rlx_ir::fill_uniform_like(out, *low, *high, opts, *key, *op_seed);
10511                }
10512            }
10513
10514            Thunk::GatedDeltaNet {
10515                q,
10516                k,
10517                v,
10518                g,
10519                beta,
10520                state,
10521                dst,
10522                batch,
10523                seq,
10524                heads,
10525                state_size,
10526            } => unsafe {
10527                execute_gated_delta_net_f32(
10528                    *q,
10529                    *k,
10530                    *v,
10531                    *g,
10532                    *beta,
10533                    *state,
10534                    *dst,
10535                    *batch as usize,
10536                    *seq as usize,
10537                    *heads as usize,
10538                    *state_size as usize,
10539                    base,
10540                );
10541            },
10542
10543            Thunk::Lstm {
10544                x,
10545                w_ih,
10546                w_hh,
10547                bias,
10548                h0,
10549                c0,
10550                dst,
10551                batch,
10552                seq,
10553                input_size,
10554                hidden,
10555                num_layers,
10556                bidirectional,
10557                carry,
10558            } => unsafe {
10559                execute_lstm_f32(
10560                    *x,
10561                    *w_ih,
10562                    *w_hh,
10563                    *bias,
10564                    *h0,
10565                    *c0,
10566                    *dst,
10567                    *batch as usize,
10568                    *seq as usize,
10569                    *input_size as usize,
10570                    *hidden as usize,
10571                    *num_layers as usize,
10572                    *bidirectional,
10573                    *carry,
10574                    base,
10575                );
10576            },
10577
10578            Thunk::SelectiveScan {
10579                x,
10580                delta,
10581                a,
10582                b: bp,
10583                c: cp,
10584                dst,
10585                batch,
10586                seq,
10587                hidden,
10588                state_size,
10589            } => {
10590                let (b, s, h, n) = (
10591                    *batch as usize,
10592                    *seq as usize,
10593                    *hidden as usize,
10594                    *state_size as usize,
10595                );
10596                unsafe {
10597                    let xs = sl(*x, base, b * s * h);
10598                    let dt = sl(*delta, base, b * s * h);
10599                    let am = sl(*a, base, h * n);
10600                    let bm = sl(*bp, base, b * s * n);
10601                    let cm = sl(*cp, base, b * s * n);
10602                    let out = sl_mut(*dst, base, b * s * h);
10603
10604                    // State buffer per-batch: h channels × n state.
10605                    // Sequential along the seq dimension; could
10606                    // parallelize over batch+channel later.
10607                    let mut state = vec![0f32; h * n];
10608                    for bi in 0..b {
10609                        // Reset state at the start of each batch row.
10610                        for v in state.iter_mut() {
10611                            *v = 0.0;
10612                        }
10613                        for si in 0..s {
10614                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10615                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10616                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10617                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
10618                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
10619
10620                            for ci in 0..h {
10621                                let d = dt_row[ci];
10622                                let xv = x_row[ci];
10623                                let mut acc = 0f32;
10624                                for ni in 0..n {
10625                                    // Discretize: exp(d * a) and d * b.
10626                                    let da = (d * am[ci * n + ni]).exp();
10627                                    state[ci * n + ni] =
10628                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
10629                                    acc += c_row[ni] * state[ci * n + ni];
10630                                }
10631                                out_row[ci] = acc;
10632                            }
10633                        }
10634                    }
10635                }
10636            }
10637
10638            Thunk::DequantMatMul {
10639                x,
10640                w_q,
10641                scale,
10642                zp,
10643                dst,
10644                m,
10645                k,
10646                n,
10647                block_size,
10648                is_asymmetric,
10649            } => {
10650                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10651                let n_blocks = k.div_ceil(bs);
10652                unsafe {
10653                    let xs = sl(*x, base, m * k);
10654                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
10655                    let scales = sl(*scale, base, n_blocks * n);
10656                    let zps = if *is_asymmetric {
10657                        sl(*zp, base, n_blocks * n)
10658                    } else {
10659                        &[][..]
10660                    };
10661                    let out = sl_mut(*dst, base, m * n);
10662                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10663                }
10664            }
10665
10666            Thunk::DequantMatMulGguf {
10667                x,
10668                w_q,
10669                dst,
10670                m,
10671                k,
10672                n,
10673                scheme,
10674            } => {
10675                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10676                let block_bytes = scheme.gguf_block_bytes() as usize;
10677                let block_elems = scheme.gguf_block_size() as usize;
10678                debug_assert!(
10679                    block_bytes > 0 && block_elems > 0,
10680                    "non-GGUF scheme in GGUF arm"
10681                );
10682                debug_assert!(
10683                    (k * n).is_multiple_of(block_elems),
10684                    "k*n={} not aligned to GGUF block size {}",
10685                    k * n,
10686                    block_elems
10687                );
10688                let total_bytes = (k * n) / block_elems * block_bytes;
10689                unsafe {
10690                    let xs = sl(*x, base, m * k);
10691                    let w_bytes_ptr = base.add(*w_q) as *const u8;
10692                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
10693                    let out = sl_mut(*dst, base, m * n);
10694                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
10695                }
10696            }
10697
10698            Thunk::DequantMatMulInt4 {
10699                x,
10700                w_q,
10701                scale,
10702                zp,
10703                dst,
10704                m,
10705                k,
10706                n,
10707                block_size,
10708                is_asymmetric,
10709            } => {
10710                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
10711                let n_blocks = k.div_ceil(bs);
10712                unsafe {
10713                    let xs = sl(*x, base, m * k);
10714                    let w_bytes = std::slice::from_raw_parts(
10715                        base.add(*w_q) as *const u8,
10716                        (k * n).div_ceil(2),
10717                    );
10718                    let scales = sl(*scale, base, n_blocks * n);
10719                    let zps = if *is_asymmetric {
10720                        sl(*zp, base, n_blocks * n)
10721                    } else {
10722                        &[][..]
10723                    };
10724                    let out = sl_mut(*dst, base, m * n);
10725                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
10726                }
10727            }
10728
10729            Thunk::DequantMatMulFp8 {
10730                x,
10731                w_q,
10732                scale,
10733                dst,
10734                m,
10735                k,
10736                n,
10737                e5m2,
10738            } => {
10739                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10740                unsafe {
10741                    let xs = sl(*x, base, m * k);
10742                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
10743                    let scales = sl(*scale, base, n);
10744                    let out = sl_mut(*dst, base, m * n);
10745                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
10746                }
10747            }
10748
10749            Thunk::DequantMatMulNvfp4 {
10750                x,
10751                w_q,
10752                scale,
10753                global_scale,
10754                dst,
10755                m,
10756                k,
10757                n,
10758            } => {
10759                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
10760                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
10761                unsafe {
10762                    let xs = sl(*x, base, m * k);
10763                    let w_bytes = std::slice::from_raw_parts(
10764                        base.add(*w_q) as *const u8,
10765                        (k * n).div_ceil(2),
10766                    );
10767                    let scale_bytes =
10768                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
10769                    let gs = sl(*global_scale, base, 1)[0];
10770                    let out = sl_mut(*dst, base, m * n);
10771                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
10772                }
10773            }
10774
10775            Thunk::LoraMatMul {
10776                x,
10777                w,
10778                a,
10779                b,
10780                dst,
10781                m,
10782                k,
10783                n,
10784                r,
10785                scale,
10786            } => {
10787                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
10788                unsafe {
10789                    let xs = sl(*x, base, m * k);
10790                    let ws = sl(*w, base, k * n);
10791                    let a_s = sl(*a, base, k * r);
10792                    let bs = sl(*b, base, r * n);
10793                    let out = sl_mut(*dst, base, m * n);
10794                    crate::blas::sgemm(xs, ws, out, m, k, n);
10795                    let mut tmp = vec![0f32; m * r];
10796                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
10797                    if *scale != 1.0 {
10798                        for v in tmp.iter_mut() {
10799                            *v *= *scale;
10800                        }
10801                    }
10802                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
10803                }
10804            }
10805
10806            Thunk::Attention {
10807                q,
10808                k,
10809                v,
10810                mask,
10811                out,
10812                batch,
10813                seq,
10814                kv_seq,
10815                heads,
10816                head_dim,
10817                mask_kind,
10818                scale,
10819                q_row_stride,
10820                k_row_stride,
10821                v_row_stride,
10822                bhsd,
10823            } => {
10824                let (b, q_s, k_s, nh, dh) = (
10825                    *batch as usize,
10826                    *seq as usize,
10827                    *kv_seq as usize,
10828                    *heads as usize,
10829                    *head_dim as usize,
10830                );
10831                let hs = nh * dh;
10832                // For [B, H, S, D] layout each (b, h) tile is dense
10833                // contiguous; the qrs/krs/vrs strides are not used.
10834                let (qrs, krs, vrs) = if *bhsd {
10835                    (dh, dh, dh)
10836                } else {
10837                    (
10838                        *q_row_stride as usize,
10839                        *k_row_stride as usize,
10840                        *v_row_stride as usize,
10841                    )
10842                };
10843                let bhsd = *bhsd;
10844                let _ = (q_row_stride, k_row_stride, v_row_stride);
10845                let scale = *scale;
10846                let ss = q_s * k_s;
10847                let cfg = crate::config::RuntimeConfig::global();
10848                unsafe {
10849                    // Slice lengths cover the strided span. When Q/K/V
10850                    // alias the parent QKV (post-#46-fusion), the same
10851                    // bytes back all three slices — compiler bounds
10852                    // checks see the right size. For [B, H, S, D] the
10853                    // buffer is densely B*H*S*D elements; the row
10854                    // strides aren't used.
10855                    let q_len = if bhsd {
10856                        b * nh * q_s * dh
10857                    } else {
10858                        b * q_s * qrs
10859                    };
10860                    let k_len = if bhsd {
10861                        b * nh * k_s * dh
10862                    } else {
10863                        b * k_s * krs
10864                    };
10865                    let v_len = if bhsd {
10866                        b * nh * k_s * dh
10867                    } else {
10868                        b * k_s * vrs
10869                    };
10870                    let q_data = sl(*q, base, q_len);
10871                    let k_data = sl(*k, base, k_len);
10872                    let v_data = sl(*v, base, v_len);
10873                    let mask_data: &[f32] = match mask_kind {
10874                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
10875                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
10876                        _ => &[],
10877                    };
10878                    let out_len = if bhsd {
10879                        b * nh * q_s * dh
10880                    } else {
10881                        b * q_s * hs
10882                    };
10883                    let out_data = sl_mut(*out, base, out_len);
10884
10885                    // ── [B, H, S, D] fallback ──────────────────────
10886                    // The NEON / strided-BLAS specializations below
10887                    // are written for the [B, S, H, D] layout. When
10888                    // the input is head-major ([B, H, S, D] —
10889                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
10890                    // them and run a simple (correct but slower)
10891                    // scalar implementation. Production-CPU inference
10892                    // graphs use [B, S, H, D] so they still hit the
10893                    // hot path; cross-backend parity tests use
10894                    // [B, H, S, D] and land here.
10895                    if bhsd {
10896                        let scores = &mut sdpa_scores[..ss];
10897                        for bi in 0..b {
10898                            for hi in 0..nh {
10899                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
10900                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
10901                                // Q@K^T
10902                                for qi in 0..q_s {
10903                                    let q_base = q_head_base + qi * dh;
10904                                    for ki in 0..k_s {
10905                                        let k_base = k_head_base + ki * dh;
10906                                        let mut dot = 0f32;
10907                                        for d in 0..dh {
10908                                            dot += q_data[q_base + d] * k_data[k_base + d];
10909                                        }
10910                                        scores[qi * k_s + ki] = dot * scale;
10911                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10912                                            && !mask_data.is_empty()
10913                                            && mask_data[bi * k_s + ki] < mask_thr
10914                                        {
10915                                            scores[qi * k_s + ki] = mask_neg;
10916                                        }
10917                                    }
10918                                }
10919                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
10920                                    let off = (bi * nh + hi) * q_s * k_s;
10921                                    for i in 0..q_s * k_s {
10922                                        scores[i] += mask_data[off + i];
10923                                    }
10924                                }
10925                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
10926                                crate::kernels::neon_softmax(scores, q_s, k_s);
10927                                // score @ V
10928                                for qi in 0..q_s {
10929                                    let o_base = q_head_base + qi * dh;
10930                                    for d in 0..dh {
10931                                        out_data[o_base + d] = 0.0;
10932                                    }
10933                                    for ki in 0..k_s {
10934                                        let sc = scores[qi * k_s + ki];
10935                                        if sc > score_thr {
10936                                            let v_base = k_head_base + ki * dh;
10937                                            for d in 0..dh {
10938                                                out_data[o_base + d] += sc * v_data[v_base + d];
10939                                            }
10940                                        }
10941                                    }
10942                                }
10943                            }
10944                        }
10945                        continue;
10946                    }
10947
10948                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
10949                    // For tiny inputs (batch=1, short seq), per-head BLAS call
10950                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
10951                    // exceeds the NEON compute cost. Use direct strided NEON
10952                    // with zero dispatch overhead.
10953                    // For batch≥2: always BLAS + par_for (parallelism wins).
10954                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
10955                        // ── Sequential NEON path (zero overhead) ──
10956                        let scores = &mut sdpa_scores[..ss];
10957                        #[cfg(target_arch = "aarch64")]
10958                        let neon_chunks = dh / 4;
10959
10960                        for bi in 0..b {
10961                            for hi in 0..nh {
10962                                // Q@K^T via strided NEON dot products
10963                                for qi in 0..q_s {
10964                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
10965                                    for ki in 0..k_s {
10966                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
10967                                        #[cfg(target_arch = "aarch64")]
10968                                        let mut dot;
10969                                        #[cfg(not(target_arch = "aarch64"))]
10970                                        let mut dot = 0f32;
10971                                        #[cfg(target_arch = "aarch64")]
10972                                        {
10973                                            use std::arch::aarch64::*;
10974                                            let mut acc = vdupq_n_f32(0.0);
10975                                            for c in 0..neon_chunks {
10976                                                let vq =
10977                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
10978                                                let vk =
10979                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
10980                                                acc = vfmaq_f32(acc, vq, vk);
10981                                            }
10982                                            dot = vaddvq_f32(acc);
10983                                            for d in (neon_chunks * 4)..dh {
10984                                                dot += q_data[q_off + d] * k_data[k_off + d];
10985                                            }
10986                                        }
10987                                        #[cfg(not(target_arch = "aarch64"))]
10988                                        for d in 0..dh {
10989                                            dot += q_data[q_off + d] * k_data[k_off + d];
10990                                        }
10991                                        scores[qi * k_s + ki] = dot * scale;
10992                                        // Inner-loop Custom mask check —
10993                                        // Causal / SlidingWindow / None
10994                                        // apply outside the loop below.
10995                                        // Skip for Bias — that mask is a
10996                                        // per-head additive tensor, not a
10997                                        // 0/1 key-padding mask.
10998                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
10999                                            && !mask_data.is_empty()
11000                                            && mask_data[bi * k_s + ki] < mask_thr
11001                                        {
11002                                            scores[qi * k_s + ki] = mask_neg;
11003                                        }
11004                                    }
11005                                }
11006
11007                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
11008                                    let off = (bi * nh + hi) * q_s * k_s;
11009                                    for i in 0..q_s * k_s {
11010                                        scores[i] += mask_data[off + i];
11011                                    }
11012                                }
11013                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
11014                                crate::kernels::neon_softmax(scores, q_s, k_s);
11015
11016                                // Score@V via strided NEON accumulation (zero-copy)
11017                                for qi in 0..q_s {
11018                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
11019                                    // Zero output for this head position
11020                                    for d in 0..dh {
11021                                        out_data[o_off + d] = 0.0;
11022                                    }
11023                                    for ki in 0..k_s {
11024                                        let sc = scores[qi * k_s + ki];
11025                                        if sc > score_thr {
11026                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
11027                                            #[cfg(target_arch = "aarch64")]
11028                                            {
11029                                                use std::arch::aarch64::*;
11030                                                let vsc = vdupq_n_f32(sc);
11031                                                for c in 0..neon_chunks {
11032                                                    let off = c * 4;
11033                                                    let vo = vld1q_f32(
11034                                                        out_data.as_ptr().add(o_off + off),
11035                                                    );
11036                                                    let vv =
11037                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
11038                                                    vst1q_f32(
11039                                                        out_data.as_mut_ptr().add(o_off + off),
11040                                                        vfmaq_f32(vo, vsc, vv),
11041                                                    );
11042                                                }
11043                                            }
11044                                            #[cfg(not(target_arch = "aarch64"))]
11045                                            for d in 0..dh {
11046                                                out_data[o_off + d] += sc * v_data[v_off + d];
11047                                            }
11048                                        }
11049                                    }
11050                                }
11051                            }
11052                        }
11053                    } else {
11054                        // ── Parallel strided BLAS path (high throughput) ──
11055                        let total_work = b * nh;
11056                        let q_addr = q_data.as_ptr() as usize;
11057                        let k_addr = k_data.as_ptr() as usize;
11058                        let v_addr = v_data.as_ptr() as usize;
11059                        let m_addr = mask_data.as_ptr() as usize;
11060                        let o_addr = out_data.as_mut_ptr() as usize;
11061                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
11062
11063                        crate::pool::par_for(total_work, 1, &|off, cnt| {
11064                            for idx in off..off + cnt {
11065                                let bi = idx / nh;
11066                                let hi = idx % nh;
11067
11068                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
11069                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
11070                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
11071                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
11072                                let sc = std::slice::from_raw_parts_mut(
11073                                    (sc_addr as *mut f32).add(idx * ss),
11074                                    ss,
11075                                );
11076
11077                                // LDA = qrs, LDB = krs (parent row strides
11078                                // when fused; hs otherwise).
11079                                crate::blas::sgemm_general(
11080                                    q_start,
11081                                    k_start,
11082                                    sc.as_mut_ptr(),
11083                                    q_s,
11084                                    k_s,
11085                                    dh,
11086                                    scale,
11087                                    0.0,
11088                                    qrs,
11089                                    krs,
11090                                    k_s,
11091                                    false,
11092                                    true,
11093                                );
11094
11095                                match mask_kind {
11096                                    rlx_ir::op::MaskKind::Custom => {
11097                                        let mask_bi = std::slice::from_raw_parts(
11098                                            (m_addr as *const f32).add(bi * k_s),
11099                                            k_s,
11100                                        );
11101                                        for ki in 0..k_s {
11102                                            if mask_bi[ki] < mask_thr {
11103                                                for qi in 0..q_s {
11104                                                    sc[qi * k_s + ki] = mask_neg;
11105                                                }
11106                                            }
11107                                        }
11108                                    }
11109                                    rlx_ir::op::MaskKind::Bias => {
11110                                        // Per-head additive bias slice.
11111                                        let bias = std::slice::from_raw_parts(
11112                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
11113                                            q_s * k_s,
11114                                        );
11115                                        for i in 0..q_s * k_s {
11116                                            sc[i] += bias[i];
11117                                        }
11118                                    }
11119                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
11120                                }
11121
11122                                crate::kernels::neon_softmax(sc, q_s, k_s);
11123
11124                                // LDB = vrs (parent row stride when
11125                                // fused; hs otherwise). LDC stays hs —
11126                                // output is its own contiguous buffer.
11127                                crate::blas::sgemm_general(
11128                                    sc.as_ptr(),
11129                                    v_start,
11130                                    o_start,
11131                                    q_s,
11132                                    dh,
11133                                    k_s,
11134                                    1.0,
11135                                    0.0,
11136                                    k_s,
11137                                    vrs,
11138                                    hs,
11139                                    false,
11140                                    false,
11141                                );
11142                            }
11143                        });
11144                    }
11145                }
11146            }
11147
11148            Thunk::AttentionBackward {
11149                q,
11150                k,
11151                v,
11152                dy,
11153                mask,
11154                out,
11155                batch,
11156                seq,
11157                kv_seq,
11158                heads,
11159                head_dim,
11160                mask_kind,
11161                wrt,
11162                bhsd,
11163            } => {
11164                let (b, q_s, k_s, nh, dh) = (
11165                    *batch as usize,
11166                    *seq as usize,
11167                    *kv_seq as usize,
11168                    *heads as usize,
11169                    *head_dim as usize,
11170                );
11171                unsafe {
11172                    let q_len = if *bhsd {
11173                        b * nh * q_s * dh
11174                    } else {
11175                        b * q_s * nh * dh
11176                    };
11177                    let k_len = if *bhsd {
11178                        b * nh * k_s * dh
11179                    } else {
11180                        b * k_s * nh * dh
11181                    };
11182                    let out_len = match wrt {
11183                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
11184                            k_len
11185                        }
11186                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
11187                    };
11188                    let q_data = sl(*q, base, q_len);
11189                    let k_data = sl(*k, base, k_len);
11190                    let v_data = sl(*v, base, k_len);
11191                    let dy_data = sl(*dy, base, q_len);
11192                    let out_data = sl_mut(*out, base, out_len);
11193                    let mask_data: &[f32] = if *mask != 0 {
11194                        let ml = match mask_kind {
11195                            rlx_ir::op::MaskKind::Custom => b * k_s,
11196                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
11197                            _ => 0,
11198                        };
11199                        sl(*mask, base, ml)
11200                    } else {
11201                        &[]
11202                    };
11203                    crate::attention_bwd::attention_backward(
11204                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
11205                        *mask_kind, mask_data, *bhsd,
11206                    );
11207                }
11208            }
11209
11210            Thunk::ActivationInPlace { data, len, act } => {
11211                let len = *len as usize;
11212                unsafe {
11213                    let d = sl_mut(*data, base, len);
11214                    match act {
11215                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
11216                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
11217                        Activation::Silu => crate::kernels::par_silu_inplace(d),
11218                        Activation::Relu => {
11219                            for v in d.iter_mut() {
11220                                *v = v.max(0.0);
11221                            }
11222                        }
11223                        Activation::Sigmoid => {
11224                            for v in d.iter_mut() {
11225                                *v = 1.0 / (1.0 + (-*v).exp());
11226                            }
11227                        }
11228                        Activation::Tanh => {
11229                            for v in d.iter_mut() {
11230                                *v = v.tanh();
11231                            }
11232                        }
11233                        Activation::Exp => {
11234                            for v in d.iter_mut() {
11235                                *v = v.exp();
11236                            }
11237                        }
11238                        Activation::Log => {
11239                            for v in d.iter_mut() {
11240                                *v = v.ln();
11241                            }
11242                        }
11243                        Activation::Sqrt => {
11244                            for v in d.iter_mut() {
11245                                *v = v.sqrt();
11246                            }
11247                        }
11248                        Activation::Rsqrt => {
11249                            for v in d.iter_mut() {
11250                                *v = 1.0 / v.sqrt();
11251                            }
11252                        }
11253                        Activation::Neg => {
11254                            for v in d.iter_mut() {
11255                                *v = -*v;
11256                            }
11257                        }
11258                        Activation::Abs => {
11259                            for v in d.iter_mut() {
11260                                *v = v.abs();
11261                            }
11262                        }
11263                        Activation::Round => {
11264                            for v in d.iter_mut() {
11265                                *v = v.round();
11266                            }
11267                        }
11268                        Activation::Sin => {
11269                            for v in d.iter_mut() {
11270                                *v = v.sin();
11271                            }
11272                        }
11273                        Activation::Cos => {
11274                            for v in d.iter_mut() {
11275                                *v = v.cos();
11276                            }
11277                        }
11278                        Activation::Tan => {
11279                            for v in d.iter_mut() {
11280                                *v = v.tan();
11281                            }
11282                        }
11283                        Activation::Atan => {
11284                            for v in d.iter_mut() {
11285                                *v = v.atan();
11286                            }
11287                        }
11288                    }
11289                }
11290            }
11291
11292            Thunk::FusedAttnBlock {
11293                hidden,
11294                qkv_w,
11295                out_w,
11296                mask,
11297                out,
11298                qkv_b,
11299                out_b,
11300                cos,
11301                sin,
11302                cos_len,
11303                batch,
11304                seq,
11305                hs,
11306                nh,
11307                dh,
11308                has_bias,
11309                has_rope,
11310            } => {
11311                let (b, s) = (*batch as usize, *seq as usize);
11312                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
11313                let m = b * s;
11314                let scale = (d_h as f32).powf(-0.5);
11315                let half = d_h / 2;
11316                unsafe {
11317                    let inp = sl(*hidden, base, m * h);
11318                    let wq = sl(*qkv_w, base, h * 3 * h);
11319                    let wo = sl(*out_w, base, h * h);
11320                    let mk = sl(*mask, base, b * s);
11321                    let dst = sl_mut(*out, base, m * h);
11322
11323                    // Stack-allocated intermediates — all fit in L1 cache for small batch
11324                    let mut qkv = vec![0f32; m * 3 * h];
11325                    let mut attn_out = vec![0f32; m * h];
11326                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
11327
11328                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
11329                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
11330                    if *has_bias {
11331                        let bias = sl(*qkv_b, base, 3 * h);
11332                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
11333                    }
11334
11335                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
11336                    //    Process heads sequentially with inline RoPE — zero copy.
11337                    #[cfg(target_arch = "aarch64")]
11338                    let neon_chunks = d_h / 4;
11339                    #[cfg(target_arch = "aarch64")]
11340                    let _rope_chunks = half / 4;
11341
11342                    for bi in 0..b {
11343                        for hi in 0..n_h {
11344                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
11345                            for qi in 0..s {
11346                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11347                                for ki in 0..s {
11348                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11349                                    let mut dot = 0f32;
11350
11351                                    if *has_rope {
11352                                        // Apply RoPE inline during dot product
11353                                        let q_cos = qi * half;
11354                                        let k_cos = ki * half;
11355                                        let cos_tab = sl(*cos, base, *cos_len as usize);
11356                                        let sin_tab = sl(*sin, base, *cos_len as usize);
11357                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
11358                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
11359                                        for i in 0..half {
11360                                            let q1 = qkv[q_base + i];
11361                                            let q2 = qkv[q_base + half + i];
11362                                            let k1 = qkv[k_base + i];
11363                                            let k2 = qkv[k_base + half + i];
11364                                            let c_q = cos_tab[q_cos + i];
11365                                            let s_q = sin_tab[q_cos + i];
11366                                            let c_k = cos_tab[k_cos + i];
11367                                            let s_k = sin_tab[k_cos + i];
11368                                            let qr1 = q1 * c_q - q2 * s_q;
11369                                            let kr1 = k1 * c_k - k2 * s_k;
11370                                            let qr2 = q2 * c_q + q1 * s_q;
11371                                            let kr2 = k2 * c_k + k1 * s_k;
11372                                            dot += qr1 * kr1 + qr2 * kr2;
11373                                        }
11374                                    } else {
11375                                        // Standard dot product
11376                                        #[cfg(target_arch = "aarch64")]
11377                                        {
11378                                            use std::arch::aarch64::*;
11379                                            let mut acc = vdupq_n_f32(0.0);
11380                                            for c in 0..neon_chunks {
11381                                                let vq =
11382                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
11383                                                let vk =
11384                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
11385                                                acc = vfmaq_f32(acc, vq, vk);
11386                                            }
11387                                            dot = vaddvq_f32(acc);
11388                                            for d in (neon_chunks * 4)..d_h {
11389                                                dot += qkv[q_base + d] * qkv[k_base + d];
11390                                            }
11391                                        }
11392                                        #[cfg(not(target_arch = "aarch64"))]
11393                                        for d in 0..d_h {
11394                                            dot += qkv[q_base + d] * qkv[k_base + d];
11395                                        }
11396                                    }
11397
11398                                    scores_buf[qi * s + ki] = dot * scale;
11399                                    if mk[bi * s + ki] < mask_thr {
11400                                        scores_buf[qi * s + ki] = mask_neg;
11401                                    }
11402                                }
11403                            }
11404
11405                            // Softmax
11406                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
11407
11408                            // Score @ V accumulation (V at offset 2h in QKV)
11409                            for qi in 0..s {
11410                                let o_base = bi * s * h + qi * h + hi * d_h;
11411                                for d in 0..d_h {
11412                                    attn_out[o_base + d] = 0.0;
11413                                }
11414                                for ki in 0..s {
11415                                    let sc = scores_buf[qi * s + ki];
11416                                    if sc > score_thr {
11417                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11418                                        #[cfg(target_arch = "aarch64")]
11419                                        {
11420                                            use std::arch::aarch64::*;
11421                                            let vsc = vdupq_n_f32(sc);
11422                                            for c in 0..neon_chunks {
11423                                                let off = c * 4;
11424                                                let vo =
11425                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
11426                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
11427                                                vst1q_f32(
11428                                                    attn_out.as_mut_ptr().add(o_base + off),
11429                                                    vfmaq_f32(vo, vsc, vv),
11430                                                );
11431                                            }
11432                                        }
11433                                        #[cfg(not(target_arch = "aarch64"))]
11434                                        for d in 0..d_h {
11435                                            attn_out[o_base + d] += sc * qkv[v_base + d];
11436                                        }
11437                                    }
11438                                }
11439                            }
11440                        }
11441                    }
11442
11443                    // 3. Output projection: [m, h] @ [h, h] → dst
11444                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
11445                    if *has_bias {
11446                        let bias = sl(*out_b, base, h);
11447                        crate::blas::bias_add(dst, bias, m, h);
11448                    }
11449                }
11450            }
11451
11452            Thunk::Rope {
11453                src,
11454                cos,
11455                sin,
11456                dst,
11457                batch,
11458                seq,
11459                hidden,
11460                head_dim,
11461                n_rot,
11462                cos_len,
11463                src_row_stride,
11464            } => {
11465                let (b, s, hs, dh, nr) = (
11466                    *batch as usize,
11467                    *seq as usize,
11468                    *hidden as usize,
11469                    *head_dim as usize,
11470                    *n_rot as usize,
11471                );
11472                let tab_half = dh / 2;
11473                let rot_half = nr / 2;
11474                let nh = hs / dh;
11475                let cl = *cos_len as usize;
11476                let src_rs = *src_row_stride as usize;
11477                unsafe {
11478                    let x = sl(*src, base, b * s * src_rs);
11479                    let cos_tab = sl(*cos, base, cl);
11480                    let sin_tab = sl(*sin, base, cl);
11481                    let out = sl_mut(*dst, base, b * s * hs);
11482
11483                    let total = b * s;
11484                    let x_ptr = x.as_ptr() as usize;
11485                    let o_ptr = out.as_mut_ptr() as usize;
11486                    let c_ptr = cos_tab.as_ptr() as usize;
11487                    let s_ptr = sin_tab.as_ptr() as usize;
11488
11489                    crate::pool::par_for(total, 4, &|off, cnt| {
11490                        for idx in off..off + cnt {
11491                            let bi = idx / s;
11492                            let si = idx % s;
11493                            let tab_off = si * tab_half;
11494
11495                            for hi in 0..nh {
11496                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
11497                                let dst_base = bi * s * hs + si * hs + hi * dh;
11498                                let xp = (x_ptr as *const f32).add(src_base);
11499                                let op = (o_ptr as *mut f32).add(dst_base);
11500                                let cp = (c_ptr as *const f32).add(tab_off);
11501                                let sp = (s_ptr as *const f32).add(tab_off);
11502
11503                                for i in 0..rot_half {
11504                                    let x1 = *xp.add(i);
11505                                    let x2 = *xp.add(rot_half + i);
11506                                    let cv = *cp.add(i);
11507                                    let sv = *sp.add(i);
11508                                    *op.add(i) = x1 * cv - x2 * sv;
11509                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
11510                                }
11511                                for j in nr..dh {
11512                                    *op.add(j) = *xp.add(j);
11513                                }
11514                            }
11515                        }
11516                    });
11517                }
11518            }
11519            Thunk::FusedBertLayer {
11520                hidden,
11521                qkv_w,
11522                qkv_b,
11523                out_w,
11524                out_b,
11525                mask,
11526                ln1_g,
11527                ln1_b,
11528                eps1,
11529                fc1_w,
11530                fc1_b,
11531                fc2_w,
11532                fc2_b,
11533                ln2_g,
11534                ln2_b,
11535                eps2,
11536                out,
11537                batch,
11538                seq,
11539                hs,
11540                nh,
11541                dh,
11542                int_dim,
11543            } => {
11544                let (b, s, h, n_h, d_h) = (
11545                    *batch as usize,
11546                    *seq as usize,
11547                    *hs as usize,
11548                    *nh as usize,
11549                    *dh as usize,
11550                );
11551                let m = b * s;
11552                let id = *int_dim as usize;
11553                let scale = (d_h as f32).powf(-0.5);
11554                let _half = d_h / 2;
11555                #[cfg(target_arch = "aarch64")]
11556                let neon_chunks = d_h / 4;
11557                unsafe {
11558                    let inp = sl(*hidden, base, m * h);
11559                    let dst = sl_mut(*out, base, m * h);
11560                    let mk = sl(*mask, base, b * s);
11561
11562                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
11563                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
11564                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
11565                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
11566                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
11567                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
11568                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
11569
11570                    // QKV (parallelized across cores — multiple AMX coprocessors)
11571                    crate::blas::par_sgemm_bias(
11572                        inp,
11573                        sl(*qkv_w, base, h * 3 * h),
11574                        sl(*qkv_b, base, 3 * h),
11575                        qkv,
11576                        m,
11577                        h,
11578                        3 * h,
11579                    );
11580
11581                    // SDPA per head (sequential NEON, inline — zero overhead)
11582                    for bi in 0..b {
11583                        for hi in 0..n_h {
11584                            for qi in 0..s {
11585                                for ki in 0..s {
11586                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11587                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11588                                    #[cfg(target_arch = "aarch64")]
11589                                    let dot;
11590                                    #[cfg(not(target_arch = "aarch64"))]
11591                                    let mut dot = 0f32;
11592                                    #[cfg(target_arch = "aarch64")]
11593                                    {
11594                                        use std::arch::aarch64::*;
11595                                        let mut acc = vdupq_n_f32(0.0);
11596                                        for c in 0..neon_chunks {
11597                                            acc = vfmaq_f32(
11598                                                acc,
11599                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
11600                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
11601                                            );
11602                                        }
11603                                        dot = vaddvq_f32(acc);
11604                                    }
11605                                    #[cfg(not(target_arch = "aarch64"))]
11606                                    for d in 0..d_h {
11607                                        dot += qkv[q_base + d] * qkv[k_base + d];
11608                                    }
11609                                    sc[qi * s + ki] = dot * scale;
11610                                    if mk[bi * s + ki] < mask_thr {
11611                                        sc[qi * s + ki] = mask_neg;
11612                                    }
11613                                }
11614                            }
11615                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11616                            for qi in 0..s {
11617                                let o = bi * s * h + qi * h + hi * d_h;
11618                                for d in 0..d_h {
11619                                    attn[o + d] = 0.0;
11620                                }
11621                                for ki in 0..s {
11622                                    let w = sc[qi * s + ki];
11623                                    if w > score_thr {
11624                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11625                                        #[cfg(target_arch = "aarch64")]
11626                                        {
11627                                            use std::arch::aarch64::*;
11628                                            let vw = vdupq_n_f32(w);
11629                                            for c in 0..neon_chunks {
11630                                                let off = c * 4;
11631                                                vst1q_f32(
11632                                                    attn.as_mut_ptr().add(o + off),
11633                                                    vfmaq_f32(
11634                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11635                                                        vw,
11636                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11637                                                    ),
11638                                                );
11639                                            }
11640                                        }
11641                                        #[cfg(not(target_arch = "aarch64"))]
11642                                        for d in 0..d_h {
11643                                            attn[o + d] += w * qkv[v + d];
11644                                        }
11645                                    }
11646                                }
11647                            }
11648                        }
11649                    }
11650
11651                    // Out proj (sgemm + bias fused) + residual add with NEON
11652                    crate::blas::sgemm_bias(
11653                        attn,
11654                        sl(*out_w, base, h * h),
11655                        sl(*out_b, base, h),
11656                        res,
11657                        m,
11658                        h,
11659                        h,
11660                    );
11661                    #[cfg(target_arch = "aarch64")]
11662                    {
11663                        use std::arch::aarch64::*;
11664                        let chunks_h = (m * h) / 4;
11665                        for c in 0..chunks_h {
11666                            let off = c * 4;
11667                            vst1q_f32(
11668                                res.as_mut_ptr().add(off),
11669                                vaddq_f32(
11670                                    vld1q_f32(res.as_ptr().add(off)),
11671                                    vld1q_f32(inp.as_ptr().add(off)),
11672                                ),
11673                            );
11674                        }
11675                        for i in (chunks_h * 4)..(m * h) {
11676                            res[i] += inp[i];
11677                        }
11678                    }
11679                    #[cfg(not(target_arch = "aarch64"))]
11680                    for i in 0..m * h {
11681                        res[i] += inp[i];
11682                    }
11683
11684                    // LN1 (fused residual already done above — just normalize)
11685                    let g1 = sl(*ln1_g, base, h);
11686                    let b1 = sl(*ln1_b, base, h);
11687                    for r in 0..m {
11688                        crate::kernels::layer_norm_row(
11689                            &res[r * h..(r + 1) * h],
11690                            g1,
11691                            b1,
11692                            &mut normed[r * h..(r + 1) * h],
11693                            h,
11694                            *eps1,
11695                        );
11696                    }
11697
11698                    // FFN: fc1 (parallel across cores) + GELU
11699                    crate::blas::par_sgemm_bias(
11700                        normed,
11701                        sl(*fc1_w, base, h * id),
11702                        sl(*fc1_b, base, id),
11703                        ffn,
11704                        m,
11705                        h,
11706                        id,
11707                    );
11708                    crate::kernels::par_gelu_inplace(ffn);
11709
11710                    // fc2 + bias (parallel across cores) + residual with NEON
11711                    crate::blas::par_sgemm_bias(
11712                        ffn,
11713                        sl(*fc2_w, base, id * h),
11714                        sl(*fc2_b, base, h),
11715                        res,
11716                        m,
11717                        id,
11718                        h,
11719                    );
11720                    #[cfg(target_arch = "aarch64")]
11721                    {
11722                        use std::arch::aarch64::*;
11723                        let chunks_h = (m * h) / 4;
11724                        for c in 0..chunks_h {
11725                            let off = c * 4;
11726                            vst1q_f32(
11727                                res.as_mut_ptr().add(off),
11728                                vaddq_f32(
11729                                    vld1q_f32(res.as_ptr().add(off)),
11730                                    vld1q_f32(normed.as_ptr().add(off)),
11731                                ),
11732                            );
11733                        }
11734                        for i in (chunks_h * 4)..(m * h) {
11735                            res[i] += normed[i];
11736                        }
11737                    }
11738                    #[cfg(not(target_arch = "aarch64"))]
11739                    for i in 0..m * h {
11740                        res[i] += normed[i];
11741                    }
11742
11743                    // LN2 → output
11744                    let g2 = sl(*ln2_g, base, h);
11745                    let b2 = sl(*ln2_b, base, h);
11746                    for r in 0..m {
11747                        crate::kernels::layer_norm_row(
11748                            &res[r * h..(r + 1) * h],
11749                            g2,
11750                            b2,
11751                            &mut dst[r * h..(r + 1) * h],
11752                            h,
11753                            *eps2,
11754                        );
11755                    }
11756                }
11757            }
11758
11759            Thunk::FusedNomicLayer {
11760                hidden,
11761                qkv_w,
11762                out_w,
11763                mask,
11764                cos,
11765                sin,
11766                cos_len,
11767                ln1_g,
11768                ln1_b,
11769                eps1,
11770                fc11_w,
11771                fc12_w: _,
11772                fc2_w,
11773                ln2_g,
11774                ln2_b,
11775                eps2,
11776                out,
11777                batch,
11778                seq,
11779                hs,
11780                nh,
11781                dh,
11782                int_dim,
11783            } => {
11784                let (b, s, h, n_h, d_h) = (
11785                    *batch as usize,
11786                    *seq as usize,
11787                    *hs as usize,
11788                    *nh as usize,
11789                    *dh as usize,
11790                );
11791                let m = b * s;
11792                let id = *int_dim as usize;
11793                let scale = (d_h as f32).powf(-0.5);
11794                let half_dh = d_h / 2;
11795                #[cfg(target_arch = "aarch64")]
11796                let neon_chunks = d_h / 4;
11797                unsafe {
11798                    let inp = sl(*hidden, base, m * h);
11799                    let dst = sl_mut(*out, base, m * h);
11800                    let mk = sl(*mask, base, b * s);
11801                    let cos_tab = sl(*cos, base, *cos_len as usize);
11802                    let sin_tab = sl(*sin, base, *cos_len as usize);
11803                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
11804                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
11805
11806                    let mut qkv = vec![0f32; m * 3 * h];
11807                    let mut attn = vec![0f32; m * h];
11808                    let mut res = vec![0f32; m * h];
11809                    let mut normed = vec![0f32; m * h];
11810                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
11811                    let mut sc = vec![0f32; s * s];
11812
11813                    // QKV (no bias)
11814                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
11815
11816                    // SDPA with inline RoPE
11817                    for bi in 0..b {
11818                        for hi in 0..n_h {
11819                            for qi in 0..s {
11820                                for ki in 0..s {
11821                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
11822                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
11823                                    let mut dot = 0f32;
11824                                    for i in 0..half_dh {
11825                                        let q1 = qkv[q_base + i];
11826                                        let q2 = qkv[q_base + half_dh + i];
11827                                        let k1 = qkv[k_base + i];
11828                                        let k2 = qkv[k_base + half_dh + i];
11829                                        let cq = cos_tab[qi * half_dh + i];
11830                                        let sq = sin_tab[qi * half_dh + i];
11831                                        let ck = cos_tab[ki * half_dh + i];
11832                                        let sk = sin_tab[ki * half_dh + i];
11833                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
11834                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
11835                                    }
11836                                    sc[qi * s + ki] = dot * scale;
11837                                    if mk[bi * s + ki] < mask_thr {
11838                                        sc[qi * s + ki] = mask_neg;
11839                                    }
11840                                }
11841                            }
11842                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
11843                            for qi in 0..s {
11844                                let o = bi * s * h + qi * h + hi * d_h;
11845                                for d in 0..d_h {
11846                                    attn[o + d] = 0.0;
11847                                }
11848                                for ki in 0..s {
11849                                    let w = sc[qi * s + ki];
11850                                    if w > score_thr {
11851                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
11852                                        #[cfg(target_arch = "aarch64")]
11853                                        {
11854                                            use std::arch::aarch64::*;
11855                                            let vw = vdupq_n_f32(w);
11856                                            for c in 0..neon_chunks {
11857                                                let off = c * 4;
11858                                                vst1q_f32(
11859                                                    attn.as_mut_ptr().add(o + off),
11860                                                    vfmaq_f32(
11861                                                        vld1q_f32(attn.as_ptr().add(o + off)),
11862                                                        vw,
11863                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
11864                                                    ),
11865                                                );
11866                                            }
11867                                        }
11868                                        #[cfg(not(target_arch = "aarch64"))]
11869                                        for d in 0..d_h {
11870                                            attn[o + d] += w * qkv[v + d];
11871                                        }
11872                                    }
11873                                }
11874                            }
11875                        }
11876                    }
11877
11878                    // Out proj (no bias) + residual
11879                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
11880                    for i in 0..m * h {
11881                        res[i] += inp[i];
11882                    }
11883
11884                    // LN1
11885                    let g1 = sl(*ln1_g, base, h);
11886                    let b1 = sl(*ln1_b, base, h);
11887                    for r in 0..m {
11888                        crate::kernels::layer_norm_row(
11889                            &res[r * h..(r + 1) * h],
11890                            g1,
11891                            b1,
11892                            &mut normed[r * h..(r + 1) * h],
11893                            h,
11894                            *eps1,
11895                        );
11896                    }
11897
11898                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
11899                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
11900                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
11901                    // SiLU on gate, then multiply up * gate → store in up region
11902                    for row in 0..m {
11903                        let bo = row * 2 * id;
11904                        // SiLU in-place on gate portion
11905                        for j in 0..id {
11906                            let x = ffn_concat[bo + id + j];
11907                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
11908                        }
11909                        // Multiply: up[j] *= gate[j]
11910                        for j in 0..id {
11911                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
11912                        }
11913                    }
11914
11915                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
11916                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
11917                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
11918                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
11919                    crate::blas::sgemm_general(
11920                        ffn_concat.as_ptr(),
11921                        sl(*fc2_w, base, id * h).as_ptr(),
11922                        res.as_mut_ptr(),
11923                        m,
11924                        h,
11925                        id,
11926                        1.0,
11927                        0.0,
11928                        2 * id,
11929                        h,
11930                        h,
11931                        false,
11932                        false,
11933                    );
11934                    for i in 0..m * h {
11935                        res[i] += normed[i];
11936                    }
11937
11938                    // LN2 → output
11939                    let g2 = sl(*ln2_g, base, h);
11940                    let b2 = sl(*ln2_b, base, h);
11941                    for r in 0..m {
11942                        crate::kernels::layer_norm_row(
11943                            &res[r * h..(r + 1) * h],
11944                            g2,
11945                            b2,
11946                            &mut dst[r * h..(r + 1) * h],
11947                            h,
11948                            *eps2,
11949                        );
11950                    }
11951                }
11952            }
11953
11954            Thunk::FusedSwiGLU {
11955                src,
11956                dst,
11957                n_half,
11958                total,
11959                gate_first,
11960            } => {
11961                let n = *n_half as usize;
11962                let t = *total as usize;
11963                let outer = t / n;
11964                let in_total = outer * 2 * n;
11965                let gate_first = *gate_first;
11966                unsafe {
11967                    let inp = sl(*src, base, in_total);
11968                    let out = sl_mut(*dst, base, t);
11969                    for o in 0..outer {
11970                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
11971                        let out_row = &mut out[o * n..(o + 1) * n];
11972                        for i in 0..n {
11973                            let (up, gate) = if gate_first {
11974                                (in_row[n + i], in_row[i])
11975                            } else {
11976                                (in_row[i], in_row[n + i])
11977                            };
11978                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
11979                        }
11980                    }
11981                }
11982            }
11983
11984            Thunk::Concat {
11985                dst,
11986                outer,
11987                inner,
11988                total_axis,
11989                inputs,
11990            } => {
11991                let outer = *outer as usize;
11992                let inner = *inner as usize;
11993                let total_axis = *total_axis as usize;
11994                let row_stride = total_axis * inner;
11995                let out_total = outer * row_stride;
11996                unsafe {
11997                    let out = sl_mut(*dst, base, out_total);
11998                    let mut cum: usize = 0;
11999                    for (src_off, in_axis, in_numel) in inputs {
12000                        let in_axis = *in_axis as usize;
12001                        let copy_per_row = in_axis * inner;
12002                        let dst_col_off = cum * inner;
12003                        let inp = sl(*src_off, base, (*in_numel as usize).max(1));
12004                        concat_copy_rows_f32(
12005                            out,
12006                            inp,
12007                            outer,
12008                            copy_per_row,
12009                            row_stride,
12010                            dst_col_off,
12011                            *in_numel as usize,
12012                        );
12013                        cum += in_axis;
12014                    }
12015                }
12016            }
12017
12018            Thunk::ConcatF64 {
12019                dst,
12020                outer,
12021                inner,
12022                total_axis,
12023                inputs,
12024            } => {
12025                let outer = *outer as usize;
12026                let inner = *inner as usize;
12027                let total_axis = *total_axis as usize;
12028                let row_stride = total_axis * inner;
12029                let out_total = outer * row_stride;
12030                unsafe {
12031                    let out = sl_mut_f64(*dst, base, out_total);
12032                    let mut cum: usize = 0;
12033                    for (src_off, in_axis, in_numel) in inputs {
12034                        let in_axis = *in_axis as usize;
12035                        let copy_per_row = in_axis * inner;
12036                        let dst_col_off = cum * inner;
12037                        let inp = sl_f64(*src_off, base, (*in_numel as usize).max(1));
12038                        concat_copy_rows_f64(
12039                            out,
12040                            inp,
12041                            outer,
12042                            copy_per_row,
12043                            row_stride,
12044                            dst_col_off,
12045                            *in_numel as usize,
12046                        );
12047                        cum += in_axis;
12048                    }
12049                }
12050            }
12051
12052            Thunk::Compare {
12053                lhs,
12054                rhs,
12055                dst,
12056                len,
12057                op,
12058                inputs_i64,
12059                inputs_elem_bytes,
12060                dst_elem_bytes,
12061            } => {
12062                let len = *len as usize;
12063                let arena_len = arena_buf.len();
12064                let elem = (*inputs_elem_bytes).max(1) as usize;
12065                let dst_eb = (*dst_elem_bytes).max(1) as usize;
12066                let max_l = (arena_len.saturating_sub(*lhs)) / elem;
12067                let max_r = (arena_len.saturating_sub(*rhs)) / elem;
12068                let max_d = (arena_len.saturating_sub(*dst)) / dst_eb;
12069                let len = len.min(max_l).min(max_r).min(max_d);
12070                if trace_thunks && len > 0 {
12071                    eprintln!("[compare] len={len} lhs={} rhs={} dst={}", *lhs, *rhs, *dst);
12072                }
12073                if elem == 1 {
12074                    let l = arena_buf[*lhs..*lhs + len].to_vec();
12075                    let r = arena_buf[*rhs..*rhs + len].to_vec();
12076                    for i in 0..len {
12077                        let v = match op {
12078                            CmpOp::Eq => l[i] == r[i],
12079                            CmpOp::Ne => l[i] != r[i],
12080                            CmpOp::Lt => l[i] < r[i],
12081                            CmpOp::Le => l[i] <= r[i],
12082                            CmpOp::Gt => l[i] > r[i],
12083                            CmpOp::Ge => l[i] >= r[i],
12084                        };
12085                        if *dst_elem_bytes == 1 {
12086                            arena_buf[*dst + i] = u8::from(v);
12087                        } else {
12088                            unsafe {
12089                                let o = sl_mut(*dst, base, len);
12090                                o[i] = if v { 1.0 } else { 0.0 };
12091                            }
12092                        }
12093                    }
12094                } else if *inputs_i64 != 0 {
12095                    unsafe {
12096                        let l = sl_i64(*lhs, base, len);
12097                        let r = sl_i64(*rhs, base, len);
12098                        for i in 0..len {
12099                            let v = match op {
12100                                CmpOp::Eq => l[i] == r[i],
12101                                CmpOp::Ne => l[i] != r[i],
12102                                CmpOp::Lt => l[i] < r[i],
12103                                CmpOp::Le => l[i] <= r[i],
12104                                CmpOp::Gt => l[i] > r[i],
12105                                CmpOp::Ge => l[i] >= r[i],
12106                            };
12107                            if *dst_elem_bytes == 1 {
12108                                arena_buf[*dst + i] = u8::from(v);
12109                            } else {
12110                                let o = sl_mut(*dst, base, len);
12111                                o[i] = if v { 1.0 } else { 0.0 };
12112                            }
12113                        }
12114                    }
12115                } else {
12116                    unsafe {
12117                        let l = sl(*lhs, base, len);
12118                        let r = sl(*rhs, base, len);
12119                        for i in 0..len {
12120                            let v = match op {
12121                                CmpOp::Eq => l[i] == r[i],
12122                                CmpOp::Ne => l[i] != r[i],
12123                                CmpOp::Lt => l[i] < r[i],
12124                                CmpOp::Le => l[i] <= r[i],
12125                                CmpOp::Gt => l[i] > r[i],
12126                                CmpOp::Ge => l[i] >= r[i],
12127                            };
12128                            if *dst_elem_bytes == 1 {
12129                                arena_buf[*dst + i] = u8::from(v);
12130                            } else {
12131                                let o = sl_mut(*dst, base, len);
12132                                o[i] = if v { 1.0 } else { 0.0 };
12133                            }
12134                        }
12135                    }
12136                }
12137            }
12138
12139            Thunk::Where {
12140                cond,
12141                on_true,
12142                on_false,
12143                dst,
12144                len,
12145                elem_bytes,
12146                cond_elem_bytes,
12147            } => {
12148                let len = *len as usize;
12149                let eb = *elem_bytes as usize;
12150                let cond_eb = (*cond_elem_bytes).max(1) as usize;
12151                let arena_len = arena_buf.len();
12152                let len = len
12153                    .min((arena_len.saturating_sub(*cond)) / cond_eb)
12154                    .min((arena_len.saturating_sub(*on_true)) / eb)
12155                    .min((arena_len.saturating_sub(*on_false)) / eb)
12156                    .min((arena_len.saturating_sub(*dst)) / eb);
12157                unsafe {
12158                    if *elem_bytes == 8 {
12159                        let t = sl_i64(*on_true, base, len);
12160                        let e = sl_i64(*on_false, base, len);
12161                        let o = sl_mut_i64(*dst, base, len);
12162                        if *cond_elem_bytes == 1 {
12163                            let c = &arena_buf[*cond..*cond + len];
12164                            for i in 0..len {
12165                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
12166                            }
12167                        } else {
12168                            let c = sl_i64(*cond, base, len);
12169                            for i in 0..len {
12170                                o[i] = if c[i] != 0 { t[i] } else { e[i] };
12171                            }
12172                        }
12173                    } else if *cond_elem_bytes == 1 {
12174                        let c = &arena_buf[*cond..*cond + len];
12175                        let t = sl(*on_true, base, len);
12176                        let e = sl(*on_false, base, len);
12177                        let o = sl_mut(*dst, base, len);
12178                        for i in 0..len {
12179                            o[i] = if c[i] != 0 { t[i] } else { e[i] };
12180                        }
12181                    } else {
12182                        let c = sl(*cond, base, len);
12183                        let t = sl(*on_true, base, len);
12184                        let e = sl(*on_false, base, len);
12185                        let o = sl_mut(*dst, base, len);
12186                        for i in 0..len {
12187                            o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
12188                        }
12189                    }
12190                }
12191            }
12192
12193            Thunk::ScatterAdd {
12194                updates,
12195                indices,
12196                dst,
12197                num_updates,
12198                out_dim,
12199                trailing,
12200            } => {
12201                let num_updates = *num_updates as usize;
12202                let out_dim = *out_dim as usize;
12203                let trailing = *trailing as usize;
12204                unsafe {
12205                    let upd = sl(*updates, base, num_updates * trailing);
12206                    let ids = sl(*indices, base, num_updates);
12207                    let out = sl_mut(*dst, base, out_dim * trailing);
12208                    // Zero the output first — semantics are accumulate-into-zeros.
12209                    for v in out.iter_mut() {
12210                        *v = 0.0;
12211                    }
12212                    for i in 0..num_updates {
12213                        let row = ids[i] as usize;
12214                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
12215                        let src_off = i * trailing;
12216                        let dst_off = row * trailing;
12217                        for j in 0..trailing {
12218                            out[dst_off + j] += upd[src_off + j];
12219                        }
12220                    }
12221                }
12222            }
12223
12224            Thunk::GroupedMatMul {
12225                input,
12226                weight,
12227                expert_idx,
12228                dst,
12229                m,
12230                k_dim,
12231                n,
12232                num_experts,
12233            } => {
12234                let m = *m as usize;
12235                let k_dim = *k_dim as usize;
12236                let n = *n as usize;
12237                let num_experts = *num_experts as usize;
12238                unsafe {
12239                    let inp = sl(*input, base, m * k_dim);
12240                    let wt = sl(*weight, base, num_experts * k_dim * n);
12241                    let ids = sl(*expert_idx, base, m);
12242                    let out = sl_mut(*dst, base, m * n);
12243
12244                    // Counting-sort tokens by their assigned expert.
12245                    // counts[e] = how many tokens routed to expert e.
12246                    let mut counts = vec![0usize; num_experts];
12247                    for i in 0..m {
12248                        let e = ids[i] as usize;
12249                        debug_assert!(
12250                            e < num_experts,
12251                            "expert_idx out of range: {e} >= {num_experts}"
12252                        );
12253                        counts[e] += 1;
12254                    }
12255                    // Cumulative offsets into the packed buffer.
12256                    let mut offsets = vec![0usize; num_experts + 1];
12257                    for e in 0..num_experts {
12258                        offsets[e + 1] = offsets[e] + counts[e];
12259                    }
12260                    // Pack: each expert's rows land contiguously in `packed_in`.
12261                    // `original_pos[packed_idx] = original_token_idx` for the
12262                    // unpermute step at the end.
12263                    let mut packed_in = vec![0f32; m * k_dim];
12264                    let mut original_pos = vec![0usize; m];
12265                    let mut write_idx = vec![0usize; num_experts];
12266                    for i in 0..m {
12267                        let e = ids[i] as usize;
12268                        let dst_row = offsets[e] + write_idx[e];
12269                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
12270                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
12271                        original_pos[dst_row] = i;
12272                        write_idx[e] += 1;
12273                    }
12274
12275                    // One BLAS sgemm per expert. Skip experts with no
12276                    // tokens — common at the tail when M is much smaller
12277                    // than num_experts × k.
12278                    let mut packed_out = vec![0f32; m * n];
12279                    let expert_stride = k_dim * n;
12280                    let gmm_ord = crate::moe_residency::next_gmm_ord();
12281                    let moe_layer = gmm_ord / 3;
12282                    for e in 0..num_experts {
12283                        let count = counts[e];
12284                        if count == 0 {
12285                            continue;
12286                        }
12287                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
12288                        let in_start = offsets[e];
12289                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
12290                        let w_slab: &[f32] =
12291                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
12292                                if let Some(ptr) =
12293                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
12294                                {
12295                                    std::slice::from_raw_parts(ptr, expert_stride)
12296                                } else {
12297                                    &wt[e * expert_stride..(e + 1) * expert_stride]
12298                                }
12299                            } else {
12300                                &wt[e * expert_stride..(e + 1) * expert_stride]
12301                            };
12302                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
12303                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
12304                    }
12305
12306                    // Unpermute back to original token order.
12307                    for packed_idx in 0..m {
12308                        let i = original_pos[packed_idx];
12309                        out[i * n..(i + 1) * n]
12310                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
12311                    }
12312                }
12313            }
12314
12315            Thunk::DequantGroupedMatMulGguf {
12316                input,
12317                w_q,
12318                expert_idx,
12319                dst,
12320                m,
12321                k_dim,
12322                n,
12323                num_experts,
12324                scheme,
12325            } => {
12326                let m = *m as usize;
12327                let k_dim = *k_dim as usize;
12328                let n = *n as usize;
12329                let num_experts = *num_experts as usize;
12330                let block_elems = scheme.gguf_block_size() as usize;
12331                let block_bytes = scheme.gguf_block_bytes() as usize;
12332                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
12333                unsafe {
12334                    let inp = sl(*input, base, m * k_dim);
12335                    let wt = std::slice::from_raw_parts(
12336                        base.add(*w_q) as *const u8,
12337                        num_experts * slab_bytes,
12338                    );
12339                    let ids = sl(*expert_idx, base, m);
12340                    let out = sl_mut(*dst, base, m * n);
12341                    crate::gguf_matmul::gguf_grouped_matmul_bt(
12342                        inp,
12343                        wt,
12344                        ids,
12345                        out,
12346                        m,
12347                        k_dim,
12348                        n,
12349                        num_experts,
12350                        *scheme,
12351                    );
12352                }
12353            }
12354
12355            Thunk::DequantMoEWeightsGguf {
12356                w_q,
12357                dst,
12358                k_dim,
12359                n,
12360                num_experts,
12361                scheme,
12362            } => {
12363                let k_dim = *k_dim as usize;
12364                let n = *n as usize;
12365                let num_experts = *num_experts as usize;
12366                let block_elems = scheme.gguf_block_size() as usize;
12367                let block_bytes = scheme.gguf_block_bytes() as usize;
12368                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
12369                unsafe {
12370                    let wt = std::slice::from_raw_parts(
12371                        base.add(*w_q) as *const u8,
12372                        num_experts * slab_bytes,
12373                    );
12374                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
12375                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
12376                        wt,
12377                        out,
12378                        num_experts,
12379                        k_dim,
12380                        n,
12381                        *scheme,
12382                    );
12383                }
12384            }
12385
12386            Thunk::TopK {
12387                src,
12388                dst,
12389                outer,
12390                axis_dim,
12391                k,
12392                indices_i64,
12393            } => {
12394                let outer = *outer as usize;
12395                let axis_dim = *axis_dim as usize;
12396                let k = *k as usize;
12397                unsafe {
12398                    let inp = sl(*src, base, outer * axis_dim);
12399                    // Repeated argmax with masking. O(k * axis_dim) per row;
12400                    // good enough for small k (MoE typical k=2–8). For larger
12401                    // k a partial heap would win.
12402                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
12403                    if *indices_i64 != 0 {
12404                        let out = sl_mut_i64(*dst, base, outer * k);
12405                        for o in 0..outer {
12406                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
12407                            for ki in 0..k {
12408                                let mut best_i = 0usize;
12409                                let mut best_v = row_buf[0];
12410                                for i in 1..axis_dim {
12411                                    let v = row_buf[i];
12412                                    if v > best_v {
12413                                        best_v = v;
12414                                        best_i = i;
12415                                    }
12416                                }
12417                                out[o * k + ki] = best_i as i64;
12418                                row_buf[best_i] = f32::NEG_INFINITY;
12419                            }
12420                        }
12421                    } else {
12422                        let out = sl_mut(*dst, base, outer * k);
12423                        for o in 0..outer {
12424                            row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
12425                            for ki in 0..k {
12426                                let mut best_i = 0usize;
12427                                let mut best_v = row_buf[0];
12428                                for i in 1..axis_dim {
12429                                    let v = row_buf[i];
12430                                    if v > best_v {
12431                                        best_v = v;
12432                                        best_i = i;
12433                                    }
12434                                }
12435                                out[o * k + ki] = best_i as f32;
12436                                row_buf[best_i] = f32::NEG_INFINITY;
12437                            }
12438                        }
12439                        if let Some(cap) = schedule.moe_topk_capture.as_ref() {
12440                            cap.push_topk_f32(&out[..outer * k], axis_dim);
12441                        }
12442                    }
12443                }
12444            }
12445
12446            Thunk::Reduce {
12447                src,
12448                dst,
12449                outer,
12450                reduced,
12451                inner,
12452                op,
12453            } => {
12454                let outer = *outer as usize;
12455                let reduced = *reduced as usize;
12456                let inner = *inner as usize;
12457                let in_total = outer * reduced * inner;
12458                let out_total = outer * inner;
12459                unsafe {
12460                    let inp = sl(*src, base, in_total);
12461                    let out = sl_mut(*dst, base, out_total);
12462                    for o in 0..outer {
12463                        for i in 0..inner {
12464                            let mut acc = match op {
12465                                ReduceOp::Max => f32::NEG_INFINITY,
12466                                ReduceOp::Min => f32::INFINITY,
12467                                ReduceOp::Prod => 1.0f32,
12468                                _ => 0.0f32, // Sum / Mean
12469                            };
12470                            // Walk the reduced axis with stride `inner`.
12471                            for r in 0..reduced {
12472                                let v = inp[o * reduced * inner + r * inner + i];
12473                                acc = match op {
12474                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
12475                                    ReduceOp::Max => acc.max(v),
12476                                    ReduceOp::Min => acc.min(v),
12477                                    ReduceOp::Prod => acc * v,
12478                                };
12479                            }
12480                            if matches!(op, ReduceOp::Mean) {
12481                                acc /= reduced as f32;
12482                            }
12483                            out[o * inner + i] = acc;
12484                        }
12485                    }
12486                }
12487            }
12488
12489            Thunk::ArgReduce {
12490                src,
12491                dst,
12492                outer,
12493                reduced,
12494                inner,
12495                is_max,
12496            } => {
12497                let outer = *outer as usize;
12498                let reduced = *reduced as usize;
12499                let inner = *inner as usize;
12500                let in_total = outer * reduced * inner;
12501                let out_total = outer * inner;
12502                unsafe {
12503                    let inp = sl(*src, base, in_total);
12504                    let out = sl_mut(*dst, base, out_total);
12505                    for o in 0..outer {
12506                        for i in 0..inner {
12507                            let mut best = inp[o * reduced * inner + i];
12508                            let mut best_idx = 0usize;
12509                            for r in 1..reduced {
12510                                let v = inp[o * reduced * inner + r * inner + i];
12511                                let better = if *is_max { v > best } else { v < best };
12512                                if better {
12513                                    best = v;
12514                                    best_idx = r;
12515                                }
12516                            }
12517                            out[o * inner + i] = best_idx as f32;
12518                        }
12519                    }
12520                }
12521            }
12522
12523            Thunk::Conv2D1x1 {
12524                src,
12525                weight,
12526                dst,
12527                n,
12528                c_in,
12529                c_out,
12530                hw,
12531            } => {
12532                let n = *n as usize;
12533                let c_in = *c_in as usize;
12534                let c_out = *c_out as usize;
12535                let hw = *hw as usize;
12536                unsafe {
12537                    let inp = sl(*src, base, n * c_in * hw);
12538                    let wt = sl(*weight, base, c_out * c_in);
12539                    let out = sl_mut(*dst, base, n * c_out * hw);
12540                    // Per-batch sgemm: weight [c_out, c_in] @ input
12541                    // [c_in, hw] = output [c_out, hw]. The weight is
12542                    // shared across batches, so we get to dispatch
12543                    // BLAS once per N (typically 1).
12544                    for ni in 0..n {
12545                        let in_off = ni * c_in * hw;
12546                        let out_off = ni * c_out * hw;
12547                        crate::blas::sgemm(
12548                            wt,
12549                            &inp[in_off..in_off + c_in * hw],
12550                            &mut out[out_off..out_off + c_out * hw],
12551                            c_out,
12552                            c_in,
12553                            hw,
12554                        );
12555                    }
12556                }
12557            }
12558
12559            Thunk::Conv2D {
12560                src,
12561                weight,
12562                dst,
12563                n,
12564                c_in,
12565                h,
12566                w,
12567                c_out,
12568                h_out,
12569                w_out,
12570                kh,
12571                kw,
12572                sh,
12573                sw,
12574                ph,
12575                pw,
12576                dh,
12577                dw,
12578                groups,
12579            } => {
12580                let n = *n as usize;
12581                let c_in = *c_in as usize;
12582                let h = *h as usize;
12583                let w = *w as usize;
12584                let c_out = *c_out as usize;
12585                let h_out = *h_out as usize;
12586                let w_out = *w_out as usize;
12587                let kh = *kh as usize;
12588                let kw = *kw as usize;
12589                let sh = *sh as usize;
12590                let sw = *sw as usize;
12591                let ph = *ph as usize;
12592                let pw = *pw as usize;
12593                let dh = *dh as usize;
12594                let dw = *dw as usize;
12595                let groups = *groups as usize;
12596                let c_in_per_g = c_in / groups;
12597                let c_out_per_g = c_out / groups;
12598                unsafe {
12599                    let inp = sl(*src, base, n * c_in * h * w);
12600                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
12601                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
12602                    for ni in 0..n {
12603                        for co in 0..c_out {
12604                            let g = co / c_out_per_g;
12605                            let ci_start = g * c_in_per_g;
12606                            for ho in 0..h_out {
12607                                for wo in 0..w_out {
12608                                    let mut acc = 0f32;
12609                                    for ci_off in 0..c_in_per_g {
12610                                        let ci = ci_start + ci_off;
12611                                        let in_chan = ((ni * c_in) + ci) * h * w;
12612                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12613                                        for ki in 0..kh {
12614                                            for kj in 0..kw {
12615                                                let hi = ho * sh + ki * dh;
12616                                                let wi = wo * sw + kj * dw;
12617                                                if hi < ph || wi < pw {
12618                                                    continue;
12619                                                }
12620                                                let hi = hi - ph;
12621                                                let wi = wi - pw;
12622                                                if hi >= h || wi >= w {
12623                                                    continue;
12624                                                }
12625                                                acc += inp[in_chan + hi * w + wi]
12626                                                    * wt[wt_chan + ki * kw + kj];
12627                                            }
12628                                        }
12629                                    }
12630                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
12631                                        acc;
12632                                }
12633                            }
12634                        }
12635                    }
12636                }
12637            }
12638
12639            Thunk::Pool2D {
12640                src,
12641                dst,
12642                n,
12643                c,
12644                h,
12645                w,
12646                h_out,
12647                w_out,
12648                kh,
12649                kw,
12650                sh,
12651                sw,
12652                ph,
12653                pw,
12654                kind,
12655            } => {
12656                let n = *n as usize;
12657                let c = *c as usize;
12658                let h = *h as usize;
12659                let w = *w as usize;
12660                let h_out = *h_out as usize;
12661                let w_out = *w_out as usize;
12662                let kh = *kh as usize;
12663                let kw = *kw as usize;
12664                let sh = *sh as usize;
12665                let sw = *sw as usize;
12666                let ph = *ph as usize;
12667                let pw = *pw as usize;
12668                let kernel_area = (kh * kw) as f32;
12669                unsafe {
12670                    let inp = sl(*src, base, n * c * h * w);
12671                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
12672                    for ni in 0..n {
12673                        for ci in 0..c {
12674                            let in_chan = ni * c * h * w + ci * h * w;
12675                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
12676                            for ho in 0..h_out {
12677                                for wo in 0..w_out {
12678                                    let mut acc = match kind {
12679                                        ReduceOp::Max => f32::NEG_INFINITY,
12680                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
12681                                    };
12682                                    for ki in 0..kh {
12683                                        for kj in 0..kw {
12684                                            let hi = ho * sh + ki;
12685                                            let wi = wo * sw + kj;
12686                                            // Padded-zero region.
12687                                            if hi < ph || wi < pw {
12688                                                continue;
12689                                            }
12690                                            let hi = hi - ph;
12691                                            let wi = wi - pw;
12692                                            if hi >= h || wi >= w {
12693                                                continue;
12694                                            }
12695                                            let v = inp[in_chan + hi * w + wi];
12696                                            match kind {
12697                                                ReduceOp::Max => acc = acc.max(v),
12698                                                _ => acc += v,
12699                                            }
12700                                        }
12701                                    }
12702                                    if matches!(kind, ReduceOp::Mean) {
12703                                        acc /= kernel_area;
12704                                    }
12705                                    out[out_chan + ho * w_out + wo] = acc;
12706                                }
12707                            }
12708                        }
12709                    }
12710                }
12711            }
12712
12713            Thunk::ReluBackward { x, dy, dx, len } => {
12714                let len = *len as usize;
12715                unsafe {
12716                    let xs = sl(*x, base, len);
12717                    let dys = sl(*dy, base, len);
12718                    let out = sl_mut(*dx, base, len);
12719                    for i in 0..len {
12720                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12721                    }
12722                }
12723            }
12724
12725            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
12726                let len = *len as usize;
12727                unsafe {
12728                    let xs = sl_f64(*x, base, len);
12729                    let dys = sl_f64(*dy, base, len);
12730                    let out = sl_mut_f64(*dx, base, len);
12731                    for i in 0..len {
12732                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
12733                    }
12734                }
12735            }
12736
12737            Thunk::QMatMul {
12738                x,
12739                w,
12740                bias,
12741                out,
12742                m,
12743                k,
12744                n,
12745                x_zp,
12746                w_zp,
12747                out_zp,
12748                mult,
12749            } => {
12750                let m = *m as usize;
12751                let k = *k as usize;
12752                let n = *n as usize;
12753                unsafe {
12754                    let x_ptr = base.add(*x) as *const i8;
12755                    let w_ptr = base.add(*w) as *const i8;
12756                    let bias_ptr = base.add(*bias) as *const i32;
12757                    let out_ptr = base.add(*out) as *mut i8;
12758                    for mi in 0..m {
12759                        for ni in 0..n {
12760                            let mut acc: i32 = *bias_ptr.add(ni);
12761                            for ki in 0..k {
12762                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
12763                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
12764                                acc += xv * wv;
12765                            }
12766                            // Requantize: round(acc · mult) + out_zp,
12767                            // clamped to i8.
12768                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12769                            let r = r.clamp(-128, 127) as i8;
12770                            *out_ptr.add(mi * n + ni) = r;
12771                        }
12772                    }
12773                }
12774            }
12775
12776            Thunk::QConv2d {
12777                x,
12778                w,
12779                bias,
12780                out,
12781                n,
12782                c_in,
12783                h,
12784                w_in,
12785                c_out,
12786                h_out,
12787                w_out,
12788                kh,
12789                kw,
12790                sh,
12791                sw,
12792                ph,
12793                pw,
12794                dh,
12795                dw,
12796                groups,
12797                x_zp,
12798                w_zp,
12799                out_zp,
12800                mult,
12801            } => {
12802                let n = *n as usize;
12803                let c_in = *c_in as usize;
12804                let h = *h as usize;
12805                let w_in = *w_in as usize;
12806                let c_out = *c_out as usize;
12807                let h_out = *h_out as usize;
12808                let w_out = *w_out as usize;
12809                let kh = *kh as usize;
12810                let kw = *kw as usize;
12811                let sh = *sh as usize;
12812                let sw = *sw as usize;
12813                let ph = *ph as usize;
12814                let pw = *pw as usize;
12815                let dh = *dh as usize;
12816                let dw = *dw as usize;
12817                let groups = *groups as usize;
12818                let c_in_per_g = c_in / groups;
12819                let c_out_per_g = c_out / groups;
12820                unsafe {
12821                    let x_ptr = base.add(*x) as *const i8;
12822                    let w_ptr = base.add(*w) as *const i8;
12823                    let bias_ptr = base.add(*bias) as *const i32;
12824                    let out_ptr = base.add(*out) as *mut i8;
12825                    for ni in 0..n {
12826                        for co in 0..c_out {
12827                            let g = co / c_out_per_g;
12828                            let ci_start = g * c_in_per_g;
12829                            for ho in 0..h_out {
12830                                for wo in 0..w_out {
12831                                    let mut acc: i32 = *bias_ptr.add(co);
12832                                    for ci_off in 0..c_in_per_g {
12833                                        let ci = ci_start + ci_off;
12834                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
12835                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
12836                                        for ki in 0..kh {
12837                                            for kj in 0..kw {
12838                                                let hi = ho * sh + ki * dh;
12839                                                let wi = wo * sw + kj * dw;
12840                                                if hi < ph || wi < pw {
12841                                                    continue;
12842                                                }
12843                                                let hi = hi - ph;
12844                                                let wi = wi - pw;
12845                                                if hi >= h || wi >= w_in {
12846                                                    continue;
12847                                                }
12848                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
12849                                                    as i32
12850                                                    - *x_zp;
12851                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
12852                                                    - *w_zp;
12853                                                acc += xv * wv;
12854                                            }
12855                                        }
12856                                    }
12857                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
12858                                    let r = r.clamp(-128, 127) as i8;
12859                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
12860                                    *out_ptr.add(dst) = r;
12861                                }
12862                            }
12863                        }
12864                    }
12865                }
12866            }
12867
12868            Thunk::Quantize {
12869                x,
12870                q,
12871                len,
12872                chan_axis: _,
12873                chan_dim,
12874                inner,
12875                scales,
12876                zero_points,
12877            } => {
12878                let len = *len as usize;
12879                let chan_dim = *chan_dim as usize;
12880                let inner = *inner as usize;
12881                unsafe {
12882                    let xs = sl(*x, base, len);
12883                    let q_ptr = base.add(*q) as *mut i8;
12884                    for i in 0..len {
12885                        let c = if chan_dim == 1 {
12886                            0
12887                        } else {
12888                            (i / inner) % chan_dim
12889                        };
12890                        let inv_scale = 1.0 / scales[c];
12891                        let zp = zero_points[c];
12892                        let v = (xs[i] * inv_scale).round() as i32 + zp;
12893                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
12894                    }
12895                }
12896            }
12897
12898            Thunk::Dequantize {
12899                q,
12900                x,
12901                len,
12902                chan_axis: _,
12903                chan_dim,
12904                inner,
12905                scales,
12906                zero_points,
12907            } => {
12908                let len = *len as usize;
12909                let chan_dim = *chan_dim as usize;
12910                let inner = *inner as usize;
12911                unsafe {
12912                    let q_ptr = base.add(*q) as *const i8;
12913                    let out = sl_mut(*x, base, len);
12914                    for i in 0..len {
12915                        let c = if chan_dim == 1 {
12916                            0
12917                        } else {
12918                            (i / inner) % chan_dim
12919                        };
12920                        let scale = scales[c];
12921                        let zp = zero_points[c];
12922                        let qv = *q_ptr.add(i) as i32;
12923                        out[i] = (qv - zp) as f32 * scale;
12924                    }
12925                }
12926            }
12927
12928            Thunk::FakeQuantize {
12929                x,
12930                out,
12931                len,
12932                chan_axis: _,
12933                chan_dim,
12934                inner,
12935                bits,
12936                ste: _,
12937                scale_mode,
12938                state_off,
12939            } => {
12940                use rlx_ir::op::ScaleMode;
12941                let len = *len as usize;
12942                let chan_dim = *chan_dim as usize;
12943                let inner = *inner as usize;
12944                let q_max: f32 = match *bits {
12945                    8 => 127.0,
12946                    4 => 7.0,
12947                    2 => 1.0,
12948                    n => panic!("FakeQuantize: unsupported bits {n}"),
12949                };
12950                unsafe {
12951                    let xs = sl(*x, base, len);
12952                    let outs = sl_mut(*out, base, len);
12953
12954                    let mut scale = vec![0f32; chan_dim];
12955                    match scale_mode {
12956                        ScaleMode::PerBatch => {
12957                            let mut max_abs = vec![0f32; chan_dim];
12958                            for i in 0..len {
12959                                let c = if chan_dim == 1 {
12960                                    0
12961                                } else {
12962                                    (i / inner) % chan_dim
12963                                };
12964                                let a = xs[i].abs();
12965                                if a > max_abs[c] {
12966                                    max_abs[c] = a;
12967                                }
12968                            }
12969                            for c in 0..chan_dim {
12970                                scale[c] = (max_abs[c] / q_max).max(1e-12);
12971                            }
12972                        }
12973                        ScaleMode::EMA { decay } => {
12974                            // Per-channel current max-abs, then blend
12975                            // into the running state in place.
12976                            let mut max_abs = vec![0f32; chan_dim];
12977                            for i in 0..len {
12978                                let c = if chan_dim == 1 {
12979                                    0
12980                                } else {
12981                                    (i / inner) % chan_dim
12982                                };
12983                                let a = xs[i].abs();
12984                                if a > max_abs[c] {
12985                                    max_abs[c] = a;
12986                                }
12987                            }
12988                            let state =
12989                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
12990                            for c in 0..chan_dim {
12991                                let cur = (max_abs[c] / q_max).max(1e-12);
12992                                // Cold-start: state==0 → seed directly.
12993                                let blended = if state[c] <= 0.0 {
12994                                    cur
12995                                } else {
12996                                    *decay * state[c] + (1.0 - *decay) * cur
12997                                };
12998                                state[c] = blended;
12999                                scale[c] = blended;
13000                            }
13001                        }
13002                        ScaleMode::Fixed => {
13003                            let state =
13004                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
13005                            for c in 0..chan_dim {
13006                                scale[c] = state[c].max(1e-12);
13007                            }
13008                        }
13009                    }
13010
13011                    for i in 0..len {
13012                        let c = if chan_dim == 1 {
13013                            0
13014                        } else {
13015                            (i / inner) % chan_dim
13016                        };
13017                        let s = scale[c];
13018                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
13019                        outs[i] = qv * s;
13020                    }
13021                }
13022            }
13023
13024            Thunk::ActivationBackward {
13025                x,
13026                dy,
13027                dx,
13028                len,
13029                kind,
13030            } => {
13031                let len = *len as usize;
13032                unsafe {
13033                    let xs = sl(*x, base, len);
13034                    let dys = sl(*dy, base, len);
13035                    let out = sl_mut(*dx, base, len);
13036                    activation_backward_kernel(*kind, xs, dys, out);
13037                }
13038            }
13039
13040            Thunk::ActivationBackwardF64 {
13041                x,
13042                dy,
13043                dx,
13044                len,
13045                kind,
13046            } => {
13047                let len = *len as usize;
13048                unsafe {
13049                    let xs = sl_f64(*x, base, len);
13050                    let dys = sl_f64(*dy, base, len);
13051                    let out = sl_mut_f64(*dx, base, len);
13052                    activation_backward_kernel_f64(*kind, xs, dys, out);
13053                }
13054            }
13055
13056            Thunk::FakeQuantizeLSQ {
13057                x,
13058                scale_off,
13059                out,
13060                len,
13061                chan_axis: _,
13062                chan_dim,
13063                inner,
13064                bits,
13065            } => {
13066                let len = *len as usize;
13067                let chan_dim = *chan_dim as usize;
13068                let inner = *inner as usize;
13069                let q_max: f32 = match *bits {
13070                    8 => 127.0,
13071                    4 => 7.0,
13072                    2 => 1.0,
13073                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
13074                };
13075                unsafe {
13076                    let xs = sl(*x, base, len);
13077                    let scale = sl(*scale_off, base, chan_dim);
13078                    let outs = sl_mut(*out, base, len);
13079                    for i in 0..len {
13080                        let c = if chan_dim == 1 {
13081                            0
13082                        } else {
13083                            (i / inner) % chan_dim
13084                        };
13085                        let s = scale[c].max(1e-12);
13086                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
13087                        outs[i] = qv * s;
13088                    }
13089                }
13090            }
13091
13092            Thunk::FakeQuantizeLSQBackwardX {
13093                x,
13094                scale_off,
13095                dy,
13096                dx,
13097                len,
13098                chan_axis: _,
13099                chan_dim,
13100                inner,
13101                bits,
13102            } => {
13103                let len = *len as usize;
13104                let chan_dim = *chan_dim as usize;
13105                let inner = *inner as usize;
13106                let q_max: f32 = match *bits {
13107                    8 => 127.0,
13108                    4 => 7.0,
13109                    2 => 1.0,
13110                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
13111                };
13112                unsafe {
13113                    let xs = sl(*x, base, len);
13114                    let scale = sl(*scale_off, base, chan_dim);
13115                    let dys = sl(*dy, base, len);
13116                    let outs = sl_mut(*dx, base, len);
13117                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
13118                    for i in 0..len {
13119                        let c = if chan_dim == 1 {
13120                            0
13121                        } else {
13122                            (i / inner) % chan_dim
13123                        };
13124                        let z = xs[i] / scale[c].max(1e-12);
13125                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
13126                    }
13127                }
13128            }
13129
13130            Thunk::FakeQuantizeLSQBackwardScale {
13131                x,
13132                scale_off,
13133                dy,
13134                dscale,
13135                len,
13136                chan_axis: _,
13137                chan_dim,
13138                inner,
13139                bits,
13140            } => {
13141                let len = *len as usize;
13142                let chan_dim = *chan_dim as usize;
13143                let inner = *inner as usize;
13144                let q_max: f32 = match *bits {
13145                    8 => 127.0,
13146                    4 => 7.0,
13147                    2 => 1.0,
13148                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
13149                };
13150                unsafe {
13151                    let xs = sl(*x, base, len);
13152                    let scale = sl(*scale_off, base, chan_dim);
13153                    let dys = sl(*dy, base, len);
13154                    let outs = sl_mut(*dscale, base, chan_dim);
13155                    for v in outs.iter_mut() {
13156                        *v = 0.0;
13157                    }
13158                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
13159                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
13160                    for i in 0..len {
13161                        let c = if chan_dim == 1 {
13162                            0
13163                        } else {
13164                            (i / inner) % chan_dim
13165                        };
13166                        let s = scale[c].max(1e-12);
13167                        let z = xs[i] / s;
13168                        let psi = if z.abs() <= q_max {
13169                            -z + z.round()
13170                        } else if z > 0.0 {
13171                            q_max
13172                        } else {
13173                            -q_max
13174                        };
13175                        outs[c] += psi * dys[i];
13176                    }
13177                }
13178            }
13179
13180            Thunk::FakeQuantizeBackward {
13181                x,
13182                dy,
13183                dx,
13184                len,
13185                chan_axis: _,
13186                chan_dim,
13187                inner,
13188                bits,
13189                ste,
13190            } => {
13191                use rlx_ir::op::SteKind;
13192                let len = *len as usize;
13193                let chan_dim = *chan_dim as usize;
13194                let inner = *inner as usize;
13195                let q_max: f32 = match *bits {
13196                    8 => 127.0,
13197                    4 => 7.0,
13198                    2 => 1.0,
13199                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
13200                };
13201                unsafe {
13202                    let xs = sl(*x, base, len);
13203                    let dys = sl(*dy, base, len);
13204                    let outs = sl_mut(*dx, base, len);
13205
13206                    // Per-channel max-abs → scale, same as forward.
13207                    let mut max_abs = vec![0f32; chan_dim];
13208                    for i in 0..len {
13209                        let c = if chan_dim == 1 {
13210                            0
13211                        } else {
13212                            (i / inner) % chan_dim
13213                        };
13214                        let a = xs[i].abs();
13215                        if a > max_abs[c] {
13216                            max_abs[c] = a;
13217                        }
13218                    }
13219                    let mut scale = vec![0f32; chan_dim];
13220                    for c in 0..chan_dim {
13221                        scale[c] = (max_abs[c] / q_max).max(1e-12);
13222                    }
13223
13224                    match *ste {
13225                        SteKind::Identity => {
13226                            // dx = dy unchanged.
13227                            outs.copy_from_slice(dys);
13228                        }
13229                        SteKind::ClippedIdentity => {
13230                            // dx = dy * (|x| <= q_max·s); zero if the
13231                            // forward saturated.
13232                            for i in 0..len {
13233                                let c = if chan_dim == 1 {
13234                                    0
13235                                } else {
13236                                    (i / inner) % chan_dim
13237                                };
13238                                let bound = q_max * scale[c];
13239                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
13240                            }
13241                        }
13242                        SteKind::Tanh => {
13243                            // dx = dy * (1 - tanh²(x/s)).
13244                            for i in 0..len {
13245                                let c = if chan_dim == 1 {
13246                                    0
13247                                } else {
13248                                    (i / inner) % chan_dim
13249                                };
13250                                let t = (xs[i] / scale[c]).tanh();
13251                                outs[i] = dys[i] * (1.0 - t * t);
13252                            }
13253                        }
13254                        SteKind::HardTanh => {
13255                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
13256                            for i in 0..len {
13257                                let c = if chan_dim == 1 {
13258                                    0
13259                                } else {
13260                                    (i / inner) % chan_dim
13261                                };
13262                                let bound = q_max * scale[c];
13263                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
13264                                outs[i] = dys[i] * attenuation;
13265                            }
13266                        }
13267                    }
13268                }
13269            }
13270
13271            Thunk::LayerNormBackwardInput {
13272                x,
13273                gamma,
13274                dy,
13275                dx,
13276                rows,
13277                h,
13278                eps,
13279            } => {
13280                let rows = *rows as usize;
13281                let h = *h as usize;
13282                let eps = *eps;
13283                unsafe {
13284                    let xs = sl(*x, base, rows * h);
13285                    let g = sl(*gamma, base, h);
13286                    let dys = sl(*dy, base, rows * h);
13287                    let out = sl_mut(*dx, base, rows * h);
13288                    let n_inv = 1.0 / h as f32;
13289                    for r in 0..rows {
13290                        let xr = &xs[r * h..(r + 1) * h];
13291                        let dyr = &dys[r * h..(r + 1) * h];
13292                        // Per-row mean and inv_std (recompute — no saved
13293                        // tensor from the forward pass).
13294                        let mut sum = 0f32;
13295                        for &v in xr {
13296                            sum += v;
13297                        }
13298                        let mean = sum * n_inv;
13299                        let mut var = 0f32;
13300                        for &v in xr {
13301                            let d = v - mean;
13302                            var += d * d;
13303                        }
13304                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
13305
13306                        // sums needed for the closed-form:
13307                        //   mean(dy·γ) and mean(dy·γ·x̂)
13308                        let mut s_sy = 0f32;
13309                        let mut s_sxh = 0f32;
13310                        for d in 0..h {
13311                            let xh = (xr[d] - mean) * inv_std;
13312                            let sy = dyr[d] * g[d];
13313                            s_sy += sy;
13314                            s_sxh += sy * xh;
13315                        }
13316                        let m_sy = s_sy * n_inv;
13317                        let m_sxh = s_sxh * n_inv;
13318
13319                        for d in 0..h {
13320                            let xh = (xr[d] - mean) * inv_std;
13321                            let sy = dyr[d] * g[d];
13322                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
13323                        }
13324                    }
13325                }
13326            }
13327
13328            Thunk::BatchNormInferenceBackwardInput {
13329                x,
13330                gamma,
13331                mean,
13332                var,
13333                dy,
13334                dx,
13335                count,
13336                channels,
13337                eps,
13338            } => {
13339                let count = *count as usize;
13340                let c = *channels as usize;
13341                let n = count * c;
13342                let eps = *eps;
13343                unsafe {
13344                    crate::kernels::batch_norm_inference_backward_input(
13345                        sl(*x, base, n),
13346                        sl(*gamma, base, c),
13347                        sl(*mean, base, c),
13348                        sl(*var, base, c),
13349                        sl(*dy, base, n),
13350                        sl_mut(*dx, base, n),
13351                        c,
13352                        eps,
13353                    );
13354                }
13355            }
13356
13357            Thunk::BatchNormInferenceBackwardGamma {
13358                x,
13359                mean,
13360                var,
13361                dy,
13362                dgamma,
13363                count,
13364                channels,
13365                eps,
13366            } => {
13367                let count = *count as usize;
13368                let c = *channels as usize;
13369                let n = count * c;
13370                let eps = *eps;
13371                unsafe {
13372                    crate::kernels::batch_norm_inference_backward_gamma(
13373                        sl(*x, base, n),
13374                        sl(*mean, base, c),
13375                        sl(*var, base, c),
13376                        sl(*dy, base, n),
13377                        sl_mut(*dgamma, base, c),
13378                        c,
13379                        eps,
13380                    );
13381                }
13382            }
13383
13384            Thunk::BatchNormInferenceBackwardBeta {
13385                dy,
13386                dbeta,
13387                count,
13388                channels,
13389            } => {
13390                let count = *count as usize;
13391                let c = *channels as usize;
13392                let n = count * c;
13393                unsafe {
13394                    crate::kernels::batch_norm_inference_backward_beta(
13395                        sl(*dy, base, n),
13396                        sl_mut(*dbeta, base, c),
13397                        c,
13398                    );
13399                }
13400            }
13401
13402            Thunk::LayerNormBackwardGamma {
13403                x,
13404                dy,
13405                dgamma,
13406                rows,
13407                h,
13408                eps,
13409            } => {
13410                let rows = *rows as usize;
13411                let h = *h as usize;
13412                let eps = *eps;
13413                unsafe {
13414                    let xs = sl(*x, base, rows * h);
13415                    let dys = sl(*dy, base, rows * h);
13416                    let out = sl_mut(*dgamma, base, h);
13417                    for v in out.iter_mut() {
13418                        *v = 0.0;
13419                    }
13420                    let n_inv = 1.0 / h as f32;
13421                    for r in 0..rows {
13422                        let xr = &xs[r * h..(r + 1) * h];
13423                        let dyr = &dys[r * h..(r + 1) * h];
13424                        let mut sum = 0f32;
13425                        for &v in xr {
13426                            sum += v;
13427                        }
13428                        let mean = sum * n_inv;
13429                        let mut var = 0f32;
13430                        for &v in xr {
13431                            let d = v - mean;
13432                            var += d * d;
13433                        }
13434                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
13435                        for d in 0..h {
13436                            let xh = (xr[d] - mean) * inv_std;
13437                            out[d] += dyr[d] * xh;
13438                        }
13439                    }
13440                }
13441            }
13442
13443            Thunk::RmsNormBackwardInput {
13444                x,
13445                gamma,
13446                beta,
13447                dy,
13448                dx,
13449                rows,
13450                h,
13451                eps,
13452            } => {
13453                let (rows, h) = (*rows as usize, *h as usize);
13454                unsafe {
13455                    let xs = sl(*x, base, rows * h);
13456                    let g = sl(*gamma, base, h);
13457                    let b = sl(*beta, base, h);
13458                    let dys = sl(*dy, base, rows * h);
13459                    let out = sl_mut(*dx, base, rows * h);
13460                    let mut dg = vec![0f32; h];
13461                    let mut db = vec![0f32; h];
13462                    for r in 0..rows {
13463                        crate::training_bwd::rms_norm_backward_row(
13464                            &xs[r * h..(r + 1) * h],
13465                            g,
13466                            b,
13467                            &dys[r * h..(r + 1) * h],
13468                            &mut out[r * h..(r + 1) * h],
13469                            &mut dg,
13470                            &mut db,
13471                            *eps,
13472                        );
13473                    }
13474                }
13475            }
13476
13477            Thunk::RmsNormBackwardGamma {
13478                x,
13479                gamma,
13480                beta,
13481                dy,
13482                dgamma,
13483                rows,
13484                h,
13485                eps,
13486            } => {
13487                let (rows, h) = (*rows as usize, *h as usize);
13488                unsafe {
13489                    let xs = sl(*x, base, rows * h);
13490                    let g = sl(*gamma, base, h);
13491                    let b = sl(*beta, base, h);
13492                    let dys = sl(*dy, base, rows * h);
13493                    let out = sl_mut(*dgamma, base, h);
13494                    for v in out.iter_mut() {
13495                        *v = 0.0;
13496                    }
13497                    let mut dx = vec![0f32; h];
13498                    let mut db = vec![0f32; h];
13499                    for r in 0..rows {
13500                        crate::training_bwd::rms_norm_backward_row(
13501                            &xs[r * h..(r + 1) * h],
13502                            g,
13503                            b,
13504                            &dys[r * h..(r + 1) * h],
13505                            &mut dx,
13506                            &mut *out,
13507                            &mut db,
13508                            *eps,
13509                        );
13510                    }
13511                }
13512            }
13513
13514            Thunk::RmsNormBackwardBeta {
13515                x,
13516                gamma,
13517                beta,
13518                dy,
13519                dbeta,
13520                rows,
13521                h,
13522                eps,
13523            } => {
13524                let (rows, h) = (*rows as usize, *h as usize);
13525                unsafe {
13526                    let xs = sl(*x, base, rows * h);
13527                    let g = sl(*gamma, base, h);
13528                    let b = sl(*beta, base, h);
13529                    let dys = sl(*dy, base, rows * h);
13530                    let out = sl_mut(*dbeta, base, h);
13531                    for v in out.iter_mut() {
13532                        *v = 0.0;
13533                    }
13534                    let mut dx = vec![0f32; h];
13535                    let mut dg = vec![0f32; h];
13536                    for r in 0..rows {
13537                        crate::training_bwd::rms_norm_backward_row(
13538                            &xs[r * h..(r + 1) * h],
13539                            g,
13540                            b,
13541                            &dys[r * h..(r + 1) * h],
13542                            &mut dx,
13543                            &mut dg,
13544                            &mut *out,
13545                            *eps,
13546                        );
13547                    }
13548                }
13549            }
13550
13551            Thunk::RopeBackward {
13552                dy,
13553                cos,
13554                sin,
13555                dx,
13556                batch,
13557                seq,
13558                hidden,
13559                head_dim,
13560                n_rot,
13561                cos_len,
13562            } => {
13563                let (b, s, hs, dh, nr, cl) = (
13564                    *batch as usize,
13565                    *seq as usize,
13566                    *hidden as usize,
13567                    *head_dim as usize,
13568                    *n_rot as usize,
13569                    *cos_len as usize,
13570                );
13571                let nh = hs / dh;
13572                let tab_half = dh / 2;
13573                unsafe {
13574                    let dys = sl(*dy, base, b * s * hs);
13575                    let cos_tab = sl(*cos, base, cl);
13576                    let sin_tab = sl(*sin, base, cl);
13577                    let out = sl_mut(*dx, base, b * s * hs);
13578                    for bi in 0..b {
13579                        for si in 0..s {
13580                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
13581                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
13582                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
13583                            for hi in 0..nh {
13584                                let base_idx = bi * s * hs + si * hs + hi * dh;
13585                                crate::training_bwd::rope_backward_row(
13586                                    &dys[base_idx..base_idx + dh],
13587                                    cp,
13588                                    sp,
13589                                    &mut out[base_idx..base_idx + dh],
13590                                    dh,
13591                                    nr,
13592                                );
13593                            }
13594                        }
13595                    }
13596                }
13597            }
13598
13599            Thunk::CumsumBackward {
13600                dy,
13601                dx,
13602                rows,
13603                cols,
13604                exclusive,
13605            } => {
13606                let (rows, cols) = (*rows as usize, *cols as usize);
13607                unsafe {
13608                    let dys = sl(*dy, base, rows * cols);
13609                    let out = sl_mut(*dx, base, rows * cols);
13610                    for r in 0..rows {
13611                        crate::training_bwd::cumsum_backward_row(
13612                            &dys[r * cols..(r + 1) * cols],
13613                            &mut out[r * cols..(r + 1) * cols],
13614                            *exclusive,
13615                        );
13616                    }
13617                }
13618            }
13619
13620            Thunk::GroupNormBackwardInput {
13621                x,
13622                gamma,
13623                beta: _beta,
13624                dy,
13625                dx,
13626                n,
13627                c,
13628                h,
13629                w,
13630                num_groups,
13631                eps,
13632            } => {
13633                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13634                let plane = c * h * w;
13635                unsafe {
13636                    let xs = sl(*x, base, n * plane);
13637                    let g = sl(*gamma, base, c);
13638                    let dys = sl(*dy, base, n * plane);
13639                    let out = sl_mut(*dx, base, n * plane);
13640                    crate::training_bwd::group_norm_backward_input_nchw(
13641                        xs,
13642                        g,
13643                        dys,
13644                        out,
13645                        n,
13646                        c,
13647                        h,
13648                        w,
13649                        *num_groups as usize,
13650                        *eps,
13651                    );
13652                }
13653            }
13654
13655            Thunk::GroupNormBackwardGamma {
13656                x,
13657                dy,
13658                dgamma,
13659                n,
13660                c,
13661                h,
13662                w,
13663                num_groups,
13664                eps,
13665            } => {
13666                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13667                let plane = c * h * w;
13668                unsafe {
13669                    let xs = sl(*x, base, n * plane);
13670                    let dys = sl(*dy, base, n * plane);
13671                    let out = sl_mut(*dgamma, base, c);
13672                    crate::training_bwd::group_norm_backward_gamma_nchw(
13673                        xs,
13674                        dys,
13675                        out,
13676                        n,
13677                        c,
13678                        h,
13679                        w,
13680                        *num_groups as usize,
13681                        *eps,
13682                    );
13683                }
13684            }
13685
13686            Thunk::GroupNormBackwardBeta {
13687                dy,
13688                dbeta,
13689                n,
13690                c,
13691                h,
13692                w,
13693            } => {
13694                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
13695                let plane = c * h * w;
13696                unsafe {
13697                    let dys = sl(*dy, base, n * plane);
13698                    let out = sl_mut(*dbeta, base, c);
13699                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
13700                }
13701            }
13702
13703            Thunk::GatherBackward {
13704                dy,
13705                indices,
13706                dst,
13707                outer,
13708                axis_dim,
13709                num_idx,
13710                trailing,
13711            } => {
13712                let (outer, axis_dim, num_idx, trailing) = (
13713                    *outer as usize,
13714                    *axis_dim as usize,
13715                    *num_idx as usize,
13716                    *trailing as usize,
13717                );
13718                unsafe {
13719                    let dys = sl(*dy, base, outer * num_idx * trailing);
13720                    let ids = sl(*indices, base, num_idx);
13721                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
13722                    for v in out.iter_mut() {
13723                        *v = 0.0;
13724                    }
13725                    crate::training_bwd::gather_axis_backward(
13726                        dys, ids, out, outer, axis_dim, num_idx, trailing,
13727                    );
13728                }
13729            }
13730
13731            Thunk::MaxPool2dBackward {
13732                x,
13733                dy,
13734                dx,
13735                n,
13736                c,
13737                h,
13738                w,
13739                h_out,
13740                w_out,
13741                kh,
13742                kw,
13743                sh,
13744                sw,
13745                ph,
13746                pw,
13747            } => unsafe {
13748                execute_maxpool2d_backward_f32(
13749                    *x, *dy, *dx, *n, *c, *h, *w, *h_out, *w_out, *kh, *kw, *sh, *sw, *ph, *pw,
13750                    base,
13751                );
13752            },
13753
13754            Thunk::Conv2dBackwardInput {
13755                dy,
13756                w,
13757                dx,
13758                n,
13759                c_in,
13760                h,
13761                w_in,
13762                c_out,
13763                h_out,
13764                w_out,
13765                kh,
13766                kw,
13767                sh,
13768                sw,
13769                ph,
13770                pw,
13771                dh,
13772                dw,
13773                groups,
13774            } => {
13775                // Per-group GEMM + col2im. Two orders of magnitude faster
13776                // than the naive 6-deep nested loop on training shapes.
13777                //
13778                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
13779                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
13780                //
13781                // Layouts (all row-major):
13782                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
13783                //   dy_n_g    [c_out_per_g, h_out · w_out]
13784                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
13785                //   dx_n_g    [c_in_per_g, h · w_in]
13786                let n = *n as usize;
13787                let c_in = *c_in as usize;
13788                let h = *h as usize;
13789                let w_in = *w_in as usize;
13790                let c_out = *c_out as usize;
13791                let h_out = *h_out as usize;
13792                let w_out = *w_out as usize;
13793                let kh = *kh as usize;
13794                let kw = *kw as usize;
13795                let sh = *sh as usize;
13796                let sw = *sw as usize;
13797                let ph = *ph as usize;
13798                let pw = *pw as usize;
13799                let dh = *dh as usize;
13800                let dw = *dw as usize;
13801                let groups = *groups as usize;
13802                let c_in_per_g = c_in / groups;
13803                let c_out_per_g = c_out / groups;
13804
13805                let m_dim = c_in_per_g * kh * kw;
13806                let n_dim = h_out * w_out;
13807                let k_dim = c_out_per_g;
13808
13809                let dy_stride_n = c_out * h_out * w_out;
13810                let dy_stride_g = c_out_per_g * h_out * w_out;
13811                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13812                let dx_stride_n = c_in * h * w_in;
13813                let dx_stride_g = c_in_per_g * h * w_in;
13814
13815                unsafe {
13816                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13817                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
13818                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
13819                    for v in dxs.iter_mut() {
13820                        *v = 0.0;
13821                    }
13822
13823                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
13824                    let mut dcol = vec![0f32; m_dim * n_dim];
13825
13826                    for ni in 0..n {
13827                        for g in 0..groups {
13828                            let w_g_off = g * w_stride_g;
13829                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13830                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
13831
13832                            // dcol = w_g^T @ dy_n_g
13833                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
13834                            // (i.e. K×M storage with lda = M = m_dim — exactly what
13835                            // sgemm_general wants for trans_a=true).
13836                            crate::blas::sgemm_general(
13837                                ws.as_ptr().add(w_g_off),
13838                                dys.as_ptr().add(dy_n_g_off),
13839                                dcol.as_mut_ptr(),
13840                                m_dim,
13841                                n_dim,
13842                                k_dim,
13843                                1.0,
13844                                0.0,
13845                                /*lda=*/ m_dim,
13846                                /*ldb=*/ n_dim,
13847                                /*ldc=*/ n_dim,
13848                                /*trans_a=*/ true,
13849                                /*trans_b=*/ false,
13850                            );
13851
13852                            // dx_n_g += col2im(dcol)
13853                            col2im(
13854                                &dcol,
13855                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
13856                                c_in_per_g,
13857                                h,
13858                                w_in,
13859                                h_out,
13860                                w_out,
13861                                kh,
13862                                kw,
13863                                sh,
13864                                sw,
13865                                ph,
13866                                pw,
13867                                dh,
13868                                dw,
13869                            );
13870                        }
13871                    }
13872                }
13873            }
13874
13875            Thunk::Conv2dBackwardWeight {
13876                x,
13877                dy,
13878                dw,
13879                n,
13880                c_in,
13881                h,
13882                w,
13883                c_out,
13884                h_out,
13885                w_out,
13886                kh,
13887                kw,
13888                sh,
13889                sw,
13890                ph,
13891                pw,
13892                dh,
13893                dw_dil,
13894                groups,
13895            } => {
13896                let n = *n as usize;
13897                let c_in = *c_in as usize;
13898                let h = *h as usize;
13899                let w = *w as usize;
13900                // Per-group im2col + GEMM, summed across batch.
13901                //
13902                //   col_n_g  = im2col(x_n_g)               (gather)
13903                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
13904                //
13905                // Layouts:
13906                //   x_n_g     [c_in_per_g, h · w]
13907                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
13908                //   dy_n_g    [c_out_per_g, h_out · w_out]
13909                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
13910                let c_out = *c_out as usize;
13911                let h_out = *h_out as usize;
13912                let w_out = *w_out as usize;
13913                let kh = *kh as usize;
13914                let kw = *kw as usize;
13915                let sh = *sh as usize;
13916                let sw = *sw as usize;
13917                let ph = *ph as usize;
13918                let pw = *pw as usize;
13919                let dh = *dh as usize;
13920                let dw_dil = *dw_dil as usize;
13921                let groups = *groups as usize;
13922                let c_in_per_g = c_in / groups;
13923                let c_out_per_g = c_out / groups;
13924
13925                let m_dim = c_out_per_g;
13926                let n_dim = c_in_per_g * kh * kw;
13927                let k_dim = h_out * w_out;
13928
13929                let x_stride_n = c_in * h * w;
13930                let x_stride_g = c_in_per_g * h * w;
13931                let dy_stride_n = c_out * h_out * w_out;
13932                let dy_stride_g = c_out_per_g * h_out * w_out;
13933                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
13934
13935                unsafe {
13936                    let xs = sl(*x, base, n * c_in * h * w);
13937                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
13938                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
13939                    for v in dws.iter_mut() {
13940                        *v = 0.0;
13941                    }
13942
13943                    let mut col = vec![0f32; n_dim * k_dim];
13944
13945                    for ni in 0..n {
13946                        for g in 0..groups {
13947                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
13948                            im2col(
13949                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
13950                                &mut col,
13951                                c_in_per_g,
13952                                h,
13953                                w,
13954                                h_out,
13955                                w_out,
13956                                kh,
13957                                kw,
13958                                sh,
13959                                sw,
13960                                ph,
13961                                pw,
13962                                dh,
13963                                dw_dil,
13964                            );
13965
13966                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
13967                            let dw_g_off = g * dw_stride_g;
13968
13969                            // dw_g += dy_n_g @ col^T
13970                            //
13971                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
13972                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
13973                            // col is stored as N×K row-major; with trans_b=true,
13974                            // sgemm_general uses ldb = K = k_dim and treats it as
13975                            // transposed. β=1 accumulates across the batch loop.
13976                            crate::blas::sgemm_general(
13977                                dys.as_ptr().add(dy_n_g_off),
13978                                col.as_ptr(),
13979                                dws.as_mut_ptr().add(dw_g_off),
13980                                m_dim,
13981                                n_dim,
13982                                k_dim,
13983                                1.0,
13984                                1.0,
13985                                /*lda=*/ k_dim,
13986                                /*ldb=*/ k_dim,
13987                                /*ldc=*/ n_dim,
13988                                /*trans_a=*/ false,
13989                                /*trans_b=*/ true,
13990                            );
13991                        }
13992                    }
13993                }
13994            }
13995
13996            Thunk::Im2Col {
13997                x,
13998                col,
13999                n,
14000                c_in,
14001                h,
14002                w,
14003                h_out,
14004                w_out,
14005                kh,
14006                kw,
14007                sh,
14008                sw,
14009                ph,
14010                pw,
14011                dh,
14012                dw_dil,
14013            } => {
14014                let c_in = *c_in as usize;
14015                let h = *h as usize;
14016                let w = *w as usize;
14017                let h_out = *h_out as usize;
14018                let w_out = *w_out as usize;
14019                let kh = *kh as usize;
14020                let kw = *kw as usize;
14021                let sh = *sh as usize;
14022                let sw = *sw as usize;
14023                let ph = *ph as usize;
14024                let pw = *pw as usize;
14025                let dh = *dh as usize;
14026                let dw_dil = *dw_dil as usize;
14027                let per_batch = c_in * h * w;
14028                unsafe {
14029                    let n_eff = if *n == 0 { 0usize } else { *n as usize };
14030                    let x_floats = if n_eff == 0 {
14031                        per_batch.max(1)
14032                    } else {
14033                        n_eff * per_batch
14034                    };
14035                    let xs = sl(*x, base, x_floats);
14036                    let n = if *n == 0 {
14037                        xs.len() / per_batch.max(1)
14038                    } else {
14039                        n_eff
14040                    };
14041                    let m = n * h_out * w_out;
14042                    let k = c_in * kh * kw;
14043                    let cols = sl_mut(*col, base, m * k);
14044                    crate::im2col::im2col_rows_layout(
14045                        xs, cols, n, c_in, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw_dil,
14046                    );
14047                }
14048            }
14049
14050            Thunk::SoftmaxCrossEntropy {
14051                logits,
14052                labels,
14053                dst,
14054                n,
14055                c,
14056            } => {
14057                let n = *n as usize;
14058                let c = *c as usize;
14059                unsafe {
14060                    let lg = sl(*logits, base, n * c);
14061                    let lb = sl(*labels, base, n);
14062                    let out = sl_mut(*dst, base, n);
14063                    for ni in 0..n {
14064                        let row = &lg[ni * c..(ni + 1) * c];
14065                        // log-sum-exp: max-subtract for stability.
14066                        let mut m = f32::NEG_INFINITY;
14067                        for &v in row {
14068                            if v > m {
14069                                m = v;
14070                            }
14071                        }
14072                        let mut sum = 0f32;
14073                        for &v in row {
14074                            sum += (v - m).exp();
14075                        }
14076                        let lse = m + sum.ln();
14077                        let label_idx = lb[ni] as usize;
14078                        // loss = -(logits[label] - lse) = lse - logits[label].
14079                        out[ni] = lse - row[label_idx];
14080                    }
14081                }
14082            }
14083
14084            Thunk::SoftmaxCrossEntropyBackward {
14085                logits,
14086                labels,
14087                d_loss,
14088                dlogits,
14089                n,
14090                c,
14091            } => {
14092                let n = *n as usize;
14093                let c = *c as usize;
14094                unsafe {
14095                    let lg = sl(*logits, base, n * c);
14096                    let lb = sl(*labels, base, n);
14097                    let dl = sl(*d_loss, base, n);
14098                    let out = sl_mut(*dlogits, base, n * c);
14099                    for ni in 0..n {
14100                        let row = &lg[ni * c..(ni + 1) * c];
14101                        let label_idx = lb[ni] as usize;
14102                        let scale = dl[ni];
14103                        let mut m = f32::NEG_INFINITY;
14104                        for &v in row {
14105                            if v > m {
14106                                m = v;
14107                            }
14108                        }
14109                        let mut sum = 0f32;
14110                        for &v in row {
14111                            sum += (v - m).exp();
14112                        }
14113                        let inv_sum = 1.0 / sum;
14114                        let dst_row = &mut out[ni * c..(ni + 1) * c];
14115                        for k in 0..c {
14116                            let p = (row[k] - m).exp() * inv_sum;
14117                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
14118                            dst_row[k] = (p - one_hot) * scale;
14119                        }
14120                    }
14121                }
14122            }
14123
14124            Thunk::GatherAxis {
14125                table,
14126                idx,
14127                dst,
14128                outer,
14129                axis_dim,
14130                num_idx,
14131                trailing,
14132                idx_i64,
14133                table_bytes,
14134            } => {
14135                let outer = *outer as usize;
14136                let axis_dim = *axis_dim as usize;
14137                let num_idx = *num_idx as usize;
14138                let trailing = *trailing as usize;
14139                unsafe {
14140                    if *table_bytes == 8 {
14141                        let tab = sl_i64(*table, base, outer * axis_dim * trailing);
14142                        let out = sl_mut_i64(*dst, base, outer * num_idx * trailing);
14143                        for o in 0..outer {
14144                            let tab_outer = o * axis_dim * trailing;
14145                            let out_outer = o * num_idx * trailing;
14146                            if *idx_i64 != 0 {
14147                                let ids = sl_i64(*idx, base, num_idx);
14148                                for k in 0..num_idx {
14149                                    let row = ids[k].max(0) as usize;
14150                                    if row < axis_dim {
14151                                        let tab_row = tab_outer + row * trailing;
14152                                        let out_row = out_outer + k * trailing;
14153                                        out[out_row..out_row + trailing]
14154                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14155                                    }
14156                                }
14157                            } else {
14158                                let ids = sl(*idx, base, num_idx);
14159                                for k in 0..num_idx {
14160                                    let row = ids[k] as usize;
14161                                    if row < axis_dim {
14162                                        let tab_row = tab_outer + row * trailing;
14163                                        let out_row = out_outer + k * trailing;
14164                                        out[out_row..out_row + trailing]
14165                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14166                                    }
14167                                }
14168                            }
14169                        }
14170                    } else {
14171                        let tab = sl(*table, base, outer * axis_dim * trailing);
14172                        let out = sl_mut(*dst, base, outer * num_idx * trailing);
14173                        for o in 0..outer {
14174                            let tab_outer = o * axis_dim * trailing;
14175                            let out_outer = o * num_idx * trailing;
14176                            if *idx_i64 != 0 {
14177                                let ids = sl_i64(*idx, base, num_idx);
14178                                for k in 0..num_idx {
14179                                    let row = ids[k].max(0) as usize;
14180                                    if row < axis_dim {
14181                                        let tab_row = tab_outer + row * trailing;
14182                                        let out_row = out_outer + k * trailing;
14183                                        out[out_row..out_row + trailing]
14184                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14185                                    }
14186                                }
14187                            } else {
14188                                let ids = sl(*idx, base, num_idx);
14189                                for k in 0..num_idx {
14190                                    let row = ids[k] as usize;
14191                                    if row < axis_dim {
14192                                        let tab_row = tab_outer + row * trailing;
14193                                        let out_row = out_outer + k * trailing;
14194                                        out[out_row..out_row + trailing]
14195                                            .copy_from_slice(&tab[tab_row..tab_row + trailing]);
14196                                    }
14197                                }
14198                            }
14199                        }
14200                    }
14201                }
14202            }
14203
14204            Thunk::Transpose {
14205                src,
14206                dst,
14207                in_total,
14208                out_dims,
14209                in_strides,
14210                elem_bytes,
14211            } => {
14212                // N-D index walk: for each output flat index, decompose into
14213                // multi-dim coords using out_dims, then dot with in_strides
14214                // to find the source flat index. Stride 0 = broadcast (read
14215                // the same input element repeatedly along that dim).
14216                let rank = out_dims.len();
14217                let total: usize = out_dims.iter().map(|&d| d as usize).product();
14218                let in_total = *in_total as usize;
14219                unsafe {
14220                    if *elem_bytes == 1 {
14221                        // 1-byte dtypes (Bool / I8 / U8). Without this branch the
14222                        // `else` path below reads/writes 4 bytes per element via the
14223                        // f32 slice, corrupting e.g. a broadcast of the VITS attention
14224                        // mask (Bool, expanded over heads) — masking wrong positions.
14225                        let inp = arena_buf[*src..*src + in_total].to_vec();
14226                        let out = &mut arena_buf[*dst..*dst + total];
14227                        let mut idx = vec![0usize; rank];
14228                        for o in 0..total {
14229                            let mut src_idx = 0usize;
14230                            for d in 0..rank {
14231                                src_idx += idx[d] * in_strides[d] as usize;
14232                            }
14233                            out[o] = inp[broadcast_src_index(src_idx, in_total)];
14234                            for d in (0..rank).rev() {
14235                                idx[d] += 1;
14236                                if idx[d] < out_dims[d] as usize {
14237                                    break;
14238                                }
14239                                idx[d] = 0;
14240                            }
14241                        }
14242                    } else if *elem_bytes == 8 {
14243                        let inp = sl_i64(*src, base, in_total);
14244                        let out = sl_mut_i64(*dst, base, total);
14245                        let mut idx = vec![0usize; rank];
14246                        for o in 0..total {
14247                            let mut src_idx = 0usize;
14248                            for d in 0..rank {
14249                                src_idx += idx[d] * in_strides[d] as usize;
14250                            }
14251                            out[o] = inp[broadcast_src_index(src_idx, in_total)];
14252                            for d in (0..rank).rev() {
14253                                idx[d] += 1;
14254                                if idx[d] < out_dims[d] as usize {
14255                                    break;
14256                                }
14257                                idx[d] = 0;
14258                            }
14259                        }
14260                    } else {
14261                        let inp = sl(*src, base, in_total);
14262                        let out = sl_mut(*dst, base, total);
14263                        let mut idx = vec![0usize; rank];
14264                        for o in 0..total {
14265                            let mut src_idx = 0usize;
14266                            for d in 0..rank {
14267                                src_idx += idx[d] * in_strides[d] as usize;
14268                            }
14269                            out[o] = inp[broadcast_src_index(src_idx, in_total)];
14270                            for d in (0..rank).rev() {
14271                                idx[d] += 1;
14272                                if idx[d] < out_dims[d] as usize {
14273                                    break;
14274                                }
14275                                idx[d] = 0;
14276                            }
14277                        }
14278                    }
14279                }
14280            }
14281
14282            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
14283            // stubs here as placeholders during the wire-up; both
14284            // are now reached by the real implementations earlier in
14285            // this same match — the stubs were dead code shadowed by
14286            // the specific-pattern arms above. Removed.)
14287            Thunk::CustomOp {
14288                kernel,
14289                inputs,
14290                output,
14291                attrs,
14292            } => {
14293                let (out_off, out_len, out_shape) = output;
14294                unsafe {
14295                    dispatch_custom_op(
14296                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
14297                    );
14298                }
14299            }
14300        }
14301        if trace_done {
14302            eprintln!("[thunk {i} done]");
14303        }
14304    }
14305}
14306
14307/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
14308/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
14309/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
14310/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
14311///
14312/// Compared to the iterative segment-cached scheme, this trades extra
14313/// recompute for less working memory — each level of recursion holds
14314/// one `cb`-sized intermediate carry on the stack but never the whole
14315/// segment at once. With K saved outer checkpoints, the outer driver
14316/// invokes this helper once per segment.
14317///
14318/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
14319/// runs `body_vjp` at iteration `t` with the supplied carry, threads
14320/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
14321#[allow(clippy::too_many_arguments)]
14322unsafe fn griewank_process_segment(
14323    t_lo: usize,
14324    t_hi: usize,
14325    anchor_carry: &[u8],
14326    cb: usize,
14327    fwd_sched: &ThunkSchedule,
14328    fwd_init: &[u8],
14329    fwd_carry_in_off: usize,
14330    fwd_output_off: usize,
14331    fwd_x_offs: &[usize],
14332    base: *mut u8,
14333    outer_xs_offs: &[(usize, u32)],
14334    fwd_buf: &mut Vec<u8>,
14335    leaf_threshold: usize,
14336    process_iter: &mut dyn FnMut(usize, &[u8]),
14337) {
14338    unsafe {
14339        let size = t_hi - t_lo + 1;
14340        if size == 1 {
14341            process_iter(t_lo, anchor_carry);
14342            return;
14343        }
14344        if size <= leaf_threshold {
14345            // Walk forward, cache each carry, run backward in reverse.
14346            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
14347            cache.extend_from_slice(anchor_carry);
14348            fwd_buf.copy_from_slice(fwd_init);
14349            std::ptr::copy_nonoverlapping(
14350                anchor_carry.as_ptr(),
14351                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
14352                cb,
14353            );
14354            for i in 1..size {
14355                let cur_iter = t_lo + i - 1;
14356                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
14357                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
14358                    let xb = x_psb as usize;
14359                    std::ptr::copy_nonoverlapping(
14360                        base.add(outer_xs_off + cur_iter * xb),
14361                        fwd_buf.as_mut_ptr().add(*fb_x_off),
14362                        xb,
14363                    );
14364                }
14365                execute_thunks(fwd_sched, fwd_buf);
14366                if fwd_output_off != fwd_carry_in_off {
14367                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
14368                }
14369                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
14370            }
14371            // Process backward.
14372            for t in (t_lo..=t_hi).rev() {
14373                let idx = t - t_lo;
14374                let carry = &cache[idx * cb..(idx + 1) * cb];
14375                process_iter(t, carry);
14376            }
14377            return;
14378        }
14379
14380        // Split: walk forward from anchor to compute carry entering `mid`.
14381        // (We need `mid - t_lo` body executions: one per iteration in
14382        // [t_lo, mid).)
14383        let mid = t_lo + size / 2;
14384        fwd_buf.copy_from_slice(fwd_init);
14385        std::ptr::copy_nonoverlapping(
14386            anchor_carry.as_ptr(),
14387            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
14388            cb,
14389        );
14390        for cur_iter in t_lo..mid {
14391            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
14392                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
14393                let xb = x_psb as usize;
14394                std::ptr::copy_nonoverlapping(
14395                    base.add(outer_xs_off + cur_iter * xb),
14396                    fwd_buf.as_mut_ptr().add(*fb_x_off),
14397                    xb,
14398                );
14399            }
14400            execute_thunks(fwd_sched, fwd_buf);
14401            if fwd_output_off != fwd_carry_in_off {
14402                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
14403            }
14404        }
14405        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
14406
14407        // Right half first (higher t values processed first to match the
14408        // canonical reverse-mode iteration order: dcarry threads from
14409        // t=length-1 down to t=0).
14410        griewank_process_segment(
14411            mid,
14412            t_hi,
14413            &mid_carry,
14414            cb,
14415            fwd_sched,
14416            fwd_init,
14417            fwd_carry_in_off,
14418            fwd_output_off,
14419            fwd_x_offs,
14420            base,
14421            outer_xs_offs,
14422            fwd_buf,
14423            leaf_threshold,
14424            process_iter,
14425        );
14426        // Then left half with original anchor.
14427        griewank_process_segment(
14428            t_lo,
14429            mid - 1,
14430            anchor_carry,
14431            cb,
14432            fwd_sched,
14433            fwd_init,
14434            fwd_carry_in_off,
14435            fwd_output_off,
14436            fwd_x_offs,
14437            base,
14438            outer_xs_offs,
14439            fwd_buf,
14440            leaf_threshold,
14441            process_iter,
14442        );
14443    }
14444}
14445
14446/// Execute a batched 1D FFT in the f64 2N-real-block layout.
14447/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
14448/// The `outer` rows are independent and processed sequentially.
14449///
14450/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
14451/// kernel — only the twiddle-factor sign differs. Power-of-2 only
14452/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
14453/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
14454/// backend crates can invoke this as a host fallback against a
14455/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
14456/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
14457/// command buffer). Self-contained — no rlx-cpu state required.
14458///
14459/// Safety: `base + src` and `base + dst` must be valid for the
14460/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
14461/// alive for the duration of the call.
14462pub unsafe fn execute_fft1d_f64(
14463    src: usize,
14464    dst: usize,
14465    outer: usize,
14466    n_complex: usize,
14467    inverse: bool,
14468    norm_tag: u32,
14469    base: *mut u8,
14470) {
14471    let row_elems = 2 * n_complex;
14472    let mut re = vec![0f64; n_complex];
14473    let mut im = vec![0f64; n_complex];
14474    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
14475    let scale = norm.output_scale(n_complex, inverse);
14476    // Scratch reused across rows for the Bluestein path. Empty when
14477    // we're on the radix-2 fast path.
14478    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
14479        BluesteinScratchF64::empty()
14480    } else {
14481        BluesteinScratchF64::build(n_complex, inverse)
14482    };
14483    for o in 0..outer {
14484        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
14485        let s = unsafe { sl_f64(row_offset, base, row_elems) };
14486        re.copy_from_slice(&s[..n_complex]);
14487        im.copy_from_slice(&s[n_complex..]);
14488        if n_complex.is_power_of_two() {
14489            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
14490        } else if n_complex <= 16 {
14491            fft_naive_inplace_f64(&mut re, &mut im, inverse);
14492        } else {
14493            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
14494        }
14495        if scale != 1.0 {
14496            re.iter_mut().for_each(|v| *v *= scale);
14497            im.iter_mut().for_each(|v| *v *= scale);
14498        }
14499        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
14500        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
14501        d[..n_complex].copy_from_slice(&re);
14502        d[n_complex..].copy_from_slice(&im);
14503    }
14504}
14505
14506/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
14507/// (first N real, second N imag per row), same unnormalized
14508/// convention; only the element width differs. Twiddle factors are
14509/// computed in f64 and cast to f32 to keep large-N error closer to
14510/// the f64 path (the savings from f32 are in memory bandwidth, not in
14511/// twiddle precision).
14512/// Complex (C64) dense GEMM `C[m,n] = A[m,k] · B[k,n]`. Operands are
14513/// interleaved `[re, im]` f32; `a_off`/`b_off`/`c_off` are byte offsets
14514/// into `base`. Parallel over output rows (disjoint writes).
14515unsafe fn cgemm_c64(
14516    a_off: usize,
14517    b_off: usize,
14518    c_off: usize,
14519    m: usize,
14520    k: usize,
14521    n: usize,
14522    base: *mut u8,
14523) {
14524    use rayon::prelude::*;
14525    let bptr = base as usize;
14526    unsafe {
14527        let a = std::slice::from_raw_parts((bptr + a_off) as *const f32, 2 * m * k);
14528        let b = std::slice::from_raw_parts((bptr + b_off) as *const f32, 2 * k * n);
14529        let c_base = bptr + c_off;
14530        (0..m).into_par_iter().for_each(|i| {
14531            let crow = std::slice::from_raw_parts_mut((c_base + i * n * 8) as *mut f32, 2 * n);
14532            for j in 0..n {
14533                let mut re = 0f32;
14534                let mut im = 0f32;
14535                for l in 0..k {
14536                    let ar = a[2 * (i * k + l)];
14537                    let ai = a[2 * (i * k + l) + 1];
14538                    let br = b[2 * (l * n + j)];
14539                    let bi = b[2 * (l * n + j) + 1];
14540                    re += ar * br - ai * bi;
14541                    im += ar * bi + ai * br;
14542                }
14543                crow[2 * j] = re;
14544                crow[2 * j + 1] = im;
14545            }
14546        });
14547    }
14548}
14549
14550/// Reference / host-fallback entry for `Op::Lstm` (multi-layer,
14551/// optionally bidirectional, optional decode carry). Gate order i, f, g,
14552/// o. Shared by the CPU backend and the CUDA / ROCm / wgpu / Metal host
14553/// fallbacks. Tensors are `f32` in the arena; weights are packed (see
14554/// `Op::Lstm`). `h0`/`c0` are byte offsets when `carry`, else ignored;
14555/// the final `hn`/`cn` are written back into them in place. `dst` is
14556/// `[batch, seq, D*hidden]`. Batch items run in parallel per direction.
14557#[allow(clippy::too_many_arguments)]
14558pub unsafe fn execute_lstm_f32(
14559    x: usize,
14560    w_ih: usize,
14561    w_hh: usize,
14562    bias: usize,
14563    h0: usize,
14564    c0: usize,
14565    dst: usize,
14566    batch: usize,
14567    seq: usize,
14568    input_size: usize,
14569    hidden: usize,
14570    num_layers: usize,
14571    bidirectional: bool,
14572    carry: bool,
14573    base: *mut u8,
14574) {
14575    use rayon::prelude::*;
14576
14577    #[inline]
14578    fn sigmoid(z: f32) -> f32 {
14579        1.0 / (1.0 + (-z).exp())
14580    }
14581
14582    let bptr = base as usize;
14583    let four_h = 4 * hidden;
14584    let dirs = if bidirectional { 2 } else { 1 };
14585
14586    unsafe {
14587        let f32s = |off: usize, n: usize| -> &[f32] {
14588            std::slice::from_raw_parts((bptr + off) as *const f32, n)
14589        };
14590
14591        // Layer 0 reads x; later layers read the previous layer's output.
14592        let mut layer_in: Vec<f32> = f32s(x, batch * seq * input_size).to_vec();
14593        let mut in_l = input_size;
14594        // Running element cursor into the packed `w_ih` buffer (block width
14595        // `4h * in_l` varies per layer; `w_hh`/`bias` blocks are uniform).
14596        let mut wih_cursor = 0usize;
14597
14598        for l in 0..num_layers {
14599            let out_width = dirs * hidden;
14600            let mut layer_out = vec![0f32; batch * seq * out_width];
14601            let lo_ptr = layer_out.as_mut_ptr() as usize;
14602            let li_ref: &[f32] = &layer_in;
14603            let wih_block = four_h * in_l;
14604
14605            for dir in 0..dirs {
14606                let ld = l * dirs + dir;
14607                let wih = f32s((w_ih / 4 + wih_cursor + dir * wih_block) * 4, wih_block);
14608                let whh = f32s(w_hh + ld * four_h * hidden * 4, four_h * hidden);
14609                let bs = f32s(bias + ld * four_h * 4, four_h);
14610                let h0p = bptr + h0 + ld * batch * hidden * 4;
14611                let c0p = bptr + c0 + ld * batch * hidden * 4;
14612
14613                (0..batch).into_par_iter().for_each(|b| {
14614                    let lo = lo_ptr as *mut f32;
14615                    let mut h = vec![0f32; hidden];
14616                    let mut c = vec![0f32; hidden];
14617                    if carry {
14618                        let hin = std::slice::from_raw_parts(
14619                            (h0p + b * hidden * 4) as *const f32,
14620                            hidden,
14621                        );
14622                        let cin = std::slice::from_raw_parts(
14623                            (c0p + b * hidden * 4) as *const f32,
14624                            hidden,
14625                        );
14626                        h.copy_from_slice(hin);
14627                        c.copy_from_slice(cin);
14628                    }
14629                    let mut z = vec![0f32; four_h];
14630                    for step in 0..seq {
14631                        let t = if dir == 0 { step } else { seq - 1 - step };
14632                        let x_t = &li_ref[(b * seq + t) * in_l..(b * seq + t + 1) * in_l];
14633                        for r in 0..four_h {
14634                            let wr = &wih[r * in_l..(r + 1) * in_l];
14635                            let mut acc = bs[r];
14636                            for j in 0..in_l {
14637                                acc += wr[j] * x_t[j];
14638                            }
14639                            let hr = &whh[r * hidden..(r + 1) * hidden];
14640                            for (j, &hj) in h.iter().enumerate() {
14641                                acc += hr[j] * hj;
14642                            }
14643                            z[r] = acc;
14644                        }
14645                        for k in 0..hidden {
14646                            let i_g = sigmoid(z[k]);
14647                            let f_g = sigmoid(z[hidden + k]);
14648                            let g_g = z[2 * hidden + k].tanh();
14649                            let o_g = sigmoid(z[3 * hidden + k]);
14650                            let c_new = f_g * c[k] + i_g * g_g;
14651                            c[k] = c_new;
14652                            let h_new = o_g * c_new.tanh();
14653                            h[k] = h_new;
14654                            // [batch, seq, D*hidden]; this direction owns the
14655                            // `dir*hidden .. dir*hidden+hidden` feature slice.
14656                            *lo.add((b * seq + t) * out_width + dir * hidden + k) = h_new;
14657                        }
14658                    }
14659                    if carry {
14660                        let hout = std::slice::from_raw_parts_mut(
14661                            (h0p + b * hidden * 4) as *mut f32,
14662                            hidden,
14663                        );
14664                        let cout = std::slice::from_raw_parts_mut(
14665                            (c0p + b * hidden * 4) as *mut f32,
14666                            hidden,
14667                        );
14668                        hout.copy_from_slice(&h);
14669                        cout.copy_from_slice(&c);
14670                    }
14671                });
14672            }
14673
14674            wih_cursor += dirs * wih_block;
14675            layer_in = layer_out;
14676            in_l = out_width;
14677        }
14678
14679        // Final layer output → dst [batch, seq, D*hidden].
14680        let dst_slice = std::slice::from_raw_parts_mut((bptr + dst) as *mut f32, layer_in.len());
14681        dst_slice.copy_from_slice(&layer_in);
14682    }
14683}
14684
14685/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
14686/// When `state == 0`, uses a zero-initialized scratch state per batch item.
14687pub unsafe fn execute_gated_delta_net_f32(
14688    q: usize,
14689    k: usize,
14690    v: usize,
14691    g: usize,
14692    beta: usize,
14693    state: usize,
14694    dst: usize,
14695    batch: usize,
14696    seq: usize,
14697    heads: usize,
14698    state_size: usize,
14699    base: *mut u8,
14700) {
14701    use rayon::prelude::*;
14702
14703    #[derive(Copy, Clone)]
14704    struct ArenaPtr(usize);
14705    unsafe impl Send for ArenaPtr {}
14706    unsafe impl Sync for ArenaPtr {}
14707    impl ArenaPtr {
14708        #[inline]
14709        fn get(self) -> *mut u8 {
14710            self.0 as *mut u8
14711        }
14712    }
14713
14714    unsafe {
14715        let arena = ArenaPtr(base as usize);
14716        let (b, s, h, n) = (batch, seq, heads, state_size);
14717        let scale = 1.0f32 / (n as f32).sqrt();
14718        let use_external = state != 0;
14719        let mut owned_state = vec![0f32; h * n * n];
14720
14721        crate::pool::num_threads();
14722
14723        assert!(
14724            n <= crate::gdn::GDN_MAX_STATE,
14725            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
14726            crate::gdn::GDN_MAX_STATE
14727        );
14728
14729        let qs = sl(q, arena.get(), b * s * h * n);
14730        let ks = sl(k, arena.get(), b * s * h * n);
14731        let vs = sl(v, arena.get(), b * s * h * n);
14732        let gs = sl(g, arena.get(), b * s * h);
14733        let betas = sl(beta, arena.get(), b * s * h);
14734        let _out = sl_mut(dst, arena.get(), b * s * h * n);
14735        let hs_n = h * n;
14736
14737        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
14738            for ti in 0..s {
14739                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
14740                let gb_step = bi * s * h + ti * h + hi;
14741                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
14742                crate::gdn::gdn_step_blas(
14743                    s_mat,
14744                    &qs[qkv_step..qkv_step + n],
14745                    &ks[qkv_step..qkv_step + n],
14746                    &vs[qkv_step..qkv_step + n],
14747                    gs[gb_step],
14748                    betas[gb_step],
14749                    out_row,
14750                    sk,
14751                    n,
14752                    scale,
14753                );
14754            }
14755        };
14756
14757        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
14758        // better occupancy than head-outer when prompt length dominates.
14759        if !use_external && s > 1 {
14760            for bi in 0..b {
14761                (0..h).into_par_iter().for_each(|hi| {
14762                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14763                    let sk = &mut sk_buf[..n];
14764                    let mut local_state =
14765                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
14766                    let s_mat = &mut local_state[..n * n];
14767                    s_mat.fill(0.0);
14768                    run_head(bi, hi, s_mat, sk);
14769                });
14770            }
14771            return;
14772        }
14773
14774        if use_external {
14775            let state_bytes = state;
14776            (0..b * h).into_par_iter().for_each(|bhi| {
14777                let bi = bhi / h;
14778                let hi = bhi % h;
14779                let elem_off = bi * h * n * n + hi * n * n;
14780                let s_mat = sl_mut(
14781                    state_bytes + elem_off * std::mem::size_of::<f32>(),
14782                    arena.get(),
14783                    n * n,
14784                );
14785                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14786                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14787            });
14788        } else {
14789            for bi in 0..b {
14790                owned_state.fill(0.0);
14791                owned_state
14792                    .par_chunks_mut(n * n)
14793                    .enumerate()
14794                    .for_each(|(hi, s_mat)| {
14795                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
14796                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
14797                    });
14798            }
14799        }
14800    }
14801}
14802
14803/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
14804pub unsafe fn execute_rms_norm_backward_input_f32(
14805    x: usize,
14806    gamma: usize,
14807    beta: usize,
14808    dy: usize,
14809    dx: usize,
14810    rows: u32,
14811    h: u32,
14812    eps: f32,
14813    base: *mut u8,
14814) {
14815    let (rows, h) = (rows as usize, h as usize);
14816    let mut dg = vec![0f32; h];
14817    let mut db = vec![0f32; h];
14818    let xs = sl(x, base, rows * h);
14819    let dys = sl(dy, base, rows * h);
14820    let g = sl(gamma, base, h);
14821    let b = sl(beta, base, h);
14822    let out = sl_mut(dx, base, rows * h);
14823    for r in 0..rows {
14824        crate::training_bwd::rms_norm_backward_row(
14825            &xs[r * h..(r + 1) * h],
14826            g,
14827            b,
14828            &dys[r * h..(r + 1) * h],
14829            &mut out[r * h..(r + 1) * h],
14830            &mut dg,
14831            &mut db,
14832            eps,
14833        );
14834    }
14835}
14836
14837pub unsafe fn execute_rms_norm_backward_gamma_f32(
14838    x: usize,
14839    gamma: usize,
14840    beta: usize,
14841    dy: usize,
14842    dgamma: usize,
14843    rows: u32,
14844    h: u32,
14845    eps: f32,
14846    base: *mut u8,
14847) {
14848    let (rows, h) = (rows as usize, h as usize);
14849    let out = sl_mut(dgamma, base, h);
14850    out.fill(0.0);
14851    let mut dx = vec![0f32; h];
14852    let mut db = vec![0f32; h];
14853    let xs = sl(x, base, rows * h);
14854    let dys = sl(dy, base, rows * h);
14855    let g = sl(gamma, base, h);
14856    let b = sl(beta, base, h);
14857    for r in 0..rows {
14858        crate::training_bwd::rms_norm_backward_row(
14859            &xs[r * h..(r + 1) * h],
14860            g,
14861            b,
14862            &dys[r * h..(r + 1) * h],
14863            &mut dx,
14864            out,
14865            &mut db,
14866            eps,
14867        );
14868    }
14869}
14870
14871pub unsafe fn execute_rms_norm_backward_beta_f32(
14872    x: usize,
14873    gamma: usize,
14874    beta: usize,
14875    dy: usize,
14876    dbeta: usize,
14877    rows: u32,
14878    h: u32,
14879    eps: f32,
14880    base: *mut u8,
14881) {
14882    let (rows, h) = (rows as usize, h as usize);
14883    let out = sl_mut(dbeta, base, h);
14884    out.fill(0.0);
14885    let mut dx = vec![0f32; h];
14886    let mut dg = vec![0f32; h];
14887    let xs = sl(x, base, rows * h);
14888    let dys = sl(dy, base, rows * h);
14889    let g = sl(gamma, base, h);
14890    let b = sl(beta, base, h);
14891    for r in 0..rows {
14892        crate::training_bwd::rms_norm_backward_row(
14893            &xs[r * h..(r + 1) * h],
14894            g,
14895            b,
14896            &dys[r * h..(r + 1) * h],
14897            &mut dx,
14898            &mut dg,
14899            out,
14900            eps,
14901        );
14902    }
14903}
14904
14905#[allow(clippy::too_many_arguments)]
14906pub unsafe fn execute_conv2d_forward_f32(
14907    src: usize,
14908    weight: usize,
14909    dst: usize,
14910    n: u32,
14911    c_in: u32,
14912    h: u32,
14913    w: u32,
14914    c_out: u32,
14915    h_out: u32,
14916    w_out: u32,
14917    kh: u32,
14918    kw: u32,
14919    sh: u32,
14920    sw: u32,
14921    ph: u32,
14922    pw: u32,
14923    dh: u32,
14924    dw: u32,
14925    groups: u32,
14926    base: *mut u8,
14927) {
14928    let n = n as usize;
14929    let c_in = c_in as usize;
14930    let h = h as usize;
14931    let w = w as usize;
14932    let c_out = c_out as usize;
14933    let h_out = h_out as usize;
14934    let w_out = w_out as usize;
14935    let kh = kh as usize;
14936    let kw = kw as usize;
14937    let sh = sh as usize;
14938    let sw = sw as usize;
14939    let ph = ph as usize;
14940    let pw = pw as usize;
14941    let dh = dh as usize;
14942    let dw = dw as usize;
14943    let groups = groups as usize;
14944    let c_in_per_g = c_in / groups;
14945    let inp = sl(src, base, n * c_in * h * w);
14946    let wt = sl(weight, base, c_out * c_in_per_g * kh * kw);
14947    let out = sl_mut(dst, base, n * c_out * h_out * w_out);
14948    crate::conv_fwd::conv2d_forward_nchw_f32(
14949        inp, wt, out, n, c_in, h, w, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh, dw, groups,
14950    );
14951}
14952
14953pub unsafe fn execute_maxpool2d_backward_f32(
14954    x: usize,
14955    dy: usize,
14956    dx: usize,
14957    n: u32,
14958    c: u32,
14959    h: u32,
14960    w: u32,
14961    h_out: u32,
14962    w_out: u32,
14963    kh: u32,
14964    kw: u32,
14965    sh: u32,
14966    sw: u32,
14967    ph: u32,
14968    pw: u32,
14969    base: *mut u8,
14970) {
14971    let (n, c, h, w) = (n as usize, c as usize, h as usize, w as usize);
14972    let (h_out, w_out) = (h_out as usize, w_out as usize);
14973    let (kh, kw) = (kh as usize, kw as usize);
14974    let (sh, sw) = (sh as usize, sw as usize);
14975    let (ph, pw) = (ph as usize, pw as usize);
14976    let xs = sl(x, base, n * c * h * w);
14977    let dys = sl(dy, base, n * c * h_out * w_out);
14978    let dxs = sl_mut(dx, base, n * c * h * w);
14979    crate::training_bwd::maxpool2d_backward_nchw(
14980        xs, dys, dxs, n, c, h, w, h_out, w_out, kh, kw, sh, sw, ph, pw,
14981    );
14982}
14983
14984pub unsafe fn execute_rope_backward_f32(
14985    dy: usize,
14986    cos: usize,
14987    sin: usize,
14988    dx: usize,
14989    batch: u32,
14990    seq: u32,
14991    hidden: u32,
14992    head_dim: u32,
14993    n_rot: u32,
14994    cos_len: u32,
14995    base: *mut u8,
14996) {
14997    let (b, s, hs, dh, nr, cl) = (
14998        batch as usize,
14999        seq as usize,
15000        hidden as usize,
15001        head_dim as usize,
15002        n_rot as usize,
15003        cos_len as usize,
15004    );
15005    let nh = hs / dh;
15006    let tab_half = dh / 2;
15007    let dys = sl(dy, base, b * s * hs);
15008    let cos_tab = sl(cos, base, cl);
15009    let sin_tab = sl(sin, base, cl);
15010    let out = sl_mut(dx, base, b * s * hs);
15011    for bi in 0..b {
15012        for si in 0..s {
15013            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
15014            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
15015            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
15016            for hi in 0..nh {
15017                let base_idx = bi * s * hs + si * hs + hi * dh;
15018                crate::training_bwd::rope_backward_row(
15019                    &dys[base_idx..base_idx + dh],
15020                    cp,
15021                    sp,
15022                    &mut out[base_idx..base_idx + dh],
15023                    dh,
15024                    nr,
15025                );
15026            }
15027        }
15028    }
15029}
15030
15031pub unsafe fn execute_cumsum_backward_f32(
15032    dy: usize,
15033    dx: usize,
15034    rows: u32,
15035    cols: u32,
15036    exclusive: bool,
15037    base: *mut u8,
15038) {
15039    let (rows, cols) = (rows as usize, cols as usize);
15040    let dys = sl(dy, base, rows * cols);
15041    let out = sl_mut(dx, base, rows * cols);
15042    for r in 0..rows {
15043        crate::training_bwd::cumsum_backward_row(
15044            &dys[r * cols..(r + 1) * cols],
15045            &mut out[r * cols..(r + 1) * cols],
15046            exclusive,
15047        );
15048    }
15049}
15050
15051pub unsafe fn execute_gather_backward_f32(
15052    dy: usize,
15053    indices: usize,
15054    dst: usize,
15055    outer: u32,
15056    axis_dim: u32,
15057    num_idx: u32,
15058    trailing: u32,
15059    base: *mut u8,
15060) {
15061    let (outer, axis_dim, num_idx, trailing) = (
15062        outer as usize,
15063        axis_dim as usize,
15064        num_idx as usize,
15065        trailing as usize,
15066    );
15067    let out = sl_mut(dst, base, outer * axis_dim * trailing);
15068    out.fill(0.0);
15069    crate::training_bwd::gather_axis_backward(
15070        sl(dy, base, outer * num_idx * trailing),
15071        sl(indices, base, num_idx),
15072        out,
15073        outer,
15074        axis_dim,
15075        num_idx,
15076        trailing,
15077    );
15078}
15079
15080/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
15081pub unsafe fn execute_dequant_matmul_gguf_f32(
15082    x: usize,
15083    w_q: usize,
15084    dst: usize,
15085    m: usize,
15086    k: usize,
15087    n: usize,
15088    scheme: rlx_ir::quant::QuantScheme,
15089    base: *mut u8,
15090) {
15091    unsafe {
15092        let block_bytes = scheme.gguf_block_bytes() as usize;
15093        let block_elems = scheme.gguf_block_size() as usize;
15094        let total_bytes = (k * n) / block_elems * block_bytes;
15095        let xs = sl(x, base, m * k);
15096        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
15097        let out = sl_mut(dst, base, m * n);
15098        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
15099    }
15100}
15101
15102/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
15103pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
15104    input: usize,
15105    w_q: usize,
15106    expert_idx: usize,
15107    dst: usize,
15108    m: usize,
15109    k: usize,
15110    n: usize,
15111    num_experts: usize,
15112    scheme: rlx_ir::quant::QuantScheme,
15113    base: *mut u8,
15114) {
15115    unsafe {
15116        let block_bytes = scheme.gguf_block_bytes() as usize;
15117        let block_elems = scheme.gguf_block_size() as usize;
15118        let slab_bytes = (k * n) / block_elems * block_bytes;
15119        let xs = sl(input, base, m * k);
15120        let w_bytes =
15121            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
15122        let ids = sl(expert_idx, base, m);
15123        let out = sl_mut(dst, base, m * n);
15124        crate::gguf_matmul::gguf_grouped_matmul_bt(
15125            xs,
15126            w_bytes,
15127            ids,
15128            out,
15129            m,
15130            k,
15131            n,
15132            num_experts,
15133            scheme,
15134        );
15135    }
15136}
15137
15138/// Host-fallback entry for Int8 `Op::DequantMatMul` (Metal unified memory).
15139pub unsafe fn execute_dequant_matmul_int8_f32(
15140    x: usize,
15141    w_q: usize,
15142    scale: usize,
15143    zp: usize,
15144    dst: usize,
15145    m: usize,
15146    k: usize,
15147    n: usize,
15148    block_size: u32,
15149    is_asymmetric: bool,
15150    base: *mut u8,
15151) {
15152    let bs = block_size as usize;
15153    let n_blocks = k.div_ceil(bs);
15154    unsafe {
15155        let xs = sl(x, base, m * k);
15156        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const i8, k * n);
15157        let scales = sl(scale, base, n_blocks * n);
15158        let zps = if is_asymmetric {
15159            sl(zp, base, n_blocks * n)
15160        } else {
15161            &[][..]
15162        };
15163        let out = sl_mut(dst, base, m * n);
15164        dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
15165    }
15166}
15167
15168/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
15169pub unsafe fn execute_dequant_matmul_int4_f32(
15170    x: usize,
15171    w_q: usize,
15172    scale: usize,
15173    zp: usize,
15174    dst: usize,
15175    m: usize,
15176    k: usize,
15177    n: usize,
15178    block_size: u32,
15179    is_asymmetric: bool,
15180    base: *mut u8,
15181) {
15182    let bs = block_size as usize;
15183    let n_blocks = k.div_ceil(bs);
15184    unsafe {
15185        let xs = sl(x, base, m * k);
15186        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
15187        let scales = sl(scale, base, n_blocks * n);
15188        let zps = if is_asymmetric {
15189            sl(zp, base, n_blocks * n)
15190        } else {
15191            &[][..]
15192        };
15193        let out = sl_mut(dst, base, m * n);
15194        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
15195    }
15196}
15197
15198/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
15199pub unsafe fn execute_dequant_matmul_fp8_f32(
15200    x: usize,
15201    w_q: usize,
15202    scale: usize,
15203    dst: usize,
15204    m: usize,
15205    k: usize,
15206    n: usize,
15207    e5m2: bool,
15208    base: *mut u8,
15209) {
15210    unsafe {
15211        let xs = sl(x, base, m * k);
15212        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
15213        let scales = sl(scale, base, n);
15214        let out = sl_mut(dst, base, m * n);
15215        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
15216    }
15217}
15218
15219/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
15220pub unsafe fn execute_dequant_matmul_nvfp4_f32(
15221    x: usize,
15222    w_q: usize,
15223    scale: usize,
15224    global_scale: usize,
15225    dst: usize,
15226    m: usize,
15227    k: usize,
15228    n: usize,
15229    base: *mut u8,
15230) {
15231    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
15232    unsafe {
15233        let xs = sl(x, base, m * k);
15234        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
15235        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
15236        let gs = sl(global_scale, base, 1)[0];
15237        let out = sl_mut(dst, base, m * n);
15238        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
15239    }
15240}
15241
15242/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
15243pub unsafe fn execute_gated_delta_net_f16(
15244    q: usize,
15245    k: usize,
15246    v: usize,
15247    g: usize,
15248    beta: usize,
15249    state: usize,
15250    dst: usize,
15251    batch: usize,
15252    seq: usize,
15253    heads: usize,
15254    state_size: usize,
15255    base: *mut u8,
15256) {
15257    use half::f16;
15258    unsafe {
15259        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
15260            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
15261            raw.chunks_exact(2)
15262                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
15263                .collect()
15264        };
15265        let write_f16 = |off: usize, data: &[f32]| {
15266            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
15267            for (i, &v) in data.iter().enumerate() {
15268                let le = f16::from_f32(v).to_le_bytes();
15269                out[i * 2] = le[0];
15270                out[i * 2 + 1] = le[1];
15271            }
15272        };
15273
15274        let (b, s, h, n) = (batch, seq, heads, state_size);
15275        let q_f = read_f16(q, b * s * h * n);
15276        let k_f = read_f16(k, b * s * h * n);
15277        let v_f = read_f16(v, b * s * h * n);
15278        let g_f = read_f16(g, b * s * h);
15279        let b_f = read_f16(beta, b * s * h);
15280        let mut state_f = if state != 0 {
15281            read_f16(state, b * h * n * n)
15282        } else {
15283            vec![0f32; b * h * n * n]
15284        };
15285        let mut out_f = vec![0f32; b * s * h * n];
15286        let scale = 1.0f32 / (n as f32).sqrt();
15287        let mut sk_buf = vec![0f32; n];
15288        let mut owned_state = vec![0f32; h * n * n];
15289
15290        for bi in 0..b {
15291            let state_slice: &mut [f32] = if state != 0 {
15292                let start = bi * h * n * n;
15293                &mut state_f[start..start + h * n * n]
15294            } else {
15295                owned_state.fill(0.0);
15296                &mut owned_state
15297            };
15298
15299            for ti in 0..s {
15300                let qkv_step_base = bi * s * h * n + ti * h * n;
15301                let gb_step_base = bi * s * h + ti * h;
15302
15303                for hi in 0..h {
15304                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15305                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15306                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15307                    let g_t = g_f[gb_step_base + hi];
15308                    let beta_t = b_f[gb_step_base + hi];
15309
15310                    let s_base = hi * n * n;
15311                    let s_mat = &mut state_slice[s_base..s_base + n * n];
15312
15313                    let g_exp = g_t.exp();
15314                    for st in s_mat.iter_mut() {
15315                        *st *= g_exp;
15316                    }
15317
15318                    for j in 0..n {
15319                        let mut acc = 0f32;
15320                        for i in 0..n {
15321                            acc += s_mat[i * n + j] * k_row[i];
15322                        }
15323                        sk_buf[j] = acc;
15324                    }
15325
15326                    for j in 0..n {
15327                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
15328                    }
15329
15330                    for i in 0..n {
15331                        let ki = k_row[i];
15332                        for j in 0..n {
15333                            s_mat[i * n + j] += ki * sk_buf[j];
15334                        }
15335                    }
15336
15337                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
15338                    for j in 0..n {
15339                        let mut acc = 0f32;
15340                        for i in 0..n {
15341                            acc += s_mat[i * n + j] * q_row[i];
15342                        }
15343                        out_row[j] = acc * scale;
15344                    }
15345                }
15346            }
15347        }
15348
15349        write_f16(dst, &out_f);
15350        if state != 0 {
15351            write_f16(state, &state_f);
15352        }
15353    }
15354}
15355
15356/// Host fallback for NCHW group norm (Metal unified-memory arena).
15357pub unsafe fn execute_group_norm_nchw_f32(
15358    src: usize,
15359    g: usize,
15360    b: usize,
15361    dst: usize,
15362    n: usize,
15363    c: usize,
15364    h: usize,
15365    w: usize,
15366    num_groups: usize,
15367    eps: f32,
15368    base: *mut u8,
15369) {
15370    let plane = c * h * w;
15371    for ni in 0..n {
15372        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
15373        let gamma = unsafe { sl(g, base, c) };
15374        let beta = unsafe { sl(b, base, c) };
15375        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
15376        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
15377    }
15378}
15379
15380/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
15381pub unsafe fn execute_layer_norm2d_nchw_f32(
15382    src: usize,
15383    g: usize,
15384    b: usize,
15385    dst: usize,
15386    n: usize,
15387    c: usize,
15388    h: usize,
15389    w: usize,
15390    eps: f32,
15391    base: *mut u8,
15392) {
15393    let plane = c * h * w;
15394    unsafe {
15395        let input = sl(src, base, n * plane);
15396        let gamma = sl(g, base, c);
15397        let beta = sl(b, base, c);
15398        let output = sl_mut(dst, base, n * plane);
15399        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
15400    }
15401}
15402
15403/// Host fallback for NCHW ConvTranspose2d.
15404pub unsafe fn execute_conv_transpose2d_nchw_f32(
15405    src: usize,
15406    weight: usize,
15407    dst: usize,
15408    n: usize,
15409    c_in: usize,
15410    h: usize,
15411    w_in: usize,
15412    c_out: usize,
15413    h_out: usize,
15414    w_out: usize,
15415    kh: usize,
15416    kw: usize,
15417    sh: usize,
15418    sw: usize,
15419    ph: usize,
15420    pw: usize,
15421    dh: usize,
15422    dw: usize,
15423    groups: usize,
15424    base: *mut u8,
15425) {
15426    let in_elems = n * c_in * h * w_in;
15427    let w_elems = c_in * (c_out / groups) * kh * kw;
15428    let out_elems = n * c_out * h_out * w_out;
15429    unsafe {
15430        let input = sl(src, base, in_elems);
15431        let wt = sl(weight, base, w_elems);
15432        let output = sl_mut(dst, base, out_elems);
15433        crate::kernels::conv_transpose2d_nchw(
15434            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
15435            dw, groups,
15436        );
15437    }
15438}
15439
15440/// Host fallback for nearest 2× upsample on NCHW.
15441pub unsafe fn execute_resize_nearest_2x_f32(
15442    src: usize,
15443    dst: usize,
15444    n: usize,
15445    c: usize,
15446    h: usize,
15447    w: usize,
15448    base: *mut u8,
15449) {
15450    let in_plane = c * h * w;
15451    let out_plane = c * h * 2 * w * 2;
15452    for ni in 0..n {
15453        let input = unsafe {
15454            sl(
15455                src + ni * in_plane * std::mem::size_of::<f32>(),
15456                base,
15457                in_plane,
15458            )
15459        };
15460        let output = unsafe {
15461            sl_mut(
15462                dst + ni * out_plane * std::mem::size_of::<f32>(),
15463                base,
15464                out_plane,
15465            )
15466        };
15467        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
15468    }
15469}
15470
15471/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
15472pub unsafe fn execute_axial_rope2d_f32(
15473    src: usize,
15474    dst: usize,
15475    batch: usize,
15476    seq: usize,
15477    hidden: usize,
15478    end_x: usize,
15479    end_y: usize,
15480    head_dim: usize,
15481    num_heads: usize,
15482    theta: f32,
15483    repeat_factor: usize,
15484    base: *mut u8,
15485) {
15486    let plane = seq * hidden;
15487    let plane_bytes = plane * std::mem::size_of::<f32>();
15488    for bi in 0..batch {
15489        let in_off = src + bi * plane_bytes;
15490        let input = unsafe { sl(in_off, base, plane) };
15491        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
15492            input,
15493            num_heads,
15494            seq,
15495            head_dim,
15496            end_x,
15497            end_y,
15498            theta,
15499            repeat_factor,
15500        );
15501        let out_off = dst + bi * plane_bytes;
15502        let output = unsafe { sl_mut(out_off, base, plane) };
15503        output.copy_from_slice(&rotated);
15504    }
15505}
15506
15507/// Ternary pruned radix-2 butterfly stage on `[batch, n_fft, 2]` interleaved state.
15508pub unsafe fn execute_fft_butterfly_stage_f32(
15509    state_src: usize,
15510    state_dst: usize,
15511    gate_src: usize,
15512    rev_src: usize,
15513    tw_re_src: usize,
15514    tw_im_src: usize,
15515    batch: usize,
15516    n_fft: usize,
15517    stage: usize,
15518    base: *mut u8,
15519) {
15520    let half = n_fft / 2;
15521    let stride = 1usize << stage;
15522    let gate = unsafe { sl(gate_src, base, half) };
15523    let rev = unsafe { sl(rev_src, base, half) };
15524    let tw_re = unsafe { sl(tw_re_src, base, half) };
15525    let tw_im = unsafe { sl(tw_im_src, base, half) };
15526    let row_elems = n_fft * 2;
15527    for b in 0..batch {
15528        let in_off = state_src + b * row_elems * std::mem::size_of::<f32>();
15529        let out_off = state_dst + b * row_elems * std::mem::size_of::<f32>();
15530        let inp = unsafe { sl(in_off, base, row_elems) };
15531        let out = unsafe { sl_mut(out_off, base, row_elems) };
15532        out.copy_from_slice(inp);
15533        for bf in 0..half {
15534            if gate[bf] == 0.0 {
15535                continue;
15536            }
15537            let group = bf / stride;
15538            let k = bf % stride;
15539            let i0 = group * 2 * stride + k;
15540            let i1 = i0 + stride;
15541            let w_re = tw_re[bf];
15542            let w_im = tw_im[bf];
15543            let in_a_re = inp[i0 * 2];
15544            let in_a_im = inp[i0 * 2 + 1];
15545            let in_b_re = inp[i1 * 2];
15546            let in_b_im = inp[i1 * 2 + 1];
15547            let (b_re, b_im) = (
15548                in_b_re * w_re - in_b_im * w_im,
15549                in_b_re * w_im + in_b_im * w_re,
15550            );
15551            let (top_re, top_im) = (in_a_re + b_re, in_a_im + b_im);
15552            let (bot_re, bot_im) = (in_a_re - b_re, in_a_im - b_im);
15553            let (oa_re, oa_im, ob_re, ob_im) = if rev[bf] >= 0.5 {
15554                (bot_re, bot_im, top_re, top_im)
15555            } else {
15556                (top_re, top_im, bot_re, bot_im)
15557            };
15558            out[i0 * 2] = oa_re;
15559            out[i0 * 2 + 1] = oa_im;
15560            out[i1 * 2] = ob_re;
15561            out[i1 * 2 + 1] = ob_im;
15562        }
15563    }
15564}
15565
15566/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
15567pub unsafe fn execute_fft1d_f32(
15568    src: usize,
15569    dst: usize,
15570    outer: usize,
15571    n_complex: usize,
15572    inverse: bool,
15573    norm_tag: u32,
15574    base: *mut u8,
15575) {
15576    let row_elems = 2 * n_complex;
15577    let mut re = vec![0f32; n_complex];
15578    let mut im = vec![0f32; n_complex];
15579    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
15580    let scale = norm.output_scale(n_complex, inverse) as f32;
15581    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
15582        BluesteinScratchF32::empty()
15583    } else {
15584        BluesteinScratchF32::build(n_complex, inverse)
15585    };
15586    for o in 0..outer {
15587        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
15588        let s = unsafe { sl(row_offset, base, row_elems) };
15589        re.copy_from_slice(&s[..n_complex]);
15590        im.copy_from_slice(&s[n_complex..]);
15591        if n_complex.is_power_of_two() {
15592            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
15593        } else if n_complex <= 16 {
15594            fft_naive_inplace_f32(&mut re, &mut im, inverse);
15595        } else {
15596            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
15597        }
15598        if scale != 1.0 {
15599            re.iter_mut().for_each(|v| *v *= scale);
15600            im.iter_mut().for_each(|v| *v *= scale);
15601        }
15602        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
15603        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
15604        d[..n_complex].copy_from_slice(&re);
15605        d[n_complex..].copy_from_slice(&im);
15606    }
15607}
15608
15609/// C64 interleaved layout: each complex element is `[re: f32, im: f32]`.
15610pub unsafe fn execute_fft1d_c64(
15611    src: usize,
15612    dst: usize,
15613    outer: usize,
15614    n_complex: usize,
15615    inverse: bool,
15616    norm_tag: u32,
15617    base: *mut u8,
15618) {
15619    let row_bytes = n_complex * 8;
15620    let mut re = vec![0f32; n_complex];
15621    let mut im = vec![0f32; n_complex];
15622    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
15623    let scale = norm.output_scale(n_complex, inverse) as f32;
15624    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
15625        BluesteinScratchF32::empty()
15626    } else {
15627        BluesteinScratchF32::build(n_complex, inverse)
15628    };
15629    for o in 0..outer {
15630        let row_offset = src + o * row_bytes;
15631        for i in 0..n_complex {
15632            let elem_off = row_offset + i * 8;
15633            re[i] = f32::from_le_bytes([
15634                *base.add(elem_off),
15635                *base.add(elem_off + 1),
15636                *base.add(elem_off + 2),
15637                *base.add(elem_off + 3),
15638            ]);
15639            im[i] = f32::from_le_bytes([
15640                *base.add(elem_off + 4),
15641                *base.add(elem_off + 5),
15642                *base.add(elem_off + 6),
15643                *base.add(elem_off + 7),
15644            ]);
15645        }
15646        if n_complex.is_power_of_two() {
15647            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
15648        } else if n_complex <= 16 {
15649            fft_naive_inplace_f32(&mut re, &mut im, inverse);
15650        } else {
15651            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
15652        }
15653        if scale != 1.0 {
15654            re.iter_mut().for_each(|v| *v *= scale);
15655            im.iter_mut().for_each(|v| *v *= scale);
15656        }
15657        let dst_row = dst + o * row_bytes;
15658        for i in 0..n_complex {
15659            let elem_off = dst_row + i * 8;
15660            let re_b = re[i].to_le_bytes();
15661            let im_b = im[i].to_le_bytes();
15662            for j in 0..4 {
15663                *base.add(elem_off + j) = re_b[j];
15664                *base.add(elem_off + 4 + j) = im_b[j];
15665            }
15666        }
15667    }
15668}
15669
15670/// Dtype-dispatching host entry for `Op::LogMel` (shared by GPU host fallbacks).
15671pub unsafe fn execute_log_mel(
15672    spec: usize,
15673    filters: usize,
15674    dst: usize,
15675    outer: usize,
15676    n_fft: usize,
15677    n_bins: usize,
15678    n_mels: usize,
15679    base: *mut u8,
15680) {
15681    execute_log_mel_f32(spec, filters, dst, outer, n_fft, n_bins, n_mels, base);
15682}
15683
15684pub unsafe fn execute_log_mel_f32(
15685    spec: usize,
15686    filters: usize,
15687    dst: usize,
15688    outer: usize,
15689    n_fft: usize,
15690    n_bins: usize,
15691    n_mels: usize,
15692    base: *mut u8,
15693) {
15694    let spec_ptr = base.add(spec) as *const f32;
15695    let filt_ptr = base.add(filters) as *const f32;
15696    let dst_ptr = base.add(dst) as *mut f32;
15697    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15698    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
15699    let out = std::slice::from_raw_parts_mut(dst_ptr, outer * n_mels);
15700    rlx_ir::audio::log_mel_block_f32(spec, filters, outer, n_fft, n_bins, n_mels, out);
15701}
15702
15703pub unsafe fn execute_welch_peaks_f32(
15704    spec: usize,
15705    dst: usize,
15706    welch_batch: usize,
15707    n_fft: usize,
15708    n_segments: usize,
15709    k: usize,
15710    base: *mut u8,
15711) {
15712    let spec_ptr = base.add(spec) as *const f32;
15713    let dst_ptr = base.add(dst) as *mut f32;
15714    let outer = welch_batch * n_segments;
15715    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15716    let out = std::slice::from_raw_parts_mut(dst_ptr, welch_batch * k * 2);
15717    rlx_ir::audio::welch_peaks_block_f32(spec, welch_batch, n_fft, n_segments, k, out);
15718}
15719
15720pub unsafe fn execute_log_mel_backward_f32(
15721    spec: usize,
15722    filters: usize,
15723    dy: usize,
15724    dst: usize,
15725    outer: usize,
15726    n_fft: usize,
15727    n_bins: usize,
15728    n_mels: usize,
15729    base: *mut u8,
15730) {
15731    let spec_ptr = base.add(spec) as *const f32;
15732    let filt_ptr = base.add(filters) as *const f32;
15733    let dy_ptr = base.add(dy) as *const f32;
15734    let dst_ptr = base.add(dst) as *mut f32;
15735    let spec = std::slice::from_raw_parts(spec_ptr, outer * n_fft * 2);
15736    let filters = std::slice::from_raw_parts(filt_ptr, n_mels * n_bins);
15737    let dy = std::slice::from_raw_parts(dy_ptr, outer * n_mels);
15738    let d_spec = std::slice::from_raw_parts_mut(dst_ptr, outer * n_fft * 2);
15739    d_spec.fill(0.0);
15740    rlx_ir::audio::log_mel_block_vjp(spec, filters, dy, outer, n_fft, n_bins, n_mels, d_spec);
15741}
15742
15743/// Dtype-dispatching host entry for `Op::Fft` (shared by GPU host fallbacks).
15744pub unsafe fn execute_fft1d(
15745    src: usize,
15746    dst: usize,
15747    outer: usize,
15748    n_complex: usize,
15749    inverse: bool,
15750    norm_tag: u32,
15751    dtype: rlx_ir::DType,
15752    base: *mut u8,
15753) {
15754    match dtype {
15755        rlx_ir::DType::F32 => {
15756            execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
15757        }
15758        rlx_ir::DType::F64 => {
15759            execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
15760        }
15761        rlx_ir::DType::C64 => {
15762            execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
15763        }
15764        other => panic!("execute_fft1d: unsupported dtype {other:?}"),
15765    }
15766}
15767
15768/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
15769/// the f64 path; twiddle recurrence is kept in f64 so accumulated
15770/// rotation drift doesn't dominate the per-stage error budget at
15771/// larger N.
15772fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
15773    let n = re.len();
15774    debug_assert_eq!(im.len(), n);
15775    debug_assert!(
15776        n.is_power_of_two(),
15777        "fft_radix2_f32: n={n} must be a power of two"
15778    );
15779    if n <= 1 {
15780        return;
15781    }
15782
15783    let mut j = 0usize;
15784    for i in 1..n {
15785        let mut bit = n >> 1;
15786        while j & bit != 0 {
15787            j ^= bit;
15788            bit >>= 1;
15789        }
15790        j ^= bit;
15791        if i < j {
15792            re.swap(i, j);
15793            im.swap(i, j);
15794        }
15795    }
15796
15797    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15798    let mut len = 2usize;
15799    while len <= n {
15800        let half = len / 2;
15801        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15802        let w_re_step = theta.cos();
15803        let w_im_step = theta.sin();
15804        let mut i = 0usize;
15805        while i < n {
15806            let mut wre = 1.0_f64;
15807            let mut wim = 0.0_f64;
15808            for k in 0..half {
15809                let wre_f = wre as f32;
15810                let wim_f = wim as f32;
15811                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
15812                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
15813                let u_re = re[i + k];
15814                let u_im = im[i + k];
15815                re[i + k] = u_re + t_re;
15816                im[i + k] = u_im + t_im;
15817                re[i + k + half] = u_re - t_re;
15818                im[i + k + half] = u_im - t_im;
15819                let new_wre = wre * w_re_step - wim * w_im_step;
15820                let new_wim = wre * w_im_step + wim * w_re_step;
15821                wre = new_wre;
15822                wim = new_wim;
15823            }
15824            i += len;
15825        }
15826        len <<= 1;
15827    }
15828}
15829
15830/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
15831/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
15832/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
15833fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15834    let n = re.len();
15835    debug_assert_eq!(im.len(), n);
15836    debug_assert!(
15837        n.is_power_of_two(),
15838        "fft_radix2: n={n} must be a power of two"
15839    );
15840    if n <= 1 {
15841        return;
15842    }
15843
15844    // Bit-reverse permutation.
15845    let mut j = 0usize;
15846    for i in 1..n {
15847        let mut bit = n >> 1;
15848        while j & bit != 0 {
15849            j ^= bit;
15850            bit >>= 1;
15851        }
15852        j ^= bit;
15853        if i < j {
15854            re.swap(i, j);
15855            im.swap(i, j);
15856        }
15857    }
15858
15859    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
15860    let sign = if inverse { 1.0 } else { -1.0 };
15861    let mut len = 2usize;
15862    while len <= n {
15863        let half = len / 2;
15864        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
15865        let w_re_step = theta.cos();
15866        let w_im_step = theta.sin();
15867        let mut i = 0usize;
15868        while i < n {
15869            // Twiddle starts at 1+0i for each segment.
15870            let mut wre = 1.0_f64;
15871            let mut wim = 0.0_f64;
15872            for k in 0..half {
15873                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
15874                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
15875                let u_re = re[i + k];
15876                let u_im = im[i + k];
15877                re[i + k] = u_re + t_re;
15878                im[i + k] = u_im + t_im;
15879                re[i + k + half] = u_re - t_re;
15880                im[i + k + half] = u_im - t_im;
15881                let new_wre = wre * w_re_step - wim * w_im_step;
15882                let new_wim = wre * w_im_step + wim * w_re_step;
15883                wre = new_wre;
15884                wim = new_wim;
15885            }
15886            i += len;
15887        }
15888        len <<= 1;
15889    }
15890}
15891
15892/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
15893/// Built once per call to `execute_fft1d_f64` and reused across rows
15894/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
15895struct BluesteinScratchF64 {
15896    /// Power-of-two convolution length, ≥ 2N - 1.
15897    m: usize,
15898    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
15899    /// the requested direction. Forward chirp on the way in, output
15900    /// chirp on the way out.
15901    w_re: Vec<f64>,
15902    w_im: Vec<f64>,
15903    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
15904    /// Doesn't depend on the input — precomputed once.
15905    bf_re: Vec<f64>,
15906    bf_im: Vec<f64>,
15907    /// Workspace reused per row (avoids per-row allocation).
15908    ar: Vec<f64>,
15909    ai: Vec<f64>,
15910}
15911
15912impl BluesteinScratchF64 {
15913    fn empty() -> Self {
15914        Self {
15915            m: 0,
15916            w_re: Vec::new(),
15917            w_im: Vec::new(),
15918            bf_re: Vec::new(),
15919            bf_im: Vec::new(),
15920            ar: Vec::new(),
15921            ai: Vec::new(),
15922        }
15923    }
15924
15925    fn build(n: usize, inverse: bool) -> Self {
15926        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
15927        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
15928        let m = if n <= 1 {
15929            1
15930        } else {
15931            (2 * n - 1).next_power_of_two()
15932        };
15933
15934        // Chirp arg reduced via k² mod 2N — without this, large N
15935        // bleeds precision into the trig call (n² grows quadratically).
15936        let mod_2n = (2 * n) as u64;
15937        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
15938        let mut w_re = vec![0.0_f64; n];
15939        let mut w_im = vec![0.0_f64; n];
15940        for k in 0..n {
15941            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
15942            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
15943            w_re[k] = theta.cos();
15944            w_im[k] = theta.sin();
15945        }
15946
15947        // Embed b[k] = conj(w[|k|]) into length M with the negative
15948        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
15949        let mut bf_re = vec![0.0_f64; m];
15950        let mut bf_im = vec![0.0_f64; m];
15951        if n > 0 {
15952            bf_re[0] = w_re[0];
15953            bf_im[0] = -w_im[0];
15954            for k in 1..n {
15955                bf_re[k] = w_re[k];
15956                bf_im[k] = -w_im[k];
15957                bf_re[m - k] = w_re[k];
15958                bf_im[m - k] = -w_im[k];
15959            }
15960        }
15961        if m > 1 {
15962            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
15963        }
15964
15965        Self {
15966            m,
15967            w_re,
15968            w_im,
15969            bf_re,
15970            bf_im,
15971            ar: vec![0.0_f64; m],
15972            ai: vec![0.0_f64; m],
15973        }
15974    }
15975}
15976
15977/// Direct O(N²) DFT for small non-pow2 N (faster than Bluestein setup).
15978fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
15979    let n = re.len();
15980    if n <= 1 {
15981        return;
15982    }
15983    let sign = if inverse { 1.0 } else { -1.0 };
15984    let mut out_re = vec![0.0_f64; n];
15985    let mut out_im = vec![0.0_f64; n];
15986    for k in 0..n {
15987        for nn in 0..n {
15988            let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
15989            let c = theta.cos();
15990            let s = theta.sin();
15991            out_re[k] += re[nn] * c - im[nn] * s;
15992            out_im[k] += re[nn] * s + im[nn] * c;
15993        }
15994    }
15995    re.copy_from_slice(&out_re);
15996    im.copy_from_slice(&out_im);
15997}
15998
15999fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
16000    let n = re.len();
16001    if n <= 1 {
16002        return;
16003    }
16004    let sign = if inverse { 1.0f32 } else { -1.0f32 };
16005    let mut out_re = vec![0.0_f32; n];
16006    let mut out_im = vec![0.0_f32; n];
16007    for k in 0..n {
16008        for nn in 0..n {
16009            let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
16010            let c = theta.cos();
16011            let s = theta.sin();
16012            out_re[k] += re[nn] * c - im[nn] * s;
16013            out_im[k] += re[nn] * s + im[nn] * c;
16014        }
16015    }
16016    re.copy_from_slice(&out_re);
16017    im.copy_from_slice(&out_im);
16018}
16019
16020/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
16021///   `n·k = (n² + k² - (k-n)²) / 2`
16022/// which lets the DFT be written as a linear convolution sandwiched
16023/// between two chirp multiplies:
16024///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
16025/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
16026/// Both directions stay unnormalized to match the radix-2 path, so the
16027/// chain rule keeps working without scaling.
16028fn fft_bluestein_inplace_f64(
16029    re: &mut [f64],
16030    im: &mut [f64],
16031    _inverse: bool,
16032    s: &mut BluesteinScratchF64,
16033) {
16034    let n = re.len();
16035    debug_assert_eq!(im.len(), n);
16036    debug_assert_eq!(s.w_re.len(), n);
16037    if n <= 1 {
16038        return;
16039    }
16040    let m = s.m;
16041
16042    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
16043    for k in 0..m {
16044        s.ar[k] = 0.0;
16045        s.ai[k] = 0.0;
16046    }
16047    for k in 0..n {
16048        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
16049        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
16050    }
16051
16052    // Length-M forward FFT of the padded chirped input.
16053    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
16054
16055    // Pointwise product with FFT(b). Stored back into (ar, ai).
16056    for k in 0..m {
16057        let ar = s.ar[k];
16058        let ai = s.ai[k];
16059        let br = s.bf_re[k];
16060        let bi = s.bf_im[k];
16061        s.ar[k] = ar * br - ai * bi;
16062        s.ai[k] = ar * bi + ai * br;
16063    }
16064
16065    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
16066    // divide by M to recover the true circular convolution.
16067    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
16068    let inv_m = 1.0 / (m as f64);
16069
16070    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
16071    for k in 0..n {
16072        let yr = s.ar[k] * inv_m;
16073        let yi = s.ai[k] * inv_m;
16074        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
16075        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
16076    }
16077}
16078
16079/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
16080/// precision (same justification as the radix-2 f32 path: twiddles in
16081/// f64, butterflies in f32). The actual conv buffers are f32.
16082struct BluesteinScratchF32 {
16083    m: usize,
16084    w_re: Vec<f32>,
16085    w_im: Vec<f32>,
16086    bf_re: Vec<f32>,
16087    bf_im: Vec<f32>,
16088    ar: Vec<f32>,
16089    ai: Vec<f32>,
16090}
16091
16092impl BluesteinScratchF32 {
16093    fn empty() -> Self {
16094        Self {
16095            m: 0,
16096            w_re: Vec::new(),
16097            w_im: Vec::new(),
16098            bf_re: Vec::new(),
16099            bf_im: Vec::new(),
16100            ar: Vec::new(),
16101            ai: Vec::new(),
16102        }
16103    }
16104
16105    fn build(n: usize, inverse: bool) -> Self {
16106        let m = if n <= 1 {
16107            1
16108        } else {
16109            (2 * n - 1).next_power_of_two()
16110        };
16111
16112        let mod_2n = (2 * n) as u64;
16113        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
16114        let mut w_re = vec![0.0_f32; n];
16115        let mut w_im = vec![0.0_f32; n];
16116        for k in 0..n {
16117            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
16118            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
16119            w_re[k] = theta.cos() as f32;
16120            w_im[k] = theta.sin() as f32;
16121        }
16122
16123        let mut bf_re = vec![0.0_f32; m];
16124        let mut bf_im = vec![0.0_f32; m];
16125        if n > 0 {
16126            bf_re[0] = w_re[0];
16127            bf_im[0] = -w_im[0];
16128            for k in 1..n {
16129                bf_re[k] = w_re[k];
16130                bf_im[k] = -w_im[k];
16131                bf_re[m - k] = w_re[k];
16132                bf_im[m - k] = -w_im[k];
16133            }
16134        }
16135        if m > 1 {
16136            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
16137        }
16138
16139        Self {
16140            m,
16141            w_re,
16142            w_im,
16143            bf_re,
16144            bf_im,
16145            ar: vec![0.0_f32; m],
16146            ai: vec![0.0_f32; m],
16147        }
16148    }
16149}
16150
16151fn fft_bluestein_inplace_f32(
16152    re: &mut [f32],
16153    im: &mut [f32],
16154    _inverse: bool,
16155    s: &mut BluesteinScratchF32,
16156) {
16157    let n = re.len();
16158    debug_assert_eq!(im.len(), n);
16159    debug_assert_eq!(s.w_re.len(), n);
16160    if n <= 1 {
16161        return;
16162    }
16163    let m = s.m;
16164
16165    for k in 0..m {
16166        s.ar[k] = 0.0;
16167        s.ai[k] = 0.0;
16168    }
16169    for k in 0..n {
16170        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
16171        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
16172    }
16173
16174    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
16175
16176    for k in 0..m {
16177        let ar = s.ar[k];
16178        let ai = s.ai[k];
16179        let br = s.bf_re[k];
16180        let bi = s.bf_im[k];
16181        s.ar[k] = ar * br - ai * bi;
16182        s.ai[k] = ar * bi + ai * br;
16183    }
16184
16185    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
16186    let inv_m = 1.0_f32 / (m as f32);
16187
16188    for k in 0..n {
16189        let yr = s.ar[k] * inv_m;
16190        let yi = s.ai[k] * inv_m;
16191        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
16192        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
16193    }
16194}
16195
16196/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
16197/// [`CpuTensorRef`] for each input *at that input's declared dtype*
16198/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
16199/// typed slices) and a [`CpuTensorMut`] for the output, then calls
16200/// the kernel's single `execute` method.
16201unsafe fn dispatch_custom_op(
16202    kernel: &dyn crate::op_registry::CpuKernel,
16203    inputs: &[(usize, u32, Shape)],
16204    out_off: usize,
16205    out_len: u32,
16206    out_shape: &Shape,
16207    attrs: &[u8],
16208    base: *mut u8,
16209) {
16210    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
16211    use rlx_ir::DType;
16212
16213    // One arm per `DType` variant — single source of truth for
16214    // "which dtypes the CPU custom-op dispatcher wires." If a new
16215    // DType lands in `rlx-ir`, the compiler flags this match as
16216    // non-exhaustive and the gap gets named at the right place.
16217    macro_rules! build_in_view {
16218        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
16219            CpuTensorRef::$variant {
16220                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
16221                shape: $shape,
16222            }
16223        };
16224    }
16225    macro_rules! build_out_view {
16226        ($variant:ident, $rust_ty:ty) => {
16227            CpuTensorMut::$variant {
16228                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
16229                shape: out_shape,
16230            }
16231        };
16232    }
16233
16234    let in_views: Vec<CpuTensorRef<'_>> = inputs
16235        .iter()
16236        .map(|(off, len, shape)| {
16237            let n = *len as usize;
16238            let off = *off;
16239            match shape.dtype() {
16240                DType::F32 => build_in_view!(shape, off, n, F32, f32),
16241                DType::F64 => build_in_view!(shape, off, n, F64, f64),
16242                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
16243                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
16244                DType::I8 => build_in_view!(shape, off, n, I8, i8),
16245                DType::I16 => build_in_view!(shape, off, n, I16, i16),
16246                DType::I32 => build_in_view!(shape, off, n, I32, i32),
16247                DType::I64 => build_in_view!(shape, off, n, I64, i64),
16248                DType::U8 => build_in_view!(shape, off, n, U8, u8),
16249                DType::U32 => build_in_view!(shape, off, n, U32, u32),
16250                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
16251                // C64 isn't a CpuTensor variant today; the user-registered
16252                // op_registry path doesn't see complex inputs (those are
16253                // handled by built-in ops with dedicated kernels).
16254                DType::C64 => panic!(
16255                    "Op::Custom kernel input has DType::C64 — built-in \
16256                 complex ops handle their own kernels; user-registered \
16257                 ops don't yet see complex tensors"
16258                ),
16259            }
16260        })
16261        .collect();
16262
16263    let result = match out_shape.dtype() {
16264        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
16265        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
16266        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
16267        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
16268        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
16269        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
16270        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
16271        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
16272        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
16273        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
16274        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
16275        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
16276    };
16277    if let Err(e) = result {
16278        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
16279    }
16280}
16281
16282/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
16283/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
16284/// uses them at call sites with concrete dtypes); the custom-op
16285/// dispatcher uses these to enumerate every `DType` uniformly without
16286/// listing one helper per dtype.
16287#[inline(always)]
16288unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
16289    if offset == usize::MAX {
16290        return &[];
16291    }
16292    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
16293}
16294
16295#[inline(always)]
16296unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
16297    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
16298}
16299
16300// Unsafe helpers to create slices from arena base + offset
16301#[inline(always)]
16302/// In-place per-element activation. Mirrors the dispatch in
16303/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
16304/// apply the activation after `bias_add` for all non-Gelu cases.
16305fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
16306    use rlx_ir::op::Activation;
16307    match act {
16308        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
16309        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
16310        Activation::Silu => crate::kernels::par_silu_inplace(d),
16311        Activation::Relu => {
16312            for v in d.iter_mut() {
16313                *v = v.max(0.0);
16314            }
16315        }
16316        Activation::Sigmoid => {
16317            for v in d.iter_mut() {
16318                *v = 1.0 / (1.0 + (-*v).exp());
16319            }
16320        }
16321        Activation::Tanh => {
16322            for v in d.iter_mut() {
16323                *v = v.tanh();
16324            }
16325        }
16326        Activation::Exp => {
16327            for v in d.iter_mut() {
16328                *v = v.exp();
16329            }
16330        }
16331        Activation::Log => {
16332            for v in d.iter_mut() {
16333                *v = v.ln();
16334            }
16335        }
16336        Activation::Sqrt => {
16337            for v in d.iter_mut() {
16338                *v = v.sqrt();
16339            }
16340        }
16341        Activation::Rsqrt => {
16342            for v in d.iter_mut() {
16343                *v = 1.0 / v.sqrt();
16344            }
16345        }
16346        Activation::Neg => {
16347            for v in d.iter_mut() {
16348                *v = -*v;
16349            }
16350        }
16351        Activation::Abs => {
16352            for v in d.iter_mut() {
16353                *v = v.abs();
16354            }
16355        }
16356        Activation::Round => {
16357            for v in d.iter_mut() {
16358                *v = v.round();
16359            }
16360        }
16361        Activation::Sin => {
16362            for v in d.iter_mut() {
16363                *v = v.sin();
16364            }
16365        }
16366        Activation::Cos => {
16367            for v in d.iter_mut() {
16368                *v = v.cos();
16369            }
16370        }
16371        Activation::Tan => {
16372            for v in d.iter_mut() {
16373                *v = v.tan();
16374            }
16375        }
16376        Activation::Atan => {
16377            for v in d.iter_mut() {
16378                *v = v.atan();
16379            }
16380        }
16381    }
16382}
16383
16384/// im2col for one image (single batch + group slice).
16385///
16386/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
16387/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
16388/// (in the padded region) are written as 0.
16389///
16390/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
16391///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
16392#[allow(clippy::too_many_arguments)]
16393fn im2col(
16394    x: &[f32],
16395    col: &mut [f32],
16396    c_in: usize,
16397    h: usize,
16398    w: usize,
16399    h_out: usize,
16400    w_out: usize,
16401    kh: usize,
16402    kw: usize,
16403    sh: usize,
16404    sw: usize,
16405    ph: usize,
16406    pw: usize,
16407    dh: usize,
16408    dw_dil: usize,
16409) {
16410    let n_dim = h_out * w_out;
16411    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
16412    debug_assert_eq!(x.len(), c_in * h * w);
16413    let h_isz = h as isize;
16414    let w_isz = w as isize;
16415    let ph_isz = ph as isize;
16416    let pw_isz = pw as isize;
16417    for ci in 0..c_in {
16418        for ki in 0..kh {
16419            for kj in 0..kw {
16420                let row = ((ci * kh) + ki) * kw + kj;
16421                let row_off = row * n_dim;
16422                for ho in 0..h_out {
16423                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
16424                    if hi < 0 || hi >= h_isz {
16425                        for wo in 0..w_out {
16426                            col[row_off + ho * w_out + wo] = 0.0;
16427                        }
16428                        continue;
16429                    }
16430                    let hi = hi as usize;
16431                    let in_row_off = (ci * h + hi) * w;
16432                    for wo in 0..w_out {
16433                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
16434                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
16435                            0.0
16436                        } else {
16437                            x[in_row_off + wi as usize]
16438                        };
16439                    }
16440                }
16441            }
16442        }
16443    }
16444}
16445
16446/// col2im — inverse of `im2col` with scatter-accumulation. The caller
16447/// is responsible for zeroing `x` if it doesn't already start zero
16448/// (the conv-input-grad path zeros once before the batch loop).
16449///
16450/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
16451/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
16452#[allow(clippy::too_many_arguments)]
16453fn col2im(
16454    col: &[f32],
16455    x: &mut [f32],
16456    c_in: usize,
16457    h: usize,
16458    w: usize,
16459    h_out: usize,
16460    w_out: usize,
16461    kh: usize,
16462    kw: usize,
16463    sh: usize,
16464    sw: usize,
16465    ph: usize,
16466    pw: usize,
16467    dh: usize,
16468    dw_dil: usize,
16469) {
16470    let n_dim = h_out * w_out;
16471    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
16472    debug_assert_eq!(x.len(), c_in * h * w);
16473    let h_isz = h as isize;
16474    let w_isz = w as isize;
16475    let ph_isz = ph as isize;
16476    let pw_isz = pw as isize;
16477    for ci in 0..c_in {
16478        for ki in 0..kh {
16479            for kj in 0..kw {
16480                let row = ((ci * kh) + ki) * kw + kj;
16481                let row_off = row * n_dim;
16482                for ho in 0..h_out {
16483                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
16484                    if hi < 0 || hi >= h_isz {
16485                        continue;
16486                    }
16487                    let hi = hi as usize;
16488                    let in_row_off = (ci * h + hi) * w;
16489                    for wo in 0..w_out {
16490                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
16491                        if wi < 0 || wi >= w_isz {
16492                            continue;
16493                        }
16494                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
16495                    }
16496                }
16497            }
16498        }
16499    }
16500}
16501
16502/// Element-wise backward for `Op::Activation`. `xs` is the original
16503/// input to the forward activation; `dys` is the upstream gradient.
16504/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
16505/// Decompose a per-channel quantization shape into the
16506/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
16507/// flat output index to a channel index. Per-tensor (`axis = None`)
16508/// degenerates to `chan_dim = 1, inner = len`, which makes the
16509/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
16510/// scalar version used.
16511fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
16512    match axis {
16513        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
16514        Some(d) => {
16515            let chan_dim = shape.dim(d).unwrap_static();
16516            let inner: usize = (d + 1..shape.rank())
16517                .map(|i| shape.dim(i).unwrap_static())
16518                .product::<usize>()
16519                .max(1);
16520            (d, chan_dim, inner)
16521        }
16522    }
16523}
16524
16525fn activation_backward_kernel(
16526    act: rlx_ir::op::Activation,
16527    xs: &[f32],
16528    dys: &[f32],
16529    out: &mut [f32],
16530) {
16531    use rlx_ir::op::Activation;
16532    let n = xs.len();
16533    debug_assert_eq!(dys.len(), n);
16534    debug_assert_eq!(out.len(), n);
16535    match act {
16536        Activation::Relu => {
16537            for i in 0..n {
16538                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
16539            }
16540        }
16541        Activation::Sigmoid => {
16542            for i in 0..n {
16543                let s = 1.0 / (1.0 + (-xs[i]).exp());
16544                out[i] = s * (1.0 - s) * dys[i];
16545            }
16546        }
16547        Activation::Tanh => {
16548            for i in 0..n {
16549                let t = xs[i].tanh();
16550                out[i] = (1.0 - t * t) * dys[i];
16551            }
16552        }
16553        Activation::Silu => {
16554            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
16555            for i in 0..n {
16556                let s = 1.0 / (1.0 + (-xs[i]).exp());
16557                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
16558            }
16559        }
16560        Activation::Gelu => {
16561            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
16562            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
16563            const INV_SQRT2: f32 = 0.707_106_77;
16564            const INV_SQRT_2PI: f32 = 0.398_942_3;
16565            for i in 0..n {
16566                let x = xs[i];
16567                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
16568                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
16569                out[i] = (phi + x * pdf) * dys[i];
16570            }
16571        }
16572        Activation::GeluApprox => {
16573            // Tanh-approximation:
16574            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
16575            const C: f32 = 0.797_884_6; // √(2/π)
16576            const A: f32 = 0.044_715;
16577            for i in 0..n {
16578                let x = xs[i];
16579                let inner = C * (x + A * x * x * x);
16580                let t = inner.tanh();
16581                let dinner = C * (1.0 + 3.0 * A * x * x);
16582                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
16583                out[i] = d * dys[i];
16584            }
16585        }
16586        Activation::Exp => {
16587            for i in 0..n {
16588                out[i] = xs[i].exp() * dys[i];
16589            }
16590        }
16591        Activation::Log => {
16592            for i in 0..n {
16593                out[i] = dys[i] / xs[i];
16594            }
16595        }
16596        Activation::Sqrt => {
16597            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
16598            for i in 0..n {
16599                let s = xs[i].sqrt();
16600                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
16601            }
16602        }
16603        Activation::Rsqrt => {
16604            // d/dx (1/√x) = -0.5 · x^(-3/2).
16605            for i in 0..n {
16606                let s = xs[i].sqrt();
16607                out[i] = if s > 0.0 {
16608                    -0.5 * dys[i] / (xs[i] * s)
16609                } else {
16610                    0.0
16611                };
16612            }
16613        }
16614        Activation::Neg => {
16615            for i in 0..n {
16616                out[i] = -dys[i];
16617            }
16618        }
16619        Activation::Abs => {
16620            // sign(x); 0 at x=0.
16621            for i in 0..n {
16622                let x = xs[i];
16623                let s = if x > 0.0 {
16624                    1.0
16625                } else if x < 0.0 {
16626                    -1.0
16627                } else {
16628                    0.0
16629                };
16630                out[i] = s * dys[i];
16631            }
16632        }
16633        Activation::Round => {
16634            // STE: pretend the round was identity in the backward
16635            // pass. The round step has zero gradient almost
16636            // everywhere, so without this trick the optimizer can't
16637            // learn through it.
16638            out.copy_from_slice(dys);
16639        }
16640        Activation::Sin => {
16641            // d/dx sin(x) = cos(x).
16642            for i in 0..n {
16643                out[i] = xs[i].cos() * dys[i];
16644            }
16645        }
16646        Activation::Cos => {
16647            for i in 0..n {
16648                out[i] = -xs[i].sin() * dys[i];
16649            }
16650        }
16651        Activation::Tan => {
16652            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
16653            for i in 0..n {
16654                let t = xs[i].tan();
16655                out[i] = (1.0 + t * t) * dys[i];
16656            }
16657        }
16658        Activation::Atan => {
16659            // d/dx atan(x) = 1 / (1 + x²)
16660            for i in 0..n {
16661                let x = xs[i];
16662                out[i] = dys[i] / (1.0 + x * x);
16663            }
16664        }
16665    }
16666}
16667
16668/// f64 sibling of `activation_backward_kernel`. Same math, twice the
16669/// precision — used by f64 graphs where the f32 kernel reading bytes
16670/// as `&[f32]` would silently discard half of every f64 value.
16671fn activation_backward_kernel_f64(
16672    act: rlx_ir::op::Activation,
16673    xs: &[f64],
16674    dys: &[f64],
16675    out: &mut [f64],
16676) {
16677    use rlx_ir::op::Activation;
16678    let n = xs.len();
16679    debug_assert_eq!(dys.len(), n);
16680    debug_assert_eq!(out.len(), n);
16681    match act {
16682        Activation::Relu => {
16683            for i in 0..n {
16684                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
16685            }
16686        }
16687        Activation::Sigmoid => {
16688            for i in 0..n {
16689                let s = 1.0 / (1.0 + (-xs[i]).exp());
16690                out[i] = s * (1.0 - s) * dys[i];
16691            }
16692        }
16693        Activation::Tanh => {
16694            for i in 0..n {
16695                let t = xs[i].tanh();
16696                out[i] = (1.0 - t * t) * dys[i];
16697            }
16698        }
16699        Activation::Silu => {
16700            for i in 0..n {
16701                let s = 1.0 / (1.0 + (-xs[i]).exp());
16702                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
16703            }
16704        }
16705        Activation::Gelu | Activation::GeluApprox => {
16706            // Both rare on f64 paths; use the high-quality libm erf.
16707            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
16708            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
16709            for i in 0..n {
16710                let x = xs[i];
16711                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
16712                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
16713                out[i] = (phi + x * pdf) * dys[i];
16714            }
16715        }
16716        Activation::Exp => {
16717            for i in 0..n {
16718                out[i] = xs[i].exp() * dys[i];
16719            }
16720        }
16721        Activation::Log => {
16722            for i in 0..n {
16723                out[i] = dys[i] / xs[i];
16724            }
16725        }
16726        Activation::Sqrt => {
16727            for i in 0..n {
16728                let s = xs[i].sqrt();
16729                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
16730            }
16731        }
16732        Activation::Rsqrt => {
16733            for i in 0..n {
16734                let s = xs[i].sqrt();
16735                out[i] = if s > 0.0 {
16736                    -0.5 * dys[i] / (xs[i] * s)
16737                } else {
16738                    0.0
16739                };
16740            }
16741        }
16742        Activation::Neg => {
16743            for i in 0..n {
16744                out[i] = -dys[i];
16745            }
16746        }
16747        Activation::Abs => {
16748            for i in 0..n {
16749                let x = xs[i];
16750                let s = if x > 0.0 {
16751                    1.0
16752                } else if x < 0.0 {
16753                    -1.0
16754                } else {
16755                    0.0
16756                };
16757                out[i] = s * dys[i];
16758            }
16759        }
16760        Activation::Round => {
16761            out.copy_from_slice(dys);
16762        }
16763        Activation::Sin => {
16764            for i in 0..n {
16765                out[i] = xs[i].cos() * dys[i];
16766            }
16767        }
16768        Activation::Cos => {
16769            for i in 0..n {
16770                out[i] = -xs[i].sin() * dys[i];
16771            }
16772        }
16773        Activation::Tan => {
16774            for i in 0..n {
16775                let t = xs[i].tan();
16776                out[i] = (1.0 + t * t) * dys[i];
16777            }
16778        }
16779        Activation::Atan => {
16780            for i in 0..n {
16781                let x = xs[i];
16782                out[i] = dys[i] / (1.0 + x * x);
16783            }
16784        }
16785    }
16786}
16787
16788/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
16789/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
16790/// arithmetic). Adequate for gradient kernels; if higher precision is
16791/// needed, swap in a libm dependency.
16792#[inline(always)]
16793fn erf_f64(x: f64) -> f64 {
16794    let s = x.signum();
16795    let x = x.abs();
16796    let t = 1.0 / (1.0 + 0.327_591_1 * x);
16797    let y = 1.0
16798        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
16799            + 0.254_829_59)
16800            * t
16801            * (-x * x).exp();
16802    s * y
16803}
16804
16805/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
16806/// over all of ℝ — plenty for f32 gradient kernels).
16807#[inline(always)]
16808fn erf_f32(x: f32) -> f32 {
16809    let s = x.signum();
16810    let x = x.abs();
16811    let t = 1.0 / (1.0 + 0.327_591_1 * x);
16812    let y = 1.0
16813        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
16814            + 0.254_829_6)
16815            * t
16816            * (-x * x).exp();
16817    s * y
16818}
16819
16820fn narrow_thunk_closure(
16821    src: usize,
16822    dst: usize,
16823    outer: u32,
16824    src_stride: u32,
16825    dst_stride: u32,
16826    inner: u32,
16827    elem_bytes: u8,
16828) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
16829    let (outer, ss, ds, inner, eb) = (
16830        outer as usize,
16831        src_stride as usize,
16832        dst_stride as usize,
16833        inner as usize,
16834        elem_bytes as usize,
16835    );
16836    let row_bytes = inner.saturating_mul(eb);
16837    let src_row_stride = ss.saturating_mul(eb);
16838    let dst_row_stride = ds.saturating_mul(eb);
16839    Arc::new(move |base: *mut u8| unsafe {
16840        if row_bytes == 0 || src == dst {
16841            return;
16842        }
16843        // Compiled-fn path has no arena length; skip if offsets look bogus.
16844        let arena_len = usize::MAX;
16845        for o in 0..outer {
16846            let s_off = src + o * src_row_stride;
16847            let d_off = dst + o * dst_row_stride;
16848            if s_off == d_off {
16849                continue;
16850            }
16851            if s_off.saturating_add(row_bytes) > arena_len
16852                || d_off.saturating_add(row_bytes) > arena_len
16853            {
16854                break;
16855            }
16856            std::ptr::copy_nonoverlapping(base.add(s_off), base.add(d_off), row_bytes);
16857        }
16858    })
16859}
16860
16861unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
16862    if offset == usize::MAX {
16863        return &[];
16864    }
16865    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
16866}
16867
16868#[inline(always)]
16869unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
16870    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
16871}
16872
16873#[inline(always)]
16874unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
16875    if offset == usize::MAX {
16876        return &[];
16877    }
16878    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
16879}
16880
16881#[inline(always)]
16882unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
16883    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
16884}
16885
16886// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
16887// integer-tensor thunks that haven't landed yet (Sample, Gather index
16888// buffers); deleting them now would force re-deriving the unsafe
16889// boilerplate when the next int-typed thunk lands.
16890#[inline(always)]
16891#[allow(dead_code)]
16892unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
16893    if offset == usize::MAX {
16894        return &[];
16895    }
16896    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
16897}
16898
16899#[inline(always)]
16900#[allow(dead_code)]
16901unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
16902    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
16903}
16904
16905#[inline(always)]
16906unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
16907    if offset == usize::MAX {
16908        return &[];
16909    }
16910    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
16911}
16912
16913#[inline(always)]
16914unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
16915    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
16916}
16917
16918/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
16919/// the output shape; `in_strides` gives the source stride for each
16920/// output dim (broadcast axes have stride 0).
16921fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
16922    let rank = out_dims.len();
16923    let mut idx = vec![0u32; rank];
16924    for o in 0..out.len() {
16925        let mut src_off = 0usize;
16926        for d in 0..rank {
16927            src_off += idx[d] as usize * in_strides[d] as usize;
16928        }
16929        out[o] = inp[broadcast_src_index(src_off, inp.len())];
16930        // Increment index — last dim varies fastest.
16931        for d in (0..rank).rev() {
16932            idx[d] += 1;
16933            if idx[d] < out_dims[d] {
16934                break;
16935            }
16936            idx[d] = 0;
16937        }
16938    }
16939}
16940
16941/// f64 elementwise activation. Reads `inp`, writes `out`. For now
16942/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
16943/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
16944/// transcendental-free subset). Approximate Gelu/Silu deferred until a
16945/// workload demands them at f64.
16946fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
16947    match kind {
16948        Activation::Neg => {
16949            for (o, &v) in out.iter_mut().zip(inp) {
16950                *o = -v;
16951            }
16952        }
16953        Activation::Exp => {
16954            for (o, &v) in out.iter_mut().zip(inp) {
16955                *o = v.exp();
16956            }
16957        }
16958        Activation::Log => {
16959            for (o, &v) in out.iter_mut().zip(inp) {
16960                *o = v.ln();
16961            }
16962        }
16963        Activation::Sqrt => {
16964            for (o, &v) in out.iter_mut().zip(inp) {
16965                *o = v.sqrt();
16966            }
16967        }
16968        Activation::Rsqrt => {
16969            for (o, &v) in out.iter_mut().zip(inp) {
16970                *o = 1.0 / v.sqrt();
16971            }
16972        }
16973        Activation::Abs => {
16974            for (o, &v) in out.iter_mut().zip(inp) {
16975                *o = v.abs();
16976            }
16977        }
16978        Activation::Tanh => {
16979            for (o, &v) in out.iter_mut().zip(inp) {
16980                *o = v.tanh();
16981            }
16982        }
16983        Activation::Sigmoid => {
16984            for (o, &v) in out.iter_mut().zip(inp) {
16985                *o = 1.0 / (1.0 + (-v).exp());
16986            }
16987        }
16988        Activation::Relu => {
16989            for (o, &v) in out.iter_mut().zip(inp) {
16990                *o = v.max(0.0);
16991            }
16992        }
16993        Activation::Round => {
16994            for (o, &v) in out.iter_mut().zip(inp) {
16995                *o = v.round_ties_even();
16996            }
16997        }
16998        Activation::Sin => {
16999            for (o, &v) in out.iter_mut().zip(inp) {
17000                *o = v.sin();
17001            }
17002        }
17003        Activation::Cos => {
17004            for (o, &v) in out.iter_mut().zip(inp) {
17005                *o = v.cos();
17006            }
17007        }
17008        Activation::Tan => {
17009            for (o, &v) in out.iter_mut().zip(inp) {
17010                *o = v.tan();
17011            }
17012        }
17013        Activation::Atan => {
17014            for (o, &v) in out.iter_mut().zip(inp) {
17015                *o = v.atan();
17016            }
17017        }
17018        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
17019            panic!(
17020                "apply_activation_f64: {kind:?} not yet implemented at f64. \
17021                    Add when a workload needs it."
17022            );
17023        }
17024    }
17025}
17026
17027#[inline]
17028fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
17029    match op {
17030        BinaryOp::Add => a + b,
17031        BinaryOp::Sub => a - b,
17032        BinaryOp::Mul => a * b,
17033        BinaryOp::Div => a / b,
17034        BinaryOp::Max => a.max(b),
17035        BinaryOp::Min => a.min(b),
17036        BinaryOp::Pow => a.powf(b),
17037    }
17038}
17039
17040/// f64 sum reduction over a contiguous middle range.
17041/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
17042fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
17043    for o in 0..outer {
17044        for n in 0..inner {
17045            let mut acc = 0.0_f64;
17046            for r in 0..reduced {
17047                acc += inp[o * reduced * inner + r * inner + n];
17048            }
17049            out[o * inner + n] = acc;
17050        }
17051    }
17052}
17053
17054/// Host-side RNG fill against a byte arena (Metal/CUDA unified-memory fallback).
17055///
17056/// # Safety
17057///
17058/// `arena` must point to a valid allocation with at least `dst_off + len * 4` bytes.
17059pub unsafe fn fill_rng_normal_arena(
17060    dst_off: usize,
17061    len: usize,
17062    mean: f32,
17063    scale: f32,
17064    key: u64,
17065    op_seed: Option<f32>,
17066    opts: rlx_ir::RngOptions,
17067    arena: *mut u8,
17068) {
17069    if len == 0 {
17070        return;
17071    }
17072    unsafe {
17073        let out = std::slice::from_raw_parts_mut((arena.add(dst_off)) as *mut f32, len);
17074        rlx_ir::fill_normal_like(out, mean, scale, opts, key, op_seed);
17075    }
17076}
17077
17078pub unsafe fn fill_rng_uniform_arena(
17079    dst_off: usize,
17080    len: usize,
17081    low: f32,
17082    high: f32,
17083    key: u64,
17084    op_seed: Option<f32>,
17085    opts: rlx_ir::RngOptions,
17086    arena: *mut u8,
17087) {
17088    if len == 0 {
17089        return;
17090    }
17091    unsafe {
17092        let out = std::slice::from_raw_parts_mut((arena.add(dst_off)) as *mut f32, len);
17093        rlx_ir::fill_uniform_like(out, low, high, opts, key, op_seed);
17094    }
17095}
17096
17097#[cfg(test)]
17098mod tests {
17099    use super::*;
17100    use rlx_ir::*;
17101
17102    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
17103    /// fusion pass collapses them — the Narrow becomes Nop, and the
17104    /// Rope reads from the parent buffer with its row stride. This
17105    /// test runs the unfused path (batch*seq > FusedAttnBlock
17106    /// threshold) and asserts the rewrite happened.
17107    #[test]
17108    fn narrow_rope_fuses_in_unfused_path() {
17109        let f = DType::F32;
17110        let mut g = Graph::new("nr_fuse");
17111        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
17112        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
17113        let cos = g.input("cos", Shape::new(&[16], f));
17114        let sin = g.input("sin", Shape::new(&[16], f));
17115        // Last-axis narrow: Q = qkv[..., 0..64]
17116        let q = g.narrow_(qkv, 2, 0, 64);
17117        let q_rope = g.rope(q, cos, sin, 16);
17118        g.set_outputs(vec![q_rope]);
17119
17120        let plan = rlx_opt::memory::plan_memory(&g);
17121        let arena = crate::arena::Arena::from_plan(plan);
17122        let sched = compile_thunks(&g, &arena);
17123
17124        let mut narrow_count = 0;
17125        let mut rope_with_stride: Option<u32> = None;
17126        for t in &sched.thunks {
17127            match t {
17128                Thunk::Narrow { .. } => narrow_count += 1,
17129                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
17130                _ => {}
17131            }
17132        }
17133        // After fusion the Narrow is gone; only the Rope remains, and
17134        // it now walks with the parent QKV's row stride (3 * 64 = 192).
17135        assert_eq!(
17136            narrow_count, 0,
17137            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
17138        );
17139        assert_eq!(
17140            rope_with_stride,
17141            Some(192),
17142            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
17143        );
17144    }
17145
17146    /// Plan #15: SSM selective scan matches a naive Python-style
17147    /// Python-style sequential reference.
17148    #[test]
17149    fn ssm_selective_scan_matches_reference() {
17150        use rlx_ir::Philox4x32;
17151        let bch = 1usize;
17152        let s = 4usize;
17153        let h = 3usize;
17154        let n = 2usize;
17155
17156        let mut rng = Philox4x32::new(13);
17157        let mut x = vec![0f32; bch * s * h];
17158        rng.fill_normal(&mut x);
17159        let mut delta = vec![0f32; bch * s * h];
17160        // Keep Δ small so exp(Δ·A) doesn't blow up.
17161        for v in delta.iter_mut() {
17162            *v = (rng.next_f32() - 0.5) * 0.1;
17163        }
17164        let mut a = vec![0f32; h * n];
17165        for v in a.iter_mut() {
17166            *v = -(rng.next_f32() * 0.5 + 0.1);
17167        } // negative for stability
17168        let mut b = vec![0f32; bch * s * n];
17169        rng.fill_normal(&mut b);
17170        let mut c = vec![0f32; bch * s * n];
17171        rng.fill_normal(&mut c);
17172
17173        // Reference scan.
17174        let mut expected = vec![0f32; bch * s * h];
17175        for bi in 0..bch {
17176            let mut state = vec![0f32; h * n];
17177            for si in 0..s {
17178                for ci in 0..h {
17179                    let d = delta[bi * s * h + si * h + ci];
17180                    let xv = x[bi * s * h + si * h + ci];
17181                    let mut acc = 0f32;
17182                    for ni in 0..n {
17183                        let da = (d * a[ci * n + ni]).exp();
17184                        state[ci * n + ni] =
17185                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
17186                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
17187                    }
17188                    expected[bi * s * h + si * h + ci] = acc;
17189                }
17190            }
17191        }
17192
17193        // RLX path.
17194        let f = DType::F32;
17195        let mut g = Graph::new("ssm");
17196        let xn = g.input("x", Shape::new(&[bch, s, h], f));
17197        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
17198        let an = g.param("a", Shape::new(&[h, n], f));
17199        let bn = g.param("b", Shape::new(&[bch, s, n], f));
17200        let cn = g.param("c", Shape::new(&[bch, s, n], f));
17201        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
17202        g.set_outputs(vec![yn]);
17203
17204        let plan = rlx_opt::memory::plan_memory(&g);
17205        let mut arena = crate::arena::Arena::from_plan(plan);
17206        let sched = compile_thunks(&g, &arena);
17207
17208        let xn_off = arena.byte_offset(xn);
17209        let dn_off = arena.byte_offset(dn);
17210        let an_off = arena.byte_offset(an);
17211        let bn_off = arena.byte_offset(bn);
17212        let cn_off = arena.byte_offset(cn);
17213        let yn_off = arena.byte_offset(yn);
17214        let buf = arena.raw_buf_mut();
17215        unsafe {
17216            let copy = |dst: *mut f32, data: &[f32]| {
17217                for (i, &v) in data.iter().enumerate() {
17218                    *dst.add(i) = v;
17219                }
17220            };
17221            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
17222            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
17223            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
17224            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
17225            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
17226        }
17227        execute_thunks(&sched, arena.raw_buf_mut());
17228
17229        let actual: Vec<f32> = unsafe {
17230            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
17231            (0..bch * s * h).map(|i| *p.add(i)).collect()
17232        };
17233
17234        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17235            assert!(
17236                (e - a).abs() < 1e-3,
17237                "mismatch at {i}: expected {e}, got {a}"
17238            );
17239        }
17240    }
17241
17242    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
17243    /// scalar 7-loop reference.
17244    #[test]
17245    fn conv_1x1_fast_path_matches_scalar() {
17246        use rlx_ir::Philox4x32;
17247        // [N=2, C_in=4, H=3, W=3]
17248        let n = 2usize;
17249        let c_in = 4usize;
17250        let h = 3usize;
17251        let w = 3usize;
17252        let c_out = 5usize;
17253        let mut rng = Philox4x32::new(31);
17254        let mut x = vec![0f32; n * c_in * h * w];
17255        rng.fill_normal(&mut x);
17256        let mut weight = vec![0f32; c_out * c_in];
17257        rng.fill_normal(&mut weight);
17258
17259        // Reference: scalar 1×1 conv = per-batch matmul
17260        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
17261        let mut expected = vec![0f32; n * c_out * h * w];
17262        for ni in 0..n {
17263            for co in 0..c_out {
17264                for hi in 0..h {
17265                    for wi in 0..w {
17266                        let mut acc = 0f32;
17267                        for ci in 0..c_in {
17268                            acc += weight[co * c_in + ci]
17269                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
17270                        }
17271                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
17272                    }
17273                }
17274            }
17275        }
17276
17277        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
17278        let f = DType::F32;
17279        let mut g = Graph::new("conv1x1");
17280        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17281        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
17282        // Manually add Op::Conv since there's no `g.conv()` helper.
17283        let cn = g.add_node(
17284            rlx_ir::Op::Conv {
17285                kernel_size: vec![1, 1],
17286                stride: vec![1, 1],
17287                padding: vec![0, 0],
17288                dilation: vec![1, 1],
17289                groups: 1,
17290            },
17291            vec![xn, wn],
17292            Shape::new(&[n, c_out, h, w], f),
17293        );
17294        g.set_outputs(vec![cn]);
17295
17296        let plan = rlx_opt::memory::plan_memory(&g);
17297        let mut arena = crate::arena::Arena::from_plan(plan);
17298        let sched = compile_thunks(&g, &arena);
17299
17300        // Verify the fast path was selected.
17301        let saw_fast = sched
17302            .thunks
17303            .iter()
17304            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
17305        let saw_slow = sched
17306            .thunks
17307            .iter()
17308            .any(|t| matches!(t, Thunk::Conv2D { .. }));
17309        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
17310        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
17311
17312        let xn_off = arena.byte_offset(xn);
17313        let wn_off = arena.byte_offset(wn);
17314        let cn_off = arena.byte_offset(cn);
17315        let buf = arena.raw_buf_mut();
17316        unsafe {
17317            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
17318            for (i, &v) in x.iter().enumerate() {
17319                *xp.add(i) = v;
17320            }
17321            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
17322            for (i, &v) in weight.iter().enumerate() {
17323                *wp.add(i) = v;
17324            }
17325        }
17326        execute_thunks(&sched, arena.raw_buf_mut());
17327
17328        let actual: Vec<f32> = unsafe {
17329            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
17330            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
17331        };
17332
17333        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17334            assert!(
17335                (e - a).abs() < 1e-3,
17336                "mismatch at {i}: expected {e}, got {a}"
17337            );
17338        }
17339    }
17340
17341    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
17342    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
17343    #[test]
17344    fn dequant_matmul_int8_sym_matches_reference() {
17345        use rlx_ir::Philox4x32;
17346        use rlx_ir::quant::QuantScheme;
17347
17348        let m = 3usize;
17349        let k = 8usize;
17350        let n = 4usize;
17351        let block_size = 4usize; // 2 blocks per column
17352        let blocks_per_col = k / block_size;
17353
17354        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
17355        let mut rng = Philox4x32::new(99);
17356        let mut x = vec![0f32; m * k];
17357        rng.fill_normal(&mut x);
17358        let w_q: Vec<i8> = (0..(k * n))
17359            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
17360            .collect();
17361        let scales: Vec<f32> = (0..(blocks_per_col * n))
17362            .map(|i| 0.01 + 0.001 * i as f32)
17363            .collect();
17364
17365        // Reference: build f32 weights from (q * scale) per block.
17366        let mut w_f32 = vec![0f32; k * n];
17367        for p in 0..k {
17368            let block = p / block_size;
17369            for j in 0..n {
17370                let s = scales[block * n + j];
17371                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
17372            }
17373        }
17374        let mut expected = vec![0f32; m * n];
17375        for i in 0..m {
17376            for j in 0..n {
17377                let mut acc = 0f32;
17378                for p in 0..k {
17379                    acc += x[i * k + p] * w_f32[p * n + j];
17380                }
17381                expected[i * n + j] = acc;
17382            }
17383        }
17384
17385        // RLX path.
17386        let f = DType::F32;
17387        let mut g = Graph::new("dq");
17388        let xn = g.input("x", Shape::new(&[m, k], f));
17389        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
17390        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
17391        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
17392        let dq = g.dequant_matmul(
17393            xn,
17394            wn,
17395            sn,
17396            zn,
17397            QuantScheme::Int8Block {
17398                block_size: block_size as u32,
17399            },
17400            Shape::new(&[m, n], f),
17401        );
17402        g.set_outputs(vec![dq]);
17403
17404        let plan = rlx_opt::memory::plan_memory(&g);
17405        let mut arena = crate::arena::Arena::from_plan(plan);
17406        let sched = compile_thunks(&g, &arena);
17407
17408        let xn_off = arena.byte_offset(xn);
17409        let wn_off = arena.byte_offset(wn);
17410        let sn_off = arena.byte_offset(sn);
17411        let zn_off = arena.byte_offset(zn);
17412        let dq_off = arena.byte_offset(dq);
17413        let buf = arena.raw_buf_mut();
17414        unsafe {
17415            // Seed f32 inputs.
17416            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
17417            for (i, &v) in x.iter().enumerate() {
17418                *xp.add(i) = v;
17419            }
17420            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
17421            for (i, &v) in scales.iter().enumerate() {
17422                *sp.add(i) = v;
17423            }
17424            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
17425            for i in 0..(blocks_per_col * n) {
17426                *zp.add(i) = 0.0;
17427            }
17428            // Seed i8 weights byte-by-byte.
17429            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
17430            for (i, &v) in w_q.iter().enumerate() {
17431                *wp.add(i) = v;
17432            }
17433        }
17434        execute_thunks(&sched, arena.raw_buf_mut());
17435
17436        let actual: Vec<f32> = unsafe {
17437            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
17438            (0..m * n).map(|i| *p.add(i)).collect()
17439        };
17440
17441        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17442            assert!(
17443                (e - a).abs() < 1e-3,
17444                "mismatch at {i}: expected {e}, got {a}"
17445            );
17446        }
17447    }
17448
17449    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
17450    #[test]
17451    fn lora_matmul_matches_unfused_reference() {
17452        use rlx_ir::Philox4x32;
17453
17454        let m = 4usize;
17455        let k = 8usize;
17456        let n = 6usize;
17457        let r = 2usize;
17458        let scale = 0.5f32;
17459
17460        // Random inputs (deterministic via Philox).
17461        let mut rng = Philox4x32::new(42);
17462        let mut x = vec![0f32; m * k];
17463        rng.fill_normal(&mut x);
17464        let mut w = vec![0f32; k * n];
17465        rng.fill_normal(&mut w);
17466        let mut a = vec![0f32; k * r];
17467        rng.fill_normal(&mut a);
17468        let mut b = vec![0f32; r * n];
17469        rng.fill_normal(&mut b);
17470
17471        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
17472        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
17473            let mut o = vec![0f32; rows * cols];
17474            for i in 0..rows {
17475                for j in 0..cols {
17476                    let mut acc = 0f32;
17477                    for p in 0..inner {
17478                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
17479                    }
17480                    o[i * cols + j] = acc;
17481                }
17482            }
17483            o
17484        };
17485        let xw = naive(&x, &w, m, k, n);
17486        let xa = naive(&x, &a, m, k, r);
17487        let xab = naive(&xa, &b, m, r, n);
17488        let mut expected = xw;
17489        for i in 0..(m * n) {
17490            expected[i] += scale * xab[i];
17491        }
17492
17493        // RLX path: build a graph with one LoraMatMul.
17494        let f = DType::F32;
17495        let mut g = Graph::new("lora");
17496        let xn = g.input("x", Shape::new(&[m, k], f));
17497        let wn = g.param("w", Shape::new(&[k, n], f));
17498        let an = g.param("a", Shape::new(&[k, r], f));
17499        let bn = g.param("b", Shape::new(&[r, n], f));
17500        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
17501        g.set_outputs(vec![lm]);
17502
17503        let plan = rlx_opt::memory::plan_memory(&g);
17504        let mut arena = crate::arena::Arena::from_plan(plan);
17505        let sched = compile_thunks(&g, &arena);
17506
17507        let xn_off = arena.byte_offset(xn);
17508        let wn_off = arena.byte_offset(wn);
17509        let an_off = arena.byte_offset(an);
17510        let bn_off = arena.byte_offset(bn);
17511        let lm_off = arena.byte_offset(lm);
17512        let buf = arena.raw_buf_mut();
17513        unsafe {
17514            let copy = |dst: *mut f32, data: &[f32]| {
17515                for (i, &v) in data.iter().enumerate() {
17516                    *dst.add(i) = v;
17517                }
17518            };
17519            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
17520            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
17521            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
17522            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
17523        }
17524        execute_thunks(&sched, arena.raw_buf_mut());
17525
17526        let actual: Vec<f32> = unsafe {
17527            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
17528            (0..m * n).map(|i| *p.add(i)).collect()
17529        };
17530
17531        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
17532            assert!(
17533                (e - a).abs() < 1e-3,
17534                "mismatch at {i}: expected {e}, got {a}"
17535            );
17536        }
17537    }
17538
17539    /// Plan #42: fused sampling kernel determinism + greedy fallback.
17540    #[test]
17541    fn sample_temperature_zero_is_argmax() {
17542        // Very low temperature → distribution collapses on argmax.
17543        // Same seed → same output bit-for-bit.
17544        let f = DType::F32;
17545        let mut g = Graph::new("samp");
17546        let logits = g.input("logits", Shape::new(&[1, 8], f));
17547        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
17548        g.set_outputs(vec![s]);
17549        let plan = rlx_opt::memory::plan_memory(&g);
17550        let mut arena = crate::arena::Arena::from_plan(plan);
17551        let sched = compile_thunks(&g, &arena);
17552
17553        let logits_off = arena.byte_offset(logits);
17554        let s_off = arena.byte_offset(s);
17555        let buf = arena.raw_buf_mut();
17556        unsafe {
17557            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
17558            // argmax = index 5 (value 9.0).
17559            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
17560            for (i, &v) in inputs.iter().enumerate() {
17561                *p.add(i) = v;
17562            }
17563        }
17564        execute_thunks(&sched, arena.raw_buf_mut());
17565
17566        let token = unsafe {
17567            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
17568            *p as usize
17569        };
17570        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
17571    }
17572
17573    #[test]
17574    fn sample_top_k_one_is_deterministic() {
17575        // top_k=1 forces only the argmax to have nonzero probability.
17576        let f = DType::F32;
17577        let mut g = Graph::new("samp_k1");
17578        let logits = g.input("logits", Shape::new(&[1, 4], f));
17579        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
17580        g.set_outputs(vec![s]);
17581        let plan = rlx_opt::memory::plan_memory(&g);
17582        let mut arena = crate::arena::Arena::from_plan(plan);
17583        let sched = compile_thunks(&g, &arena);
17584
17585        let logits_off = arena.byte_offset(logits);
17586        let s_off = arena.byte_offset(s);
17587        let buf = arena.raw_buf_mut();
17588        unsafe {
17589            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
17590            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
17591            for (i, &v) in inputs.iter().enumerate() {
17592                *p.add(i) = v;
17593            }
17594        }
17595        execute_thunks(&sched, arena.raw_buf_mut());
17596        let token = unsafe {
17597            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
17598            *p as usize
17599        };
17600        assert_eq!(token, 1);
17601    }
17602
17603    /// Plan #44: cumsum primitive parity vs. naive scan.
17604    #[test]
17605    fn cumsum_inclusive_matches_naive() {
17606        let f = DType::F32;
17607        let mut g = Graph::new("cumsum");
17608        let x = g.input("x", Shape::new(&[2, 4], f));
17609        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
17610        g.set_outputs(vec![cs]);
17611        let plan = rlx_opt::memory::plan_memory(&g);
17612        let mut arena = crate::arena::Arena::from_plan(plan);
17613        let sched = compile_thunks(&g, &arena);
17614
17615        // Cache offsets up-front so we can drop the immutable borrow.
17616        let x_off = arena.byte_offset(x);
17617        let out_off = arena.byte_offset(cs);
17618        let buf = arena.raw_buf_mut();
17619        unsafe {
17620            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
17621            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
17622            for (i, &v) in inputs.iter().enumerate() {
17623                *p.add(i) = v;
17624            }
17625        }
17626        execute_thunks(&sched, arena.raw_buf_mut());
17627
17628        let out: Vec<f32> = unsafe {
17629            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
17630            (0..8).map(|i| *p.add(i)).collect()
17631        };
17632        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
17633    }
17634
17635    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
17636    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
17637    /// path collapse into a single strided-Attention thunk.
17638    #[test]
17639    fn narrow_attention_fuses_in_unfused_path() {
17640        let f = DType::F32;
17641        let mut g = Graph::new("nattn_fuse");
17642        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
17643        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
17644        let mask = g.input("mask", Shape::new(&[8, 16], f));
17645        let q = g.narrow_(qkv, 2, 0, 64);
17646        let k = g.narrow_(qkv, 2, 64, 64);
17647        let v = g.narrow_(qkv, 2, 128, 64);
17648        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
17649        g.set_outputs(vec![attn]);
17650
17651        let plan = rlx_opt::memory::plan_memory(&g);
17652        let arena = crate::arena::Arena::from_plan(plan);
17653        let sched = compile_thunks(&g, &arena);
17654
17655        let mut narrow_count = 0;
17656        let mut attn_strides: Option<(u32, u32, u32)> = None;
17657        for t in &sched.thunks {
17658            match t {
17659                Thunk::Narrow { .. } => narrow_count += 1,
17660                Thunk::Attention {
17661                    q_row_stride,
17662                    k_row_stride,
17663                    v_row_stride,
17664                    ..
17665                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
17666                _ => {}
17667            }
17668        }
17669        // After fusion the 3 narrows are gone; Attention now walks the
17670        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
17671        assert_eq!(
17672            narrow_count, 0,
17673            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
17674        );
17675        assert_eq!(
17676            attn_strides,
17677            Some((192, 192, 192)),
17678            "Attention should walk Q/K/V with parent row stride 192"
17679        );
17680    }
17681
17682    // ── Backward / training op parity tests ────────────────────
17683    //
17684    // Strategy: build a graph that contains exactly the backward op
17685    // under test (plus its inputs as graph Inputs), execute, and
17686    // compare against a hand-rolled scalar reference. For
17687    // Conv2dBackwardInput we additionally check against the numerical
17688    // gradient of the forward Conv2D — that's the gold-standard test
17689    // that validates the math, not just consistency between two
17690    // implementations of the same formula.
17691
17692    fn run_graph(
17693        g: &Graph,
17694        inputs: &[(NodeId, &[f32])],
17695        out_id: NodeId,
17696        out_len: usize,
17697    ) -> Vec<f32> {
17698        let plan = rlx_opt::memory::plan_memory(g);
17699        let mut arena = crate::arena::Arena::from_plan(plan);
17700        let sched = compile_thunks(g, &arena);
17701        for &(id, data) in inputs {
17702            let off = arena.byte_offset(id);
17703            let buf = arena.raw_buf_mut();
17704            unsafe {
17705                let p = buf.as_mut_ptr().add(off) as *mut f32;
17706                for (i, &v) in data.iter().enumerate() {
17707                    *p.add(i) = v;
17708                }
17709            }
17710        }
17711        execute_thunks(&sched, arena.raw_buf_mut());
17712        let off = arena.byte_offset(out_id);
17713        unsafe {
17714            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
17715            (0..out_len).map(|i| *p.add(i)).collect()
17716        }
17717    }
17718
17719    #[test]
17720    fn relu_backward_matches_mask() {
17721        let f = DType::F32;
17722        let len = 7usize;
17723        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
17724        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
17725
17726        let mut g = Graph::new("relu_bw");
17727        let xn = g.input("x", Shape::new(&[len], f));
17728        let dyn_ = g.input("dy", Shape::new(&[len], f));
17729        let dx = g.relu_backward(xn, dyn_);
17730        g.set_outputs(vec![dx]);
17731
17732        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
17733        // Reference: gradient is dy where x>0 strictly, else 0.
17734        // (zero is not "positive" — the forward applied max(0, x), and at
17735        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
17736        let expected: Vec<f32> = x
17737            .iter()
17738            .zip(&dy)
17739            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
17740            .collect();
17741        for (a, e) in actual.iter().zip(&expected) {
17742            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
17743        }
17744    }
17745
17746    #[test]
17747    fn maxpool2d_backward_routes_to_argmax() {
17748        let f = DType::F32;
17749        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
17750        let x: Vec<f32> = vec![
17751            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17752        ];
17753        // Argmax of each 2x2 window:
17754        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
17755        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
17756        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
17757
17758        let mut g = Graph::new("maxpool_bw");
17759        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
17760        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
17761        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
17762        g.set_outputs(vec![dx]);
17763
17764        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
17765        let mut expected = vec![0f32; 16];
17766        expected[5] = 0.5;
17767        expected[7] = 1.0;
17768        expected[13] = 2.0;
17769        expected[15] = 4.0;
17770        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17771            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
17772        }
17773    }
17774
17775    #[test]
17776    fn conv2d_backward_input_matches_numerical_gradient() {
17777        use rlx_ir::Philox4x32;
17778        // Small enough to numerically differentiate exhaustively but
17779        // big enough to exercise stride/padding edge cases.
17780        let n = 1usize;
17781        let c_in = 2usize;
17782        let h = 4usize;
17783        let w = 4usize;
17784        let c_out = 3usize;
17785        let kh = 3usize;
17786        let kw = 3usize;
17787        let ph = 1usize;
17788        let pw = 1usize;
17789        let sh = 1usize;
17790        let sw = 1usize;
17791        // Output dims with padding=1, stride=1: same as input.
17792        let h_out = (h + 2 * ph - kh) / sh + 1;
17793        let w_out = (w + 2 * pw - kw) / sw + 1;
17794        assert_eq!(h_out, 4);
17795        assert_eq!(w_out, 4);
17796
17797        let mut rng = Philox4x32::new(7);
17798        let mut x = vec![0f32; n * c_in * h * w];
17799        rng.fill_normal(&mut x);
17800        let mut wt = vec![0f32; c_out * c_in * kh * kw];
17801        rng.fill_normal(&mut wt);
17802        let mut dy = vec![0f32; n * c_out * h_out * w_out];
17803        rng.fill_normal(&mut dy);
17804
17805        // Analytical: Conv2dBackwardInput on (dy, w).
17806        let f = DType::F32;
17807        let mut g = Graph::new("conv_bwi");
17808        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17809        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
17810        let dx = g.conv2d_backward_input(
17811            dy_in,
17812            w_in,
17813            Shape::new(&[n, c_in, h, w], f),
17814            vec![kh, kw],
17815            vec![sh, sw],
17816            vec![ph, pw],
17817            vec![1, 1],
17818            1,
17819        );
17820        g.set_outputs(vec![dx]);
17821        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
17822
17823        // Numerical: for each x[i], finite-difference forward conv twice.
17824        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
17825        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
17826        let forward = |x: &[f32]| -> Vec<f32> {
17827            let mut out = vec![0f32; n * c_out * h_out * w_out];
17828            for ni in 0..n {
17829                for co in 0..c_out {
17830                    for ho in 0..h_out {
17831                        for wo in 0..w_out {
17832                            let mut acc = 0f32;
17833                            for ci in 0..c_in {
17834                                for ki in 0..kh {
17835                                    for kj in 0..kw {
17836                                        let hi = ho * sh + ki;
17837                                        let wi = wo * sw + kj;
17838                                        if hi < ph || wi < pw {
17839                                            continue;
17840                                        }
17841                                        let hi = hi - ph;
17842                                        let wi = wi - pw;
17843                                        if hi >= h || wi >= w {
17844                                            continue;
17845                                        }
17846                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17847                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17848                                        acc += xv * wv;
17849                                    }
17850                                }
17851                            }
17852                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17853                        }
17854                    }
17855                }
17856            }
17857            out
17858        };
17859        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17860        let eps = 1e-3f32;
17861        let mut numerical = vec![0f32; x.len()];
17862        for i in 0..x.len() {
17863            let saved = x[i];
17864            x[i] = saved + eps;
17865            let plus = dot(&forward(&x), &dy);
17866            x[i] = saved - eps;
17867            let minus = dot(&forward(&x), &dy);
17868            x[i] = saved;
17869            numerical[i] = (plus - minus) / (2.0 * eps);
17870        }
17871        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17872            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
17873            assert!(
17874                (a - n).abs() < 5e-3,
17875                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
17876            );
17877        }
17878    }
17879
17880    #[test]
17881    fn conv2d_backward_weight_matches_numerical_gradient() {
17882        use rlx_ir::Philox4x32;
17883        let n = 2usize;
17884        let c_in = 2usize;
17885        let h = 4usize;
17886        let w = 4usize;
17887        let c_out = 2usize;
17888        let kh = 3usize;
17889        let kw = 3usize;
17890        let ph = 0usize;
17891        let pw = 0usize;
17892        let sh = 1usize;
17893        let sw = 1usize;
17894        let h_out = (h + 2 * ph - kh) / sh + 1;
17895        let w_out = (w + 2 * pw - kw) / sw + 1;
17896
17897        let mut rng = Philox4x32::new(11);
17898        let mut x = vec![0f32; n * c_in * h * w];
17899        rng.fill_normal(&mut x);
17900        let mut wt = vec![0f32; c_out * c_in * kh * kw];
17901        rng.fill_normal(&mut wt);
17902        let mut dy = vec![0f32; n * c_out * h_out * w_out];
17903        rng.fill_normal(&mut dy);
17904
17905        let f = DType::F32;
17906        let mut g = Graph::new("conv_bww");
17907        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
17908        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
17909        let dwn = g.conv2d_backward_weight(
17910            xn,
17911            dyn_,
17912            Shape::new(&[c_out, c_in, kh, kw], f),
17913            vec![kh, kw],
17914            vec![sh, sw],
17915            vec![ph, pw],
17916            vec![1, 1],
17917            1,
17918        );
17919        g.set_outputs(vec![dwn]);
17920        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
17921
17922        let forward = |wt: &[f32]| -> Vec<f32> {
17923            let mut out = vec![0f32; n * c_out * h_out * w_out];
17924            for ni in 0..n {
17925                for co in 0..c_out {
17926                    for ho in 0..h_out {
17927                        for wo in 0..w_out {
17928                            let mut acc = 0f32;
17929                            for ci in 0..c_in {
17930                                for ki in 0..kh {
17931                                    for kj in 0..kw {
17932                                        let hi = ho + ki;
17933                                        let wi = wo + kj;
17934                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
17935                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
17936                                        acc += xv * wv;
17937                                    }
17938                                }
17939                            }
17940                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
17941                        }
17942                    }
17943                }
17944            }
17945            out
17946        };
17947        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
17948        let eps = 1e-3f32;
17949        let mut numerical = vec![0f32; wt.len()];
17950        for i in 0..wt.len() {
17951            let saved = wt[i];
17952            wt[i] = saved + eps;
17953            let plus = dot(&forward(&wt), &dy);
17954            wt[i] = saved - eps;
17955            let minus = dot(&forward(&wt), &dy);
17956            wt[i] = saved;
17957            numerical[i] = (plus - minus) / (2.0 * eps);
17958        }
17959        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
17960            assert!(
17961                (a - n).abs() < 5e-3,
17962                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
17963            );
17964        }
17965    }
17966
17967    #[test]
17968    fn softmax_cross_entropy_matches_reference() {
17969        let f = DType::F32;
17970        let logits: Vec<f32> = vec![
17971            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
17972            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
17973            5.0, 5.0, 5.0, // row 2: uniform
17974        ];
17975        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
17976
17977        let mut g = Graph::new("sce");
17978        let lg = g.input("logits", Shape::new(&[3, 3], f));
17979        let lb = g.input("labels", Shape::new(&[3], f));
17980        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
17981        g.set_outputs(vec![loss]);
17982        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
17983
17984        // Reference per-row: -log(softmax(row)[label]).
17985        let mut expected = vec![0f32; 3];
17986        for ni in 0..3 {
17987            let row = &logits[ni * 3..(ni + 1) * 3];
17988            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
17989            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
17990            let lse = m + sum.ln();
17991            let label_idx = labels[ni] as usize;
17992            expected[ni] = lse - row[label_idx];
17993        }
17994        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
17995            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
17996        }
17997    }
17998
17999    #[test]
18000    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
18001        use rlx_ir::Philox4x32;
18002        let n = 4usize;
18003        let c = 5usize;
18004        let mut rng = Philox4x32::new(23);
18005        let mut logits = vec![0f32; n * c];
18006        rng.fill_normal(&mut logits);
18007        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
18008        let mut d_loss = vec![0f32; n];
18009        rng.fill_normal(&mut d_loss);
18010
18011        let f = DType::F32;
18012        let mut g = Graph::new("sce_bw");
18013        let lg = g.input("logits", Shape::new(&[n, c], f));
18014        let lb = g.input("labels", Shape::new(&[n], f));
18015        let dl = g.input("d_loss", Shape::new(&[n], f));
18016        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
18017        g.set_outputs(vec![dlogits]);
18018        let analytical = run_graph(
18019            &g,
18020            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
18021            dlogits,
18022            n * c,
18023        );
18024
18025        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
18026        let sce_loss = |logits: &[f32]| -> Vec<f32> {
18027            let mut out = vec![0f32; n];
18028            for ni in 0..n {
18029                let row = &logits[ni * c..(ni + 1) * c];
18030                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18031                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18032                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
18033            }
18034            out
18035        };
18036        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
18037        let eps = 1e-3f32;
18038        let mut numerical = vec![0f32; logits.len()];
18039        for i in 0..logits.len() {
18040            let saved = logits[i];
18041            logits[i] = saved + eps;
18042            let plus = dot(&sce_loss(&logits), &d_loss);
18043            logits[i] = saved - eps;
18044            let minus = dot(&sce_loss(&logits), &d_loss);
18045            logits[i] = saved;
18046            numerical[i] = (plus - minus) / (2.0 * eps);
18047        }
18048        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
18049            assert!(
18050                (a - num).abs() < 5e-3,
18051                "sce_bw[{i}]: analytical {a} vs numerical {num}"
18052            );
18053        }
18054    }
18055
18056    // ── End-to-end autodiff parity tests ──────────────────────
18057    //
18058    // Build a forward graph, run `grad_with_loss` to produce a graph
18059    // that emits [loss, gradients...], execute it through rlx-cpu,
18060    // and compare each gradient to a finite-difference estimate
18061    // produced by re-running the forward graph with each parameter
18062    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
18063    // 5e-3 absolute error.
18064
18065    /// Initialize Op::Constant slots in the arena with their literal
18066    /// data. Mirrors the loop in rlx_runtime::backend (which serves
18067    /// the same role for production runs).
18068    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
18069        for node in graph.nodes() {
18070            if let Op::Constant { data } = &node.op
18071                && arena.has_buffer(node.id)
18072                && !data.is_empty()
18073            {
18074                let buf = arena.slice_mut(node.id);
18075                let n_floats = data.len() / 4;
18076                let n = buf.len().min(n_floats);
18077                for i in 0..n {
18078                    let bytes = [
18079                        data[i * 4],
18080                        data[i * 4 + 1],
18081                        data[i * 4 + 2],
18082                        data[i * 4 + 3],
18083                    ];
18084                    buf[i] = f32::from_le_bytes(bytes);
18085                }
18086            }
18087        }
18088    }
18089
18090    /// Compile + arena-prep helper for these tests. Returns the
18091    /// schedule and a populated arena. `seed_inputs` writes f32 input
18092    /// data into the arena slot for each (NodeId, &[f32]) pair.
18093    fn prepare(
18094        graph: &Graph,
18095        seed_inputs: &[(NodeId, &[f32])],
18096    ) -> (ThunkSchedule, crate::arena::Arena) {
18097        let plan = rlx_opt::memory::plan_memory(graph);
18098        let mut arena = crate::arena::Arena::from_plan(plan);
18099        let sched = compile_thunks(graph, &arena);
18100        fill_constants_into_arena(graph, &mut arena);
18101        for &(id, data) in seed_inputs {
18102            let off = arena.byte_offset(id);
18103            let buf = arena.raw_buf_mut();
18104            unsafe {
18105                let p = buf.as_mut_ptr().add(off) as *mut f32;
18106                for (i, &v) in data.iter().enumerate() {
18107                    *p.add(i) = v;
18108                }
18109            }
18110        }
18111        (sched, arena)
18112    }
18113
18114    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
18115        let off = arena.byte_offset(id);
18116        unsafe {
18117            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18118            (0..len).map(|i| *p.add(i)).collect()
18119        }
18120    }
18121
18122    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
18123        let off = arena.byte_offset(id);
18124        let buf = arena.raw_buf_mut();
18125        unsafe {
18126            let p = buf.as_mut_ptr().add(off) as *mut f32;
18127            for (i, &v) in data.iter().enumerate() {
18128                *p.add(i) = v;
18129            }
18130        }
18131    }
18132
18133    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
18134    fn prepare_f64(
18135        graph: &Graph,
18136        seed_inputs: &[(NodeId, &[f64])],
18137    ) -> (ThunkSchedule, crate::arena::Arena) {
18138        let plan = rlx_opt::memory::plan_memory(graph);
18139        let mut arena = crate::arena::Arena::from_plan(plan);
18140        let sched = compile_thunks(graph, &arena);
18141        fill_constants_into_arena(graph, &mut arena);
18142        for &(id, data) in seed_inputs {
18143            let off = arena.byte_offset(id);
18144            let buf = arena.raw_buf_mut();
18145            unsafe {
18146                let p = buf.as_mut_ptr().add(off) as *mut f64;
18147                for (i, &v) in data.iter().enumerate() {
18148                    *p.add(i) = v;
18149                }
18150            }
18151        }
18152        (sched, arena)
18153    }
18154
18155    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
18156        let off = arena.byte_offset(id);
18157        unsafe {
18158            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
18159            (0..len).map(|i| *p.add(i)).collect()
18160        }
18161    }
18162
18163    /// End-to-end f64 DenseSolve through the full compile + execute
18164    /// path. Validates: IR shape inference, memory planner f64 sizing,
18165    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
18166    /// dispatch, Accelerate dgesv FFI.
18167    ///
18168    /// System:
18169    ///   A = [[2, 1],
18170    ///        [1, 3]]   b = [5, 10]
18171    ///   ⇒  x = [1, 3]   (verified by hand)
18172    #[test]
18173    fn dense_solve_f64_end_to_end() {
18174        let mut g = Graph::new("solve_e2e");
18175        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
18176        let b = g.input("b", Shape::new(&[2], DType::F64));
18177        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
18178        g.set_outputs(vec![x]);
18179
18180        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
18181        let b_data = [5.0, 10.0_f64];
18182        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18183        execute_thunks(&sched, arena.raw_buf_mut());
18184
18185        let got = read_arena_f64(&arena, x, 2);
18186        let want = [1.0, 3.0_f64];
18187        for i in 0..2 {
18188            assert!(
18189                (got[i] - want[i]).abs() < 1e-12,
18190                "x[{i}] = {} (expected {})",
18191                got[i],
18192                want[i]
18193            );
18194        }
18195    }
18196
18197    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
18198    /// MNA structure for a passive RC mesh in Circulax). Validates
18199    /// that the solve scales beyond the trivial 2×2 and that the
18200    /// row-major ↔ col-major dance in `dgesv` is correct for the
18201    /// general case.
18202    #[test]
18203    fn dense_solve_f64_5x5_laplacian() {
18204        let n = 5usize;
18205        let mut g = Graph::new("solve_5x5");
18206        let a = g.input("A", Shape::new(&[n, n], DType::F64));
18207        let b = g.input("b", Shape::new(&[n], DType::F64));
18208        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18209        g.set_outputs(vec![x]);
18210
18211        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
18212        let mut a_data = vec![0.0_f64; n * n];
18213        for i in 0..n {
18214            a_data[i * n + i] = 2.0;
18215            if i > 0 {
18216                a_data[i * n + (i - 1)] = -1.0;
18217            }
18218            if i + 1 < n {
18219                a_data[i * n + (i + 1)] = -1.0;
18220            }
18221        }
18222        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
18223        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
18224        execute_thunks(&sched, arena.raw_buf_mut());
18225
18226        let got = read_arena_f64(&arena, x, n);
18227        // Verify A·x ≈ b by computing the residual.
18228        let mut residual = vec![0.0_f64; n];
18229        for i in 0..n {
18230            for j in 0..n {
18231                residual[i] += a_data[i * n + j] * got[j];
18232            }
18233        }
18234        for i in 0..n {
18235            assert!(
18236                (residual[i] - b_data[i]).abs() < 1e-10,
18237                "row {i}: residual {} vs b {}",
18238                residual[i],
18239                b_data[i]
18240            );
18241        }
18242    }
18243
18244    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
18245    ///
18246    /// Forward:
18247    ///   A      : Param  [N, N]   f64
18248    ///   b      : Input  [N]      f64
18249    ///   x      = solve(A, b)            (DenseSolve)
18250    ///   loss   = sum(x)                 (Reduce::Sum)
18251    ///
18252    /// Backward (via grad_with_loss):
18253    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
18254    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
18255    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
18256    ///   db       = dx_int                       (DenseSolve VJP step 3)
18257    ///
18258    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
18259    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
18260    /// We verify this against the autodiff-emitted graph's output and
18261    /// against a finite-difference baseline.
18262    #[test]
18263    fn hello_resistor_gradient_end_to_end() {
18264        use rlx_opt::autodiff::grad_with_loss;
18265        let n = 3usize;
18266
18267        // ── Build forward graph ──
18268        let mut g = Graph::new("hello_resistor");
18269        let a = g.param("A", Shape::new(&[n, n], DType::F64));
18270        let b = g.input("b", Shape::new(&[n], DType::F64));
18271        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
18272        let loss = g.reduce(
18273            x,
18274            ReduceOp::Sum,
18275            vec![0],
18276            false,
18277            Shape::new(&[1], DType::F64),
18278        );
18279        g.set_outputs(vec![loss]);
18280
18281        // ── Run reverse-mode AD ──
18282        let bwd = grad_with_loss(&g, &[a, b]);
18283        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
18284
18285        // ── Locate the inputs the bwd graph still needs from us ──
18286        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
18287        // appear under their original names. Find them by name.
18288        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
18289            for node in graph.nodes() {
18290                let name = match &node.op {
18291                    rlx_ir::Op::Input { name } => Some(name.as_str()),
18292                    rlx_ir::Op::Param { name } => Some(name.as_str()),
18293                    _ => None,
18294                };
18295                if name == Some(want) {
18296                    return node.id;
18297                }
18298            }
18299            panic!("no node named {want:?} in bwd graph");
18300        };
18301        let a_bwd = find_by_name(&bwd, "A");
18302        let b_bwd = find_by_name(&bwd, "b");
18303        let d_out_bwd = find_by_name(&bwd, "d_output");
18304
18305        // ── Test data ──
18306        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
18307        // b = [1,2,3]
18308        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
18309        let b_data = [1.0, 2.0, 3.0_f64];
18310        let d_output = [1.0_f64]; // ∂loss/∂loss
18311
18312        // ── Compile + execute backward graph ──
18313        let (sched, mut arena) = prepare_f64(
18314            &bwd,
18315            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
18316        );
18317        execute_thunks(&sched, arena.raw_buf_mut());
18318
18319        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
18320        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
18321        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
18322
18323        // ── Closed-form reference ──
18324        // x = A⁻¹ b ; loss = sum(x).
18325        let x_ref = {
18326            let mut a = a_data;
18327            let mut b = b_data;
18328            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
18329            assert_eq!(info, 0);
18330            b
18331        };
18332        let loss_ref: f64 = x_ref.iter().sum();
18333        // db = (Aᵀ)⁻¹ · 1
18334        let db_ref = {
18335            let mut at = [0.0_f64; 9];
18336            for i in 0..n {
18337                for j in 0..n {
18338                    at[i * n + j] = a_data[j * n + i];
18339                }
18340            }
18341            let mut ones = [1.0_f64; 3];
18342            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
18343            assert_eq!(info, 0);
18344            ones
18345        };
18346        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
18347        let mut da_ref = [0.0_f64; 9];
18348        for i in 0..n {
18349            for j in 0..n {
18350                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
18351            }
18352        }
18353
18354        // ── Assertions vs analytic answer ──
18355        assert!(
18356            (loss_out[0] - loss_ref).abs() < 1e-10,
18357            "loss: got {}, want {}",
18358            loss_out[0],
18359            loss_ref
18360        );
18361        for i in 0..n {
18362            assert!(
18363                (db_out[i] - db_ref[i]).abs() < 1e-10,
18364                "db[{i}]: got {}, want {}",
18365                db_out[i],
18366                db_ref[i]
18367            );
18368        }
18369        for i in 0..n * n {
18370            assert!(
18371                (da_out[i] - da_ref[i]).abs() < 1e-10,
18372                "dA[{i}]: got {}, want {}",
18373                da_out[i],
18374                da_ref[i]
18375            );
18376        }
18377
18378        // ── Cross-check vs finite differences on db (a few entries) ──
18379        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
18380        let h = 1e-6_f64;
18381        for k in 0..n {
18382            let mut bp = b_data;
18383            bp[k] += h;
18384            let mut bm = b_data;
18385            bm[k] -= h;
18386            let lp = {
18387                let mut ac = a_data;
18388                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
18389                assert_eq!(info, 0);
18390                bp.iter().sum::<f64>()
18391            };
18392            let lm = {
18393                let mut ac = a_data;
18394                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
18395                assert_eq!(info, 0);
18396                bm.iter().sum::<f64>()
18397            };
18398            let fd = (lp - lm) / (2.0 * h);
18399            assert!(
18400                (db_out[k] - fd).abs() < 1e-7,
18401                "FD mismatch on db[{k}]: AD={} FD={}",
18402                db_out[k],
18403                fd
18404            );
18405        }
18406    }
18407
18408    /// Smallest possible Op::Scan basic test: geometric growth.
18409    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
18410    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
18411    /// to f64 precision.
18412    #[test]
18413    fn scan_geometric_growth_f64() {
18414        let n = 3usize;
18415        let length = 10u32;
18416
18417        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
18418        let mut body = Graph::new("scan_body");
18419        let x = body.input("carry", Shape::new(&[n], DType::F64));
18420        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
18421        let scale = body.add_node(
18422            Op::Constant { data: scale_bytes },
18423            vec![],
18424            Shape::new(&[n], DType::F64),
18425        );
18426        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
18427        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
18428        body.set_outputs(vec![next]);
18429
18430        // Outer graph: scan(init, body, length).
18431        let mut g = Graph::new("scan_outer");
18432        let init = g.input("init", Shape::new(&[n], DType::F64));
18433        let final_carry = g.scan(init, body, length);
18434        g.set_outputs(vec![final_carry]);
18435
18436        let init_data = vec![1.0_f64; n];
18437        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
18438        execute_thunks(&sched, arena.raw_buf_mut());
18439        let got = read_arena_f64(&arena, final_carry, n);
18440        let want: f64 = 1.1_f64.powi(length as i32);
18441        for i in 0..n {
18442            assert!(
18443                (got[i] - want).abs() < 1e-12,
18444                "got[{i}] = {} want {}",
18445                got[i],
18446                want
18447            );
18448        }
18449    }
18450
18451    /// Per-step xs scan: cumulative-sum.
18452    ///   carry_0 = init
18453    ///   carry_{t+1} = carry_t + xs\[t\]
18454    ///   final = sum_{t<length} xs\[t\] + init
18455    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
18456    /// (next carry). Validates the per-step-input plumbing end-to-end.
18457    #[test]
18458    fn scan_with_xs_cumulative_sum() {
18459        let n = 3usize;
18460        let length = 4u32;
18461
18462        let mut body = Graph::new("cumsum_body");
18463        // carry must come first in NodeId order — declare it first.
18464        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18465        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
18466        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
18467        body.set_outputs(vec![next]);
18468
18469        let mut g = Graph::new("cumsum_outer");
18470        let init = g.input("init", Shape::new(&[n], DType::F64));
18471        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18472        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18473        g.set_outputs(vec![final_carry]);
18474
18475        let init_data = vec![0.0_f64; n];
18476        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
18477        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
18478        execute_thunks(&sched, arena.raw_buf_mut());
18479        let got = read_arena_f64(&arena, final_carry, n);
18480
18481        // Reference: column-wise sum of xs rows + init. With our row-major
18482        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
18483        // (per-step row at offset t*n contributes element j to slot j).
18484        let mut want = init_data.clone();
18485        for t in 0..length as usize {
18486            for j in 0..n {
18487                want[j] += xs_data[t * n + j];
18488            }
18489        }
18490        for i in 0..n {
18491            assert!(
18492                (got[i] - want[i]).abs() < 1e-12,
18493                "got[{i}] = {} want {}",
18494                got[i],
18495                want[i]
18496            );
18497        }
18498    }
18499
18500    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
18501    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
18502    /// Models a Backward-Euler step driven by a time-varying source.
18503    #[test]
18504    fn scan_with_xs_be_with_drive() {
18505        let n = 3usize;
18506        let length = 4u32;
18507        let dt = 0.1_f64;
18508
18509        let mut m_data = vec![0.0_f64; n * n];
18510        for i in 0..n {
18511            m_data[i * n + i] = 1.0 + dt * 2.0;
18512            if i > 0 {
18513                m_data[i * n + (i - 1)] = -dt;
18514            }
18515            if i + 1 < n {
18516                m_data[i * n + (i + 1)] = -dt;
18517            }
18518        }
18519        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
18520
18521        let mut body = Graph::new("be_drive_body");
18522        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18523        let drive = body.input("drive", Shape::new(&[n], DType::F64));
18524        let m = body.add_node(
18525            Op::Constant { data: m_bytes },
18526            vec![],
18527            Shape::new(&[n, n], DType::F64),
18528        );
18529        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
18530        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
18531        body.set_outputs(vec![next]);
18532
18533        let mut g = Graph::new("be_drive_outer");
18534        let init = g.input("init", Shape::new(&[n], DType::F64));
18535        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
18536        let final_carry = g.scan_with_xs(init, &[xs], body, length);
18537        g.set_outputs(vec![final_carry]);
18538
18539        let init_data = vec![0.0_f64; n];
18540        // Drive the system with a unit pulse on element 0 at t=0,
18541        // zeros after.
18542        let mut xs_data = vec![0.0_f64; length as usize * n];
18543        xs_data[0] = 1.0;
18544
18545        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
18546        execute_thunks(&sched, arena.raw_buf_mut());
18547        let got = read_arena_f64(&arena, final_carry, n);
18548
18549        // Reference: per-step in pure Rust.
18550        let mut x = init_data.clone();
18551        for t in 0..length as usize {
18552            for j in 0..n {
18553                x[j] += xs_data[t * n + j];
18554            }
18555            let mut a_copy = m_data.clone();
18556            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
18557        }
18558        for i in 0..n {
18559            assert!(
18560                (got[i] - x[i]).abs() < 1e-12,
18561                "got[{i}] = {} ref {}",
18562                got[i],
18563                x[i]
18564            );
18565        }
18566    }
18567
18568    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
18569    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
18570    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
18571    /// per batch (each slice matches what the unbatched DenseSolve VJP
18572    /// would compute).
18573    #[test]
18574    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
18575        use rlx_opt::autodiff::grad_with_loss;
18576        let n = 3usize;
18577        let batch = 4usize;
18578
18579        let mut g = Graph::new("bds_grad");
18580        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
18581        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
18582        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
18583        let loss = g.reduce(
18584            x,
18585            ReduceOp::Sum,
18586            vec![0, 1],
18587            false,
18588            Shape::new(&[1], DType::F64),
18589        );
18590        g.set_outputs(vec![loss]);
18591
18592        let bwd = grad_with_loss(&g, &[a, b]);
18593
18594        let find = |graph: &Graph, want: &str| -> NodeId {
18595            for node in graph.nodes() {
18596                let name = match &node.op {
18597                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18598                    _ => None,
18599                };
18600                if name == Some(want) {
18601                    return node.id;
18602                }
18603            }
18604            panic!("no node named {want}");
18605        };
18606        let a_id = find(&bwd, "A");
18607        let b_id = find(&bwd, "b");
18608        let d_out_id = find(&bwd, "d_output");
18609
18610        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
18611        let mut a_data = vec![0.0_f64; batch * n * n];
18612        let mut b_data = vec![0.0_f64; batch * n];
18613        for bi in 0..batch {
18614            for i in 0..n {
18615                for j in 0..n {
18616                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
18617                }
18618                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
18619            }
18620            for i in 0..n {
18621                b_data[bi * n + i] = rng.next_f32() as f64;
18622            }
18623        }
18624        let d_seed = [1.0_f64];
18625
18626        let (sched, mut arena) = prepare_f64(
18627            &bwd,
18628            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
18629        );
18630        execute_thunks(&sched, arena.raw_buf_mut());
18631        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
18632        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
18633
18634        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
18635        // dA_i = -dB_i ⊗ x_i.
18636        for bi in 0..batch {
18637            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
18638            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
18639            let mut a_copy = a_slice.clone();
18640            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
18641            let x_ref = b_slice.clone();
18642            // dB: solve(A^T, ones)
18643            let mut at = vec![0.0_f64; n * n];
18644            for i in 0..n {
18645                for j in 0..n {
18646                    at[i * n + j] = a_slice[j * n + i];
18647                }
18648            }
18649            let mut ones = vec![1.0_f64; n];
18650            crate::blas::dgesv(&mut at, &mut ones, n, 1);
18651            let db_ref = ones;
18652            for i in 0..n {
18653                let got = db_out[bi * n + i];
18654                assert!(
18655                    (got - db_ref[i]).abs() < 1e-10,
18656                    "batch {bi}, db[{i}]: got {got} ref {}",
18657                    db_ref[i]
18658                );
18659            }
18660            // dA: -outer(db, x)
18661            for i in 0..n {
18662                for j in 0..n {
18663                    let got = da_out[bi * n * n + i * n + j];
18664                    let want = -db_ref[i] * x_ref[j];
18665                    assert!(
18666                        (got - want).abs() < 1e-10,
18667                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
18668                    );
18669                }
18670            }
18671        }
18672    }
18673
18674    /// AD knob: gradient through `scan_checkpointed` automatically
18675    /// uses the recompute backward path. Compares dinit from a plain
18676    /// scan against the same forward written with `scan_checkpointed`,
18677    /// both run through `grad_with_loss`. They must match to f64.
18678    #[test]
18679    fn scan_checkpointed_grad_matches_plain_scan_grad() {
18680        use rlx_opt::autodiff::grad_with_loss;
18681        let n = 2usize;
18682        let length = 6u32;
18683
18684        let make_body = || {
18685            let mut body = Graph::new("ck_body");
18686            let carry = body.input("carry", Shape::new(&[n], DType::F64));
18687            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
18688            let scale = body.add_node(
18689                Op::Constant { data: scale_bytes },
18690                vec![],
18691                Shape::new(&[n], DType::F64),
18692            );
18693            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
18694            body.set_outputs(vec![next]);
18695            body
18696        };
18697
18698        // Plain scan path.
18699        let mut g_plain = Graph::new("ck_plain");
18700        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
18701        let final_p = g_plain.scan(init_p, make_body(), length);
18702        let loss_p = g_plain.reduce(
18703            final_p,
18704            ReduceOp::Sum,
18705            vec![0],
18706            false,
18707            Shape::new(&[1], DType::F64),
18708        );
18709        g_plain.set_outputs(vec![loss_p]);
18710        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
18711
18712        // Checkpointed scan path with K=2 (length=6).
18713        let mut g_ck = Graph::new("ck_ckpt");
18714        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
18715        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
18716        let loss_c = g_ck.reduce(
18717            final_c,
18718            ReduceOp::Sum,
18719            vec![0],
18720            false,
18721            Shape::new(&[1], DType::F64),
18722        );
18723        g_ck.set_outputs(vec![loss_c]);
18724        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
18725
18726        let find = |graph: &Graph, want: &str| -> NodeId {
18727            for node in graph.nodes() {
18728                let name = match &node.op {
18729                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18730                    _ => None,
18731                };
18732                if name == Some(want) {
18733                    return node.id;
18734                }
18735            }
18736            panic!("no {want}");
18737        };
18738
18739        let init_data = vec![0.5_f64, -0.5];
18740        let d_seed = [1.0_f64];
18741
18742        let (s_p, mut a_p) = prepare_f64(
18743            &bwd_p,
18744            &[
18745                (find(&bwd_p, "init"), &init_data),
18746                (find(&bwd_p, "d_output"), &d_seed),
18747            ],
18748        );
18749        execute_thunks(&s_p, a_p.raw_buf_mut());
18750        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
18751
18752        let (s_c, mut a_c) = prepare_f64(
18753            &bwd_c,
18754            &[
18755                (find(&bwd_c, "init"), &init_data),
18756                (find(&bwd_c, "d_output"), &d_seed),
18757            ],
18758        );
18759        execute_thunks(&s_c, a_c.raw_buf_mut());
18760        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
18761
18762        for i in 0..n {
18763            assert!(
18764                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
18765                "dinit[{i}]: plain={} checkpointed={}",
18766                dinit_p[i],
18767                dinit_c[i]
18768            );
18769        }
18770    }
18771
18772    /// Recursive checkpointing end-to-end: build a ScanBackward
18773    /// configured with K=2 checkpoints (for length=4), and compare
18774    /// dinit against the same backward graph with full trajectory
18775    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
18776    /// Both paths must agree to f64 precision.
18777    #[test]
18778    fn recursive_checkpointing_matches_full_trajectory() {
18779        let n = 2usize;
18780        let length = 4u32;
18781
18782        // Body: carry + ones (deterministic, no xs)
18783        let build_body = || -> Graph {
18784            let mut body = Graph::new("rc_body");
18785            let carry = body.input("carry", Shape::new(&[n], DType::F64));
18786            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18787            let ones = body.add_node(
18788                Op::Constant { data: ones_bytes },
18789                vec![],
18790                Shape::new(&[n], DType::F64),
18791            );
18792            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18793            body.set_outputs(vec![next]);
18794            body
18795        };
18796
18797        // body_vjp: same body + d_output, output dcarry. body_vjp is
18798        // used by ScanBackward to walk the chain rule per step.
18799        let body_vjp_for = || -> Graph {
18800            use rlx_opt::autodiff::grad;
18801            let body = build_body();
18802            // grad(body, [carry_id]) → graph with dcarry as the output.
18803            let carry_id = body
18804                .nodes()
18805                .iter()
18806                .find(|n| matches!(n.op, Op::Input { .. }))
18807                .map(|n| n.id)
18808                .unwrap();
18809            grad(&body, &[carry_id])
18810        };
18811
18812        // ── Forward (All-strategy): scan with full trajectory ──
18813        let mut g_full = Graph::new("rc_outer_full");
18814        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
18815        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
18816        // Hand-build a ScanBackward node that reads the full trajectory.
18817        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
18818        let dinit_full_id = g_full.scan_backward(
18819            init_full,
18820            traj_full_id,
18821            upstream_full,
18822            &[],
18823            body_vjp_for(),
18824            length,
18825            true,
18826            Shape::new(&[n], DType::F64),
18827        );
18828        g_full.set_outputs(vec![dinit_full_id]);
18829
18830        // ── Forward (Recursive-2): scan saves only K=2 rows ──
18831        // Build the trajectory shape [K, *carry] = [2, 2].
18832        let k = 2u32;
18833        let mut g_rec = Graph::new("rc_outer_rec");
18834        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
18835        let traj_rec_id = g_rec.add_node(
18836            Op::Scan {
18837                body: Box::new(build_body()),
18838                length,
18839                save_trajectory: true,
18840                num_bcast: 0,
18841                num_xs: 0,
18842                num_checkpoints: k,
18843            },
18844            vec![init_rec],
18845            Shape::new(&[k as usize, n], DType::F64),
18846        );
18847        // Same upstream shape as the full version (the upstream is per
18848        // *forward step*, length rows — independent of K).
18849        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
18850        let dinit_rec_id = g_rec.add_node(
18851            Op::ScanBackward {
18852                body_vjp: Box::new(body_vjp_for()),
18853                length,
18854                save_trajectory: true,
18855                num_xs: 0,
18856                num_checkpoints: k,
18857                forward_body: Some(Box::new(build_body())),
18858            },
18859            vec![init_rec, traj_rec_id, upstream_rec],
18860            Shape::new(&[n], DType::F64),
18861        );
18862        g_rec.set_outputs(vec![dinit_rec_id]);
18863
18864        // ── Run both, same inputs ──
18865        let init_data = vec![0.5_f64, -0.5];
18866        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
18867
18868        let find = |graph: &Graph, want: &str| -> NodeId {
18869            for node in graph.nodes() {
18870                if let Op::Input { name } = &node.op
18871                    && name == want
18872                {
18873                    return node.id;
18874                }
18875            }
18876            panic!("no input {want}");
18877        };
18878
18879        let (s_full, mut a_full) = prepare_f64(
18880            &g_full,
18881            &[
18882                (find(&g_full, "init"), &init_data),
18883                (find(&g_full, "upstream"), &upstream_data),
18884            ],
18885        );
18886        execute_thunks(&s_full, a_full.raw_buf_mut());
18887        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
18888
18889        let (s_rec, mut a_rec) = prepare_f64(
18890            &g_rec,
18891            &[
18892                (find(&g_rec, "init"), &init_data),
18893                (find(&g_rec, "upstream"), &upstream_data),
18894            ],
18895        );
18896        execute_thunks(&s_rec, a_rec.raw_buf_mut());
18897        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
18898
18899        for i in 0..n {
18900            assert!(
18901                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
18902                "i={i}: full={} rec={}",
18903                dinit_full[i],
18904                dinit_rec[i]
18905            );
18906        }
18907    }
18908
18909    /// vmap-of-grad: gradient through Scan, vmap'd over init.
18910    /// Forward (per row):
18911    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
18912    ///   loss = sum(carry_length) = sum(init) + length·n
18913    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
18914    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
18915    /// against per-row grad_with_loss runs. Validates the vmap rule
18916    /// for Op::ScanBackward.
18917    #[test]
18918    fn vmap_of_grad_scan_matches_per_row_runs() {
18919        use rlx_opt::autodiff::grad_with_loss;
18920        use rlx_opt::vmap::vmap;
18921        let n = 2usize;
18922        let length = 3u32;
18923        let batch = 3usize;
18924
18925        let mut body = Graph::new("scan_grad_body");
18926        let carry = body.input("carry", Shape::new(&[n], DType::F64));
18927        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18928        let ones = body.add_node(
18929            Op::Constant { data: ones_bytes },
18930            vec![],
18931            Shape::new(&[n], DType::F64),
18932        );
18933        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
18934        body.set_outputs(vec![next]);
18935
18936        let mut g = Graph::new("scan_grad_outer");
18937        let init = g.input("init", Shape::new(&[n], DType::F64));
18938        let final_x = g.scan(init, body, length);
18939        let loss = g.reduce(
18940            final_x,
18941            ReduceOp::Sum,
18942            vec![0],
18943            false,
18944            Shape::new(&[1], DType::F64),
18945        );
18946        g.set_outputs(vec![loss]);
18947
18948        let bwd = grad_with_loss(&g, &[init]);
18949        let bg = vmap(&bwd, &["init"], batch);
18950
18951        let find = |graph: &Graph, want: &str| -> NodeId {
18952            for node in graph.nodes() {
18953                let name = match &node.op {
18954                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
18955                    _ => None,
18956                };
18957                if name == Some(want) {
18958                    return node.id;
18959                }
18960            }
18961            panic!("no node named {want}");
18962        };
18963        let init_b = find(&bg, "init");
18964        let d_out_b = find(&bg, "d_output");
18965
18966        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
18967        let d_seed = [1.0_f64];
18968
18969        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
18970        execute_thunks(&sched, arena.raw_buf_mut());
18971        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
18972
18973        for i in 0..batch * n {
18974            assert!(
18975                (dinit_b[i] - 1.0).abs() < 1e-12,
18976                "dinit[{i}] = {} (expected 1.0)",
18977                dinit_b[i]
18978            );
18979        }
18980
18981        // Cross-check vs per-row grad_with_loss.
18982        for bi in 0..batch {
18983            let row = &init_data[bi * n..(bi + 1) * n];
18984            let mut g2 = Graph::new("per_row_grad");
18985            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
18986            let mut body2 = Graph::new("per_row_body");
18987            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
18988            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
18989            let ones2 = body2.add_node(
18990                Op::Constant { data: ones2_bytes },
18991                vec![],
18992                Shape::new(&[n], DType::F64),
18993            );
18994            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
18995            body2.set_outputs(vec![next2]);
18996            let final2 = g2.scan(init2, body2, length);
18997            let loss2 = g2.reduce(
18998                final2,
18999                ReduceOp::Sum,
19000                vec![0],
19001                false,
19002                Shape::new(&[1], DType::F64),
19003            );
19004            g2.set_outputs(vec![loss2]);
19005            let bwd2 = grad_with_loss(&g2, &[init2]);
19006            let init2_id = find(&bwd2, "init");
19007            let d_out2_id = find(&bwd2, "d_output");
19008            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
19009            execute_thunks(&s2, a2.raw_buf_mut());
19010            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
19011            for j in 0..n {
19012                let got = dinit_b[bi * n + j];
19013                let want = row_dinit[j];
19014                assert!(
19015                    (got - want).abs() < 1e-12,
19016                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
19017                );
19018            }
19019        }
19020    }
19021
19022    /// vmap of Op::Scan: batched cumulative-sum. Forward
19023    ///   carry_{t+1} = carry_t + xs\[t\]
19024    ///   final = init + sum(xs)
19025    /// vmap over both init and xs at batch=3. Each batch row should
19026    /// equal the scalar run of the same body+xs subset.
19027    #[test]
19028    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
19029        use rlx_opt::vmap::vmap;
19030        let n = 2usize;
19031        let length = 4u32;
19032        let batch = 3usize;
19033
19034        // Body: (carry, x_t) → carry + x_t
19035        let mut body = Graph::new("scan_body_cumsum");
19036        let carry = body.input("carry", Shape::new(&[n], DType::F64));
19037        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
19038        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
19039        body.set_outputs(vec![next]);
19040
19041        let mut g = Graph::new("scan_outer_cumsum");
19042        let init = g.input("init", Shape::new(&[n], DType::F64));
19043        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19044        let final_carry = g.scan_with_xs(init, &[xs], body, length);
19045        g.set_outputs(vec![final_carry]);
19046
19047        // vmap over both init and xs.
19048        let bg = vmap(&g, &["init", "xs"], batch);
19049
19050        // Test data — distinct per-batch rows.
19051        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
19052        // xs has shape [B, length, n] after vmap (the outer's xs is
19053        // [length, n]; vmap lifts it to [B, length, n]).
19054        let xs_data: Vec<f64> = (0..batch * length as usize * n)
19055            .map(|i| 0.1 * (i as f64))
19056            .collect();
19057
19058        let find = |graph: &Graph, want: &str| -> NodeId {
19059            for node in graph.nodes() {
19060                if let Op::Input { name } = &node.op
19061                    && name == want
19062                {
19063                    return node.id;
19064                }
19065            }
19066            panic!("no input {want}");
19067        };
19068        let init_b = find(&bg, "init");
19069        let xs_b = find(&bg, "xs");
19070        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
19071        execute_thunks(&sched, arena.raw_buf_mut());
19072        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
19073
19074        // Reference: per-batch scalar Scan.
19075        for bi in 0..batch {
19076            let init_slice = &init_data[bi * n..(bi + 1) * n];
19077            let mut x = init_slice.to_vec();
19078            for t in 0..length as usize {
19079                for j in 0..n {
19080                    x[j] += xs_data[bi * length as usize * n + t * n + j];
19081                }
19082            }
19083
19084            for i in 0..n {
19085                let got = batched_out[bi * n + i];
19086                assert!(
19087                    (got - x[i]).abs() < 1e-12,
19088                    "row {bi}, i {i}: got {got} ref {}",
19089                    x[i]
19090                );
19091            }
19092        }
19093    }
19094
19095    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
19096    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
19097    /// and b (batched [B,N]). Run on CPU and compare each batch row
19098    /// against an independent scalar dgesv.
19099    #[test]
19100    fn vmap_dense_solve_matches_scalar_runs() {
19101        use rlx_opt::vmap::vmap;
19102        let n = 3usize;
19103        let batch = 4usize;
19104
19105        let mut g = Graph::new("solve_forward");
19106        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19107        let b = g.input("b", Shape::new(&[n], DType::F64));
19108        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
19109        g.set_outputs(vec![x]);
19110
19111        // vmap both A and b across the batch.
19112        let bg = vmap(&g, &["A", "b"], batch);
19113
19114        // Independent A and b per batch row.
19115        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
19116        let mut a_data = vec![0.0_f64; batch * n * n];
19117        let mut b_data = vec![0.0_f64; batch * n];
19118        for bi in 0..batch {
19119            // Diagonally dominant A — guaranteed non-singular.
19120            for i in 0..n {
19121                for j in 0..n {
19122                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
19123                }
19124                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
19125            }
19126            for i in 0..n {
19127                b_data[bi * n + i] = rng.next_f32() as f64;
19128            }
19129        }
19130
19131        let find = |graph: &Graph, want: &str| -> NodeId {
19132            for node in graph.nodes() {
19133                if let Op::Input { name } = &node.op
19134                    && name == want
19135                {
19136                    return node.id;
19137                }
19138            }
19139            panic!("no input named {want}");
19140        };
19141        let ba = find(&bg, "A");
19142        let bb = find(&bg, "b");
19143        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
19144        execute_thunks(&sched, arena.raw_buf_mut());
19145        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
19146
19147        // Reference: per-batch dgesv.
19148        for bi in 0..batch {
19149            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
19150            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
19151            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
19152            for i in 0..n {
19153                let got = batched_x[bi * n + i];
19154                let want = b_slice[i];
19155                assert!(
19156                    (got - want).abs() < 1e-12,
19157                    "row {bi}, i {i}: got {got} want {want}"
19158                );
19159            }
19160        }
19161    }
19162
19163    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
19164    /// and reduces to a per-element loss. vmap over x with batch=4.
19165    /// Run the batched graph and compare each output row against an
19166    /// independent scalar run of the original graph. Validates the
19167    /// structural lift + the runtime path for batched MatMul +
19168    /// batched Binary + batched Reduce.
19169    #[test]
19170    fn vmap_matmul_add_reduce_matches_scalar_runs() {
19171        use rlx_opt::vmap::vmap;
19172        let n = 3usize;
19173        let batch = 4usize;
19174
19175        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
19176        let mut g = Graph::new("vmap_e2e_forward");
19177        let x = g.input("x", Shape::new(&[n], DType::F64));
19178        let w = g.input("w", Shape::new(&[n, n], DType::F64));
19179        let b = g.input("b", Shape::new(&[n], DType::F64));
19180        let x_row = g.add_node(
19181            Op::Reshape {
19182                new_shape: vec![1, n as i64],
19183            },
19184            vec![x],
19185            Shape::new(&[1, n], DType::F64),
19186        );
19187        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
19188        let mm_flat = g.add_node(
19189            Op::Reshape {
19190                new_shape: vec![n as i64],
19191            },
19192            vec![mm],
19193            Shape::new(&[n], DType::F64),
19194        );
19195        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
19196        let loss = g.reduce(
19197            yv,
19198            ReduceOp::Sum,
19199            vec![0],
19200            false,
19201            Shape::new(&[1], DType::F64),
19202        );
19203        g.set_outputs(vec![loss]);
19204
19205        // Build the vmap'd version (batch over x; w and b shared).
19206        let bg = vmap(&g, &["x"], batch);
19207
19208        // Test data — distinct rows so we can verify the per-row dispatch.
19209        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
19210        let n_w = n * n;
19211        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
19212        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
19213        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
19214        for _ in 0..batch * n {
19215            x_data_batched.push(rng.next_f32() as f64);
19216        }
19217
19218        // Run the batched graph.
19219        let find = |graph: &Graph, want: &str| -> NodeId {
19220            for node in graph.nodes() {
19221                if let Op::Input { name } = &node.op
19222                    && name == want
19223                {
19224                    return node.id;
19225                }
19226            }
19227            panic!("no input named {want}");
19228        };
19229        let bx = find(&bg, "x");
19230        let bw = find(&bg, "w");
19231        let bb = find(&bg, "b");
19232        let (sched, mut arena) =
19233            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
19234        execute_thunks(&sched, arena.raw_buf_mut());
19235        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
19236        // (it preserves the leading batch axis but reduces what was [n] to [].
19237        // Since the original output was [1] f64 and the reduce was over
19238        // axis 0, after vmap the leading-axis-shifted reduce keeps the
19239        // leading 1 from the original output's [1] shape.)
19240        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
19241
19242        // Reference: run the original (un-batched) graph once per batch row.
19243        for bi in 0..batch {
19244            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
19245            let mut g2 = Graph::new("scalar_run");
19246            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
19247            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
19248            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
19249            let xr = g2.add_node(
19250                Op::Reshape {
19251                    new_shape: vec![1, n as i64],
19252                },
19253                vec![x2],
19254                Shape::new(&[1, n], DType::F64),
19255            );
19256            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
19257            let mf = g2.add_node(
19258                Op::Reshape {
19259                    new_shape: vec![n as i64],
19260                },
19261                vec![m],
19262                Shape::new(&[n], DType::F64),
19263            );
19264            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
19265            let l2 = g2.reduce(
19266                yv2,
19267                ReduceOp::Sum,
19268                vec![0],
19269                false,
19270                Shape::new(&[1], DType::F64),
19271            );
19272            g2.set_outputs(vec![l2]);
19273            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
19274            execute_thunks(&s2, a2.raw_buf_mut());
19275            let scalar_out = read_arena_f64(&a2, l2, 1);
19276            assert!(
19277                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
19278                "row {bi}: batched={} scalar={}",
19279                batched_out[bi],
19280                scalar_out[0]
19281            );
19282        }
19283    }
19284
19285    /// Full gradient through scan-with-xs: dinit AND dxs both checked
19286    /// against finite differences. Forward
19287    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
19288    ///   loss        = sum(carry_length)
19289    /// Verifies that grad_with_loss returns gradients w.r.t. both
19290    /// `init` and `xs` and that dxs matches per-element FD.
19291    #[test]
19292    fn scan_with_xs_dxs_matches_fd() {
19293        use rlx_opt::autodiff::grad_with_loss;
19294        let n = 3usize;
19295        let length = 3u32;
19296        let dt = 0.1_f64;
19297
19298        let mut m_data = vec![0.0_f64; n * n];
19299        for i in 0..n {
19300            m_data[i * n + i] = 1.0 + dt * 2.0;
19301            if i > 0 {
19302                m_data[i * n + (i - 1)] = -dt;
19303            }
19304            if i + 1 < n {
19305                m_data[i * n + (i + 1)] = -dt;
19306            }
19307        }
19308        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19309
19310        let mut body = Graph::new("be_dxs_body");
19311        let carry = body.input("carry", Shape::new(&[n], DType::F64));
19312        let drive = body.input("drive", Shape::new(&[n], DType::F64));
19313        let m = body.add_node(
19314            Op::Constant { data: m_bytes },
19315            vec![],
19316            Shape::new(&[n, n], DType::F64),
19317        );
19318        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
19319        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
19320        body.set_outputs(vec![next]);
19321
19322        let mut g = Graph::new("be_dxs_outer");
19323        let init = g.input("init", Shape::new(&[n], DType::F64));
19324        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19325        let final_carry = g.scan_with_xs(init, &[xs], body, length);
19326        let loss = g.reduce(
19327            final_carry,
19328            ReduceOp::Sum,
19329            vec![0],
19330            false,
19331            Shape::new(&[1], DType::F64),
19332        );
19333        g.set_outputs(vec![loss]);
19334
19335        // wrt = [init, xs] — get both gradients back.
19336        let bwd = grad_with_loss(&g, &[init, xs]);
19337        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
19338
19339        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19340            for node in graph.nodes() {
19341                let name = match &node.op {
19342                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19343                    _ => None,
19344                };
19345                if name == Some(want) {
19346                    return node.id;
19347                }
19348            }
19349            panic!("no node named {want:?}");
19350        };
19351        let init_bwd = find_by_name(&bwd, "init");
19352        let xs_bwd = find_by_name(&bwd, "xs");
19353        let d_out_bwd = find_by_name(&bwd, "d_output");
19354
19355        let init_data = vec![0.5_f64, 0.0, -0.5];
19356        let xs_data: Vec<f64> = (0..length as usize * n)
19357            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
19358            .collect();
19359        let d_seed = [1.0_f64];
19360
19361        let (sched, mut arena) = prepare_f64(
19362            &bwd,
19363            &[
19364                (init_bwd, &init_data),
19365                (xs_bwd, &xs_data),
19366                (d_out_bwd, &d_seed),
19367            ],
19368        );
19369        execute_thunks(&sched, arena.raw_buf_mut());
19370        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19371        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
19372
19373        let h = 1e-6;
19374        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
19375            let mut acc = x0.to_vec();
19376            for t in 0..length as usize {
19377                for j in 0..n {
19378                    acc[j] += xs_in[t * n + j];
19379                }
19380                let mut a_copy = m_data.clone();
19381                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19382            }
19383            acc.iter().sum()
19384        };
19385
19386        // FD on dinit (sanity).
19387        for i in 0..n {
19388            let mut ip = init_data.to_vec();
19389            ip[i] += h;
19390            let mut im = init_data.to_vec();
19391            im[i] -= h;
19392            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
19393            assert!(
19394                (dinit[i] - fd).abs() < 1e-7,
19395                "FD dinit[{i}]: AD={} FD={}",
19396                dinit[i],
19397                fd
19398            );
19399        }
19400
19401        // FD on every dxs entry — full per-step gradient check.
19402        for t in 0..length as usize {
19403            for j in 0..n {
19404                let idx = t * n + j;
19405                let mut xp = xs_data.clone();
19406                xp[idx] += h;
19407                let mut xm = xs_data.clone();
19408                xm[idx] -= h;
19409                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
19410                assert!(
19411                    (dxs[idx] - fd).abs() < 1e-7,
19412                    "FD dxs[t={t},j={j}]: AD={} FD={}",
19413                    dxs[idx],
19414                    fd
19415                );
19416            }
19417        }
19418    }
19419
19420    /// Gradient through a scan with per-step xs (Circulax-shaped).
19421    /// Forward:
19422    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
19423    ///   loss = sum(carry_length)
19424    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
19425    /// but `dinit` flows correctly through the body's reverse Jacobian
19426    /// even with xs in the chain. Verify dinit against finite differences.
19427    #[test]
19428    fn scan_with_xs_gradient_dinit_matches_fd() {
19429        use rlx_opt::autodiff::grad_with_loss;
19430        let n = 3usize;
19431        let length = 3u32;
19432        let dt = 0.1_f64;
19433
19434        let mut m_data = vec![0.0_f64; n * n];
19435        for i in 0..n {
19436            m_data[i * n + i] = 1.0 + dt * 2.0;
19437            if i > 0 {
19438                m_data[i * n + (i - 1)] = -dt;
19439            }
19440            if i + 1 < n {
19441                m_data[i * n + (i + 1)] = -dt;
19442            }
19443        }
19444        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19445
19446        let mut body = Graph::new("be_xs_grad_body");
19447        let carry = body.input("carry", Shape::new(&[n], DType::F64));
19448        let drive = body.input("drive", Shape::new(&[n], DType::F64));
19449        let m = body.add_node(
19450            Op::Constant { data: m_bytes },
19451            vec![],
19452            Shape::new(&[n, n], DType::F64),
19453        );
19454        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
19455        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
19456        body.set_outputs(vec![next]);
19457
19458        let mut g = Graph::new("be_xs_grad_outer");
19459        let init = g.input("init", Shape::new(&[n], DType::F64));
19460        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
19461        let final_carry = g.scan_with_xs(init, &[xs], body, length);
19462        let loss = g.reduce(
19463            final_carry,
19464            ReduceOp::Sum,
19465            vec![0],
19466            false,
19467            Shape::new(&[1], DType::F64),
19468        );
19469        g.set_outputs(vec![loss]);
19470
19471        let bwd = grad_with_loss(&g, &[init]);
19472
19473        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19474            for node in graph.nodes() {
19475                let name = match &node.op {
19476                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19477                    _ => None,
19478                };
19479                if name == Some(want) {
19480                    return node.id;
19481                }
19482            }
19483            panic!("no node named {want:?}");
19484        };
19485        let init_bwd = find_by_name(&bwd, "init");
19486        let xs_bwd = find_by_name(&bwd, "xs");
19487        let d_out_bwd = find_by_name(&bwd, "d_output");
19488
19489        let init_data = vec![0.5_f64, 0.0, -0.5];
19490        // Drive: small per-step pulse, varying per element.
19491        let xs_data: Vec<f64> = (0..length as usize * n)
19492            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
19493            .collect();
19494        let d_seed = [1.0_f64];
19495
19496        let (sched, mut arena) = prepare_f64(
19497            &bwd,
19498            &[
19499                (init_bwd, &init_data),
19500                (xs_bwd, &xs_data),
19501                (d_out_bwd, &d_seed),
19502            ],
19503        );
19504        execute_thunks(&sched, arena.raw_buf_mut());
19505        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19506
19507        let h = 1e-6;
19508        let loss_at = |x0: &[f64]| -> f64 {
19509            let mut acc = x0.to_vec();
19510            for t in 0..length as usize {
19511                for j in 0..n {
19512                    acc[j] += xs_data[t * n + j];
19513                }
19514                let mut a_copy = m_data.clone();
19515                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19516            }
19517            acc.iter().sum()
19518        };
19519        for i in 0..n {
19520            let mut ip = init_data.to_vec();
19521            ip[i] += h;
19522            let mut im = init_data.to_vec();
19523            im[i] -= h;
19524            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19525            assert!(
19526                (dinit[i] - fd).abs() < 1e-7,
19527                "FD dinit[{i}]: AD={} FD={}",
19528                dinit[i],
19529                fd
19530            );
19531        }
19532    }
19533
19534    /// Gradient through a geometric-growth scan: forward
19535    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
19536    ///   final   = x_length     = init · 1.1^length
19537    ///   loss    = sum(final)
19538    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
19539    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
19540    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
19541    #[test]
19542    fn scan_gradient_geometric_matches_closed_form() {
19543        use rlx_opt::autodiff::grad_with_loss;
19544        let n = 3usize;
19545        let length = 5u32;
19546
19547        let mut body = Graph::new("scan_grad_body");
19548        let x = body.input("carry", Shape::new(&[n], DType::F64));
19549        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
19550        let scale = body.add_node(
19551            Op::Constant { data: scale_bytes },
19552            vec![],
19553            Shape::new(&[n], DType::F64),
19554        );
19555        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
19556        body.set_outputs(vec![next]);
19557
19558        let mut g = Graph::new("scan_grad_outer");
19559        let init = g.input("init", Shape::new(&[n], DType::F64));
19560        let final_x = g.scan(init, body, length);
19561        let loss = g.reduce(
19562            final_x,
19563            ReduceOp::Sum,
19564            vec![0],
19565            false,
19566            Shape::new(&[1], DType::F64),
19567        );
19568        g.set_outputs(vec![loss]);
19569
19570        let bwd = grad_with_loss(&g, &[init]);
19571        assert_eq!(bwd.outputs.len(), 2);
19572
19573        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19574            for node in graph.nodes() {
19575                let name = match &node.op {
19576                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19577                    _ => None,
19578                };
19579                if name == Some(want) {
19580                    return node.id;
19581                }
19582            }
19583            panic!("no node named {want:?}");
19584        };
19585        let init_bwd = find_by_name(&bwd, "init");
19586        let d_out_bwd = find_by_name(&bwd, "d_output");
19587
19588        let init_data = vec![1.0_f64; n];
19589        let d_seed = [1.0_f64];
19590        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
19591        execute_thunks(&sched, arena.raw_buf_mut());
19592        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19593
19594        let want = 1.1_f64.powi(length as i32);
19595        for i in 0..n {
19596            assert!(
19597                (dinit[i] - want).abs() < 1e-12,
19598                "dinit[{i}] = {} want {}",
19599                dinit[i],
19600                want
19601            );
19602        }
19603
19604        // Finite-difference cross-check on init[0].
19605        let h = 1e-6;
19606        let loss_at = |x: &[f64]| -> f64 {
19607            let mut acc = x.to_vec();
19608            for _ in 0..length {
19609                for v in acc.iter_mut() {
19610                    *v *= 1.1;
19611                }
19612            }
19613            acc.iter().sum()
19614        };
19615        let mut ip = init_data.clone();
19616        ip[0] += h;
19617        let mut im = init_data.clone();
19618        im[0] -= h;
19619        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19620        assert!(
19621            (dinit[0] - fd).abs() < 1e-7,
19622            "FD dinit[0]: AD={} FD={}",
19623            dinit[0],
19624            fd
19625        );
19626    }
19627
19628    /// Gradient through Backward Euler scan composing with DenseSolve.
19629    /// Asserts dinit matches finite-difference per coordinate.
19630    #[test]
19631    fn scan_gradient_backward_euler_matches_fd() {
19632        use rlx_opt::autodiff::grad_with_loss;
19633        let n = 4usize;
19634        let length = 3u32;
19635        let dt = 0.05_f64;
19636
19637        let mut m_data = vec![0.0_f64; n * n];
19638        for i in 0..n {
19639            m_data[i * n + i] = 1.0 + dt * 2.0;
19640            if i > 0 {
19641                m_data[i * n + (i - 1)] = -dt;
19642            }
19643            if i + 1 < n {
19644                m_data[i * n + (i + 1)] = -dt;
19645            }
19646        }
19647        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19648
19649        let mut body = Graph::new("be_grad_body");
19650        let x = body.input("x", Shape::new(&[n], DType::F64));
19651        let m = body.add_node(
19652            Op::Constant { data: m_bytes },
19653            vec![],
19654            Shape::new(&[n, n], DType::F64),
19655        );
19656        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19657        body.set_outputs(vec![next]);
19658
19659        let mut g = Graph::new("be_grad_outer");
19660        let init = g.input("x0", Shape::new(&[n], DType::F64));
19661        let final_x = g.scan(init, body, length);
19662        let loss = g.reduce(
19663            final_x,
19664            ReduceOp::Sum,
19665            vec![0],
19666            false,
19667            Shape::new(&[1], DType::F64),
19668        );
19669        g.set_outputs(vec![loss]);
19670
19671        let bwd = grad_with_loss(&g, &[init]);
19672
19673        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19674            for node in graph.nodes() {
19675                let name = match &node.op {
19676                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19677                    _ => None,
19678                };
19679                if name == Some(want) {
19680                    return node.id;
19681                }
19682            }
19683            panic!("no node named {want:?}");
19684        };
19685        let init_bwd = find_by_name(&bwd, "x0");
19686        let d_out_bwd = find_by_name(&bwd, "d_output");
19687
19688        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19689        let d_seed = [1.0_f64];
19690        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
19691        execute_thunks(&sched, arena.raw_buf_mut());
19692        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
19693
19694        let h = 1e-6;
19695        let loss_at = |x0: &[f64]| -> f64 {
19696            let mut acc = x0.to_vec();
19697            for _ in 0..length {
19698                let mut a_copy = m_data.clone();
19699                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
19700            }
19701            acc.iter().sum()
19702        };
19703        for i in 0..n {
19704            let mut ip = init_data.to_vec();
19705            ip[i] += h;
19706            let mut im = init_data.to_vec();
19707            im[i] -= h;
19708            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
19709            assert!(
19710                (dinit[i] - fd).abs() < 1e-7,
19711                "FD dinit[{i}]: AD={} FD={}",
19712                dinit[i],
19713                fd
19714            );
19715        }
19716    }
19717
19718    /// Trajectory-mode scan: same Backward Euler body, but record the
19719    /// carry at every step. Output is `[length, n]` — row `t` is the
19720    /// state after step `t+1`. Validates the SaveAt-style waveform
19721    /// recording end-to-end, including that the last row equals what
19722    /// the no-trajectory variant would have returned.
19723    #[test]
19724    fn scan_trajectory_backward_euler_records_waveform() {
19725        let n = 4usize;
19726        let length = 5u32;
19727        let dt = 0.05_f64;
19728
19729        let mut m_data = vec![0.0_f64; n * n];
19730        for i in 0..n {
19731            m_data[i * n + i] = 1.0 + dt * 2.0;
19732            if i > 0 {
19733                m_data[i * n + (i - 1)] = -dt;
19734            }
19735            if i + 1 < n {
19736                m_data[i * n + (i + 1)] = -dt;
19737            }
19738        }
19739        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19740
19741        let mut body = Graph::new("be_traj_body");
19742        let x = body.input("x", Shape::new(&[n], DType::F64));
19743        let m = body.add_node(
19744            Op::Constant { data: m_bytes },
19745            vec![],
19746            Shape::new(&[n, n], DType::F64),
19747        );
19748        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19749        body.set_outputs(vec![next]);
19750
19751        let mut g = Graph::new("be_traj_outer");
19752        let init = g.input("x0", Shape::new(&[n], DType::F64));
19753        let traj = g.scan_trajectory(init, body, length);
19754        g.set_outputs(vec![traj]);
19755
19756        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19757        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19758        execute_thunks(&sched, arena.raw_buf_mut());
19759        let got = read_arena_f64(&arena, traj, length as usize * n);
19760
19761        // Reference: each step's solve, recorded.
19762        let mut want = Vec::<f64>::with_capacity(length as usize * n);
19763        let mut x_ref = init_data.to_vec();
19764        for _ in 0..length {
19765            let mut a_copy = m_data.clone();
19766            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
19767            want.extend_from_slice(&x_ref);
19768        }
19769        for i in 0..length as usize * n {
19770            assert!(
19771                (got[i] - want[i]).abs() < 1e-12,
19772                "got[{i}] = {} ref {}",
19773                got[i],
19774                want[i]
19775            );
19776        }
19777
19778        // Sanity: trajectory rows are monotone-decreasing in mass
19779        // (Backward Euler diffuses; boundary leak removes mass).
19780        for t in 1..length as usize {
19781            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
19782            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
19783            assert!(
19784                curr <= prev + 1e-15,
19785                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
19786                t - 1
19787            );
19788        }
19789
19790        // Last row of the trajectory equals what a non-trajectory
19791        // scan returns — verify by running the same forward through
19792        // the simpler API and comparing.
19793        let mut body2 = Graph::new("be_final_body");
19794        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
19795        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19796        let m2 = body2.add_node(
19797            Op::Constant { data: m_bytes2 },
19798            vec![],
19799            Shape::new(&[n, n], DType::F64),
19800        );
19801        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
19802        body2.set_outputs(vec![next2]);
19803
19804        let mut g2 = Graph::new("be_final_outer");
19805        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
19806        let final_x = g2.scan(init2, body2, length);
19807        g2.set_outputs(vec![final_x]);
19808        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
19809        execute_thunks(&sched2, arena2.raw_buf_mut());
19810        let final_got = read_arena_f64(&arena2, final_x, n);
19811
19812        let last_row = &got[(length as usize - 1) * n..length as usize * n];
19813        for i in 0..n {
19814            assert!(
19815                (last_row[i] - final_got[i]).abs() < 1e-15,
19816                "last trajectory row[{i}] = {} vs final-scan = {}",
19817                last_row[i],
19818                final_got[i]
19819            );
19820        }
19821    }
19822
19823    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
19824    /// pattern for Backward Euler.
19825    /// Body: x_{t+1} = solve(I + dt·A, x_t).
19826    /// 1-D heat-equation Laplacian A; analytic ground truth from
19827    /// composing the same per-step solve in Rust.
19828    #[test]
19829    fn scan_backward_euler_heat_f64() {
19830        let n = 4usize;
19831        let length = 5u32;
19832        let dt = 0.05_f64;
19833
19834        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
19835        // M is constant across iterations; embed it in the body via Op::Constant.
19836        let mut m_data = vec![0.0_f64; n * n];
19837        for i in 0..n {
19838            m_data[i * n + i] = 1.0 + dt * 2.0;
19839            if i > 0 {
19840                m_data[i * n + (i - 1)] = -dt;
19841            }
19842            if i + 1 < n {
19843                m_data[i * n + (i + 1)] = -dt;
19844            }
19845        }
19846        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
19847
19848        let mut body = Graph::new("be_body");
19849        let x = body.input("x", Shape::new(&[n], DType::F64));
19850        let m = body.add_node(
19851            Op::Constant { data: m_bytes },
19852            vec![],
19853            Shape::new(&[n, n], DType::F64),
19854        );
19855        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
19856        body.set_outputs(vec![next]);
19857
19858        let mut g = Graph::new("be_outer");
19859        let init = g.input("x0", Shape::new(&[n], DType::F64));
19860        let final_x = g.scan(init, body, length);
19861        g.set_outputs(vec![final_x]);
19862
19863        // Initial: a sharp pulse at index 1.
19864        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
19865        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
19866        execute_thunks(&sched, arena.raw_buf_mut());
19867        let got = read_arena_f64(&arena, final_x, n);
19868
19869        // Reference: apply the same M-solve `length` times in pure Rust.
19870        let mut ref_x = init_data.to_vec();
19871        for _ in 0..length {
19872            let mut a_copy = m_data.clone();
19873            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
19874        }
19875        for i in 0..n {
19876            assert!(
19877                (got[i] - ref_x[i]).abs() < 1e-12,
19878                "got[{i}] = {} ref {}",
19879                got[i],
19880                ref_x[i]
19881            );
19882        }
19883        // Sanity: pulse should diffuse, mass should be conserved-ish
19884        // (Backward Euler is mass-conserving for this stencil with
19885        // zero-flux boundaries — but our boundaries leak, so check
19886        // that mass strictly decreases instead).
19887        let mass: f64 = got.iter().sum();
19888        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
19889    }
19890
19891    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
19892    /// stays correct end-to-end. Verifies the executor/lowering and
19893    /// the LAPACK column-major dance both honour `nrhs > 1`.
19894    #[test]
19895    fn dense_solve_f64_multi_rhs_forward() {
19896        let n = 3usize;
19897        let k = 2usize;
19898        let mut g = Graph::new("solve_multi_rhs");
19899        let a = g.input("A", Shape::new(&[n, n], DType::F64));
19900        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19901        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19902        g.set_outputs(vec![x]);
19903
19904        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19905        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19906        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
19907        execute_thunks(&sched, arena.raw_buf_mut());
19908        let x_got = read_arena_f64(&arena, x, n * k);
19909        for c in 0..k {
19910            for i in 0..n {
19911                let mut acc = 0.0_f64;
19912                for j in 0..n {
19913                    acc += a_data[i * n + j] * x_got[j * k + c];
19914                }
19915                let want = b_data[i * k + c];
19916                assert!(
19917                    (acc - want).abs() < 1e-10,
19918                    "col {c} row {i}: got {acc} want {want}"
19919                );
19920            }
19921        }
19922    }
19923
19924    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
19925    /// Verified analytically + finite differences on dB[0,0].
19926    #[test]
19927    fn dense_solve_f64_multi_rhs_gradient() {
19928        use rlx_opt::autodiff::grad_with_loss;
19929        let n = 3usize;
19930        let k = 2usize;
19931        let mut g = Graph::new("solve_mrhs_grad");
19932        let a = g.param("A", Shape::new(&[n, n], DType::F64));
19933        let b = g.input("B", Shape::new(&[n, k], DType::F64));
19934        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
19935        let loss = g.reduce(
19936            x,
19937            ReduceOp::Sum,
19938            vec![0, 1],
19939            false,
19940            Shape::new(&[1], DType::F64),
19941        );
19942        g.set_outputs(vec![loss]);
19943
19944        let bwd = grad_with_loss(&g, &[a, b]);
19945        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
19946            for node in graph.nodes() {
19947                let name = match &node.op {
19948                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19949                    _ => None,
19950                };
19951                if name == Some(want) {
19952                    return node.id;
19953                }
19954            }
19955            panic!("no node named {want:?}");
19956        };
19957        let a_bwd = find_by_name(&bwd, "A");
19958        let b_bwd = find_by_name(&bwd, "B");
19959        let d_out = find_by_name(&bwd, "d_output");
19960
19961        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
19962        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
19963        let d_seed = [1.0_f64];
19964
19965        let (sched, mut arena) = prepare_f64(
19966            &bwd,
19967            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
19968        );
19969        execute_thunks(&sched, arena.raw_buf_mut());
19970        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
19971        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
19972
19973        // Reference.
19974        let mut x_ref = b_data;
19975        {
19976            let mut a_copy = a_data;
19977            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
19978        }
19979        let mut at = [0.0_f64; 9];
19980        for i in 0..n {
19981            for j in 0..n {
19982                at[i * n + j] = a_data[j * n + i];
19983            }
19984        }
19985        let mut ones_nk = vec![1.0_f64; n * k];
19986        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
19987        let db_ref = ones_nk;
19988        let mut da_ref = [0.0_f64; 9];
19989        for i in 0..n {
19990            for j in 0..n {
19991                let mut acc = 0.0_f64;
19992                for c in 0..k {
19993                    acc += db_ref[i * k + c] * x_ref[j * k + c];
19994                }
19995                da_ref[i * n + j] = -acc;
19996            }
19997        }
19998        for i in 0..n * k {
19999            assert!(
20000                (db_got[i] - db_ref[i]).abs() < 1e-10,
20001                "dB[{i}]: got {} want {}",
20002                db_got[i],
20003                db_ref[i]
20004            );
20005        }
20006        for i in 0..n * n {
20007            assert!(
20008                (da_got[i] - da_ref[i]).abs() < 1e-10,
20009                "dA[{i}]: got {} want {}",
20010                da_got[i],
20011                da_ref[i]
20012            );
20013        }
20014
20015        // FD on dB[0,0].
20016        let h = 1e-6;
20017        let mut bp = b_data;
20018        bp[0] += h;
20019        let mut bm = b_data;
20020        bm[0] -= h;
20021        let xp = {
20022            let mut a_copy = a_data;
20023            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
20024            bp
20025        };
20026        let xm = {
20027            let mut a_copy = a_data;
20028            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
20029            bm
20030        };
20031        let lp: f64 = xp.iter().sum();
20032        let lm: f64 = xm.iter().sum();
20033        let fd = (lp - lm) / (2.0 * h);
20034        assert!(
20035            (db_got[0] - fd).abs() < 1e-7,
20036            "FD dB[0,0]: AD={} FD={}",
20037            db_got[0],
20038            fd
20039        );
20040    }
20041
20042    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
20043    #[test]
20044    fn dense_solve_f64_multi_rhs_jvp() {
20045        use rlx_opt::autodiff_fwd::jvp;
20046        let n = 3usize;
20047        let k = 2usize;
20048        let mut g = Graph::new("solve_mrhs_jvp");
20049        let a = g.input("A", Shape::new(&[n, n], DType::F64));
20050        let b = g.input("B", Shape::new(&[n, k], DType::F64));
20051        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
20052        g.set_outputs(vec![x]);
20053
20054        let jg = jvp(&g, &[b]);
20055        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20056            for node in graph.nodes() {
20057                let name = match &node.op {
20058                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20059                    _ => None,
20060                };
20061                if name == Some(want) {
20062                    return node.id;
20063                }
20064            }
20065            panic!("no node named {want:?}");
20066        };
20067        let a_id = find_by_name(&jg, "A");
20068        let b_id = find_by_name(&jg, "B");
20069        let tb_id = find_by_name(&jg, "tangent_B");
20070
20071        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
20072        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
20073        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
20074
20075        let (sched, mut arena) =
20076            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
20077        execute_thunks(&sched, arena.raw_buf_mut());
20078        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
20079
20080        let mut a_copy = a_data;
20081        let mut tb_copy = tb_data;
20082        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
20083        for i in 0..n * k {
20084            assert!(
20085                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
20086                "t_X[{i}]: AD={} ref={}",
20087                tangent_x[i],
20088                tb_copy[i]
20089            );
20090        }
20091
20092        let h = 1e-6;
20093        let mut bp = b_data;
20094        let mut bm = b_data;
20095        for i in 0..n * k {
20096            bp[i] += h * tb_data[i];
20097            bm[i] -= h * tb_data[i];
20098        }
20099        let xp = {
20100            let mut a_copy = a_data;
20101            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
20102            bp
20103        };
20104        let xm = {
20105            let mut a_copy = a_data;
20106            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
20107            bm
20108        };
20109        for i in 0..n * k {
20110            let fd = (xp[i] - xm[i]) / (2.0 * h);
20111            assert!(
20112                (tangent_x[i] - fd).abs() < 1e-7,
20113                "FD t_X[{i}]: AD={} FD={}",
20114                tangent_x[i],
20115                fd
20116            );
20117        }
20118    }
20119
20120    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
20121    ///
20122    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
20123    /// compile + run, and check the tangent output matches the
20124    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
20125    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
20126    #[test]
20127    fn jvp_dense_solve_b_runs_and_matches_fd() {
20128        use rlx_opt::autodiff_fwd::jvp;
20129        let n = 3usize;
20130
20131        // Forward.
20132        let mut g = Graph::new("jvp_b_e2e");
20133        let a = g.input("A", Shape::new(&[n, n], DType::F64));
20134        let b = g.input("b", Shape::new(&[n], DType::F64));
20135        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
20136        g.set_outputs(vec![x]);
20137
20138        // JVP graph perturbing b only.
20139        let jg = jvp(&g, &[b]);
20140        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
20141        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20142            for node in graph.nodes() {
20143                let name = match &node.op {
20144                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20145                    _ => None,
20146                };
20147                if name == Some(want) {
20148                    return node.id;
20149                }
20150            }
20151            panic!("no node named {want:?}");
20152        };
20153        let a_id = find_by_name(&jg, "A");
20154        let b_id = find_by_name(&jg, "b");
20155        let tb_id = find_by_name(&jg, "tangent_b");
20156
20157        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
20158        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
20159        // Pick an arbitrary perturbation direction.
20160        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
20161
20162        let (sched, mut arena) =
20163            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
20164        execute_thunks(&sched, arena.raw_buf_mut());
20165
20166        // Outputs: [primal_x, tangent_x].
20167        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
20168        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
20169
20170        // Closed form: t_x = solve(A, t_b).
20171        let t_x_ref = {
20172            let mut a = a_data;
20173            let mut tb = tb_data;
20174            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
20175            assert_eq!(info, 0);
20176            tb
20177        };
20178        for i in 0..n {
20179            assert!(
20180                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
20181                "t_x[{i}]: got {} want {}",
20182                tangent_x[i],
20183                t_x_ref[i]
20184            );
20185        }
20186
20187        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
20188        let h = 1e-6;
20189        let mut bp = b_data;
20190        let mut bm = b_data;
20191        for i in 0..n {
20192            bp[i] += h * tb_data[i];
20193            bm[i] -= h * tb_data[i];
20194        }
20195        let xp = {
20196            let mut a = a_data;
20197            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
20198            assert_eq!(info, 0);
20199            bp
20200        };
20201        let xm = {
20202            let mut a = a_data;
20203            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
20204            assert_eq!(info, 0);
20205            bm
20206        };
20207        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
20208        for i in 0..n {
20209            assert!(
20210                (tangent_x[i] - fd[i]).abs() < 1e-7,
20211                "FD mismatch t_x[{i}]: AD={} FD={}",
20212                tangent_x[i],
20213                fd[i]
20214            );
20215        }
20216        // Sanity: primal output is the actual solve.
20217        let primal_ref = {
20218            let mut a = a_data;
20219            let mut b = b_data;
20220            crate::blas::dgesv(&mut a, &mut b, n, 1);
20221            b
20222        };
20223        for i in 0..n {
20224            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
20225        }
20226    }
20227
20228    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
20229    /// path includes the −t_A·x correction term.
20230    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
20231    /// directional derivative of `solve(A, b)` w.r.t. A in the
20232    /// `t_A` direction.
20233    #[test]
20234    fn jvp_dense_solve_a_runs_and_matches_fd() {
20235        use rlx_opt::autodiff_fwd::jvp;
20236        let n = 3usize;
20237
20238        let mut g = Graph::new("jvp_a_e2e");
20239        let a = g.input("A", Shape::new(&[n, n], DType::F64));
20240        let b = g.input("b", Shape::new(&[n], DType::F64));
20241        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
20242        g.set_outputs(vec![x]);
20243
20244        let jg = jvp(&g, &[a]);
20245        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
20246            for node in graph.nodes() {
20247                let name = match &node.op {
20248                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
20249                    _ => None,
20250                };
20251                if name == Some(want) {
20252                    return node.id;
20253                }
20254            }
20255            panic!("no node named {want:?}");
20256        };
20257        let a_id = find_by_name(&jg, "A");
20258        let b_id = find_by_name(&jg, "b");
20259        let ta_id = find_by_name(&jg, "tangent_A");
20260
20261        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
20262        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
20263        // Asymmetric perturbation direction for A.
20264        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
20265
20266        let (sched, mut arena) =
20267            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
20268        execute_thunks(&sched, arena.raw_buf_mut());
20269
20270        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
20271
20272        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
20273        let x_ref = {
20274            let mut a = a_data;
20275            let mut b = b_data;
20276            crate::blas::dgesv(&mut a, &mut b, n, 1);
20277            b
20278        };
20279        let mut prod = [0.0_f64; 3];
20280        for i in 0..n {
20281            for j in 0..n {
20282                prod[i] += ta_data[i * n + j] * x_ref[j];
20283            }
20284        }
20285        let t_x_ref = {
20286            let mut a = a_data;
20287            let mut p = prod;
20288            crate::blas::dgesv(&mut a, &mut p, n, 1);
20289            [-p[0], -p[1], -p[2]]
20290        };
20291        for i in 0..n {
20292            assert!(
20293                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
20294                "closed-form t_x[{i}]: AD={} ref={}",
20295                tangent_x[i],
20296                t_x_ref[i]
20297            );
20298        }
20299
20300        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
20301        let h = 1e-6;
20302        let mut ap = a_data;
20303        let mut am = a_data;
20304        for i in 0..n * n {
20305            ap[i] += h * ta_data[i];
20306            am[i] -= h * ta_data[i];
20307        }
20308        let xp = {
20309            let mut a = ap;
20310            let mut b = b_data;
20311            crate::blas::dgesv(&mut a, &mut b, n, 1);
20312            b
20313        };
20314        let xm = {
20315            let mut a = am;
20316            let mut b = b_data;
20317            crate::blas::dgesv(&mut a, &mut b, n, 1);
20318            b
20319        };
20320        for i in 0..n {
20321            let fd = (xp[i] - xm[i]) / (2.0 * h);
20322            assert!(
20323                (tangent_x[i] - fd).abs() < 1e-7,
20324                "FD t_x[{i}]: AD={} FD={}",
20325                tangent_x[i],
20326                fd
20327            );
20328        }
20329    }
20330
20331    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
20332    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
20333    /// in-test reference loop that does the same i32 accumulation
20334    /// and requantize math. Symmetric quant (zp=0) to keep the math
20335    /// head-to-head.
20336    #[test]
20337    fn q_conv2d_matches_reference() {
20338        use rlx_ir::Philox4x32;
20339        // Small NCHW shape — enough to exercise stride/padding edges.
20340        let n = 1usize;
20341        let c_in = 2usize;
20342        let h = 5usize;
20343        let w_in = 5usize;
20344        let c_out = 3usize;
20345        let kh = 3usize;
20346        let kw = 3usize;
20347        let ph = 1usize;
20348        let pw = 1usize;
20349        let sh = 1usize;
20350        let sw = 1usize;
20351        let h_out = (h + 2 * ph - kh) / sh + 1;
20352        let w_out = (w_in + 2 * pw - kw) / sw + 1;
20353
20354        let x_scale = 0.04f32;
20355        let w_scale = 0.02f32;
20356        let out_scale = 0.5f32;
20357        let mult = x_scale * w_scale / out_scale;
20358
20359        let mut rng = Philox4x32::new(2099);
20360        let mut xf = vec![0f32; n * c_in * h * w_in];
20361        rng.fill_normal(&mut xf);
20362        let mut wf = vec![0f32; c_out * c_in * kh * kw];
20363        rng.fill_normal(&mut wf);
20364        let xq: Vec<i8> = xf
20365            .iter()
20366            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
20367            .collect();
20368        let wq: Vec<i8> = wf
20369            .iter()
20370            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
20371            .collect();
20372        let bias: Vec<i32> = vec![0i32; c_out];
20373
20374        let mut g = Graph::new("qconv");
20375        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
20376        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
20377        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
20378        let out = g.q_conv2d(
20379            xn,
20380            wn,
20381            bn,
20382            vec![kh, kw],
20383            vec![sh, sw],
20384            vec![ph, pw],
20385            vec![1, 1],
20386            1,
20387            0,
20388            0,
20389            0,
20390            mult,
20391            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
20392        );
20393        g.set_outputs(vec![out]);
20394
20395        let plan = rlx_opt::memory::plan_memory(&g);
20396        let mut arena = crate::arena::Arena::from_plan(plan);
20397        let sched = compile_thunks(&g, &arena);
20398        // Capture offsets before borrowing the buf mutably (avoids
20399        // overlap between &mut and the &arena.byte_offset reads).
20400        let xn_off = arena.byte_offset(xn);
20401        let wn_off = arena.byte_offset(wn);
20402        let bn_off = arena.byte_offset(bn);
20403        let out_off = arena.byte_offset(out);
20404        let buf = arena.raw_buf_mut();
20405        unsafe {
20406            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
20407            for (i, &v) in xq.iter().enumerate() {
20408                *p.add(i) = v;
20409            }
20410            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
20411            for (i, &v) in wq.iter().enumerate() {
20412                *p.add(i) = v;
20413            }
20414            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
20415            for (i, &v) in bias.iter().enumerate() {
20416                *p.add(i) = v;
20417            }
20418        }
20419        execute_thunks(&sched, arena.raw_buf_mut());
20420        let out_q: Vec<i8> = unsafe {
20421            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
20422            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
20423        };
20424
20425        // Reference: scalar loop in NCHW with the same requantize.
20426        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
20427        for ni in 0..n {
20428            for co in 0..c_out {
20429                for ho in 0..h_out {
20430                    for wo in 0..w_out {
20431                        let mut acc: i32 = 0;
20432                        for ci in 0..c_in {
20433                            for ki in 0..kh {
20434                                for kj in 0..kw {
20435                                    let hi = ho * sh + ki;
20436                                    let wi = wo * sw + kj;
20437                                    if hi < ph || wi < pw {
20438                                        continue;
20439                                    }
20440                                    let hi = hi - ph;
20441                                    let wi = wi - pw;
20442                                    if hi >= h || wi >= w_in {
20443                                        continue;
20444                                    }
20445                                    let xv =
20446                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
20447                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
20448                                    acc += xv * wv;
20449                                }
20450                            }
20451                        }
20452                        let r = (acc as f32 * mult).round() as i32;
20453                        let r = r.clamp(-128, 127) as i8;
20454                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
20455                    }
20456                }
20457            }
20458        }
20459
20460        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
20461            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
20462        }
20463    }
20464
20465    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
20466    /// fake-quant reference `Dequantize → MatMul → Quantize` that
20467    /// would produce the same output if we round-tripped through
20468    /// f32. Both should agree element-for-element (or within ±1 i8
20469    /// step, since rounding in the requantize uses different code
20470    /// paths). Symmetric quantization (zp=0) for both paths to keep
20471    /// the math head-to-head.
20472    #[test]
20473    fn q_matmul_matches_fake_quant_reference() {
20474        use rlx_ir::Philox4x32;
20475        let m = 3usize;
20476        let k = 8usize;
20477        let n = 5usize;
20478        let mut rng = Philox4x32::new(2031);
20479
20480        // Pick scales and quantize random f32 inputs to i8.
20481        let x_scale = 0.05f32;
20482        let w_scale = 0.03f32;
20483        let out_scale = 0.4f32;
20484        let mult = x_scale * w_scale / out_scale;
20485        let mut xf = vec![0f32; m * k];
20486        rng.fill_normal(&mut xf);
20487        let mut wf = vec![0f32; k * n];
20488        rng.fill_normal(&mut wf);
20489        let xq: Vec<i8> = xf
20490            .iter()
20491            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
20492            .collect();
20493        let wq: Vec<i8> = wf
20494            .iter()
20495            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
20496            .collect();
20497        let bias: Vec<i32> = vec![0i32; n];
20498
20499        // ── Direct INT8 path ──
20500        let _f = DType::F32;
20501        let mut g_q = Graph::new("qmm_direct");
20502        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
20503        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
20504        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
20505        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
20506        g_q.set_outputs(vec![out]);
20507        let plan = rlx_opt::memory::plan_memory(&g_q);
20508        let mut arena = crate::arena::Arena::from_plan(plan);
20509        let sched = compile_thunks(&g_q, &arena);
20510
20511        // Fill inputs.
20512        let xn_off = arena.byte_offset(xn);
20513        let wn_off = arena.byte_offset(wn);
20514        let bn_off = arena.byte_offset(bn);
20515        let out_off = arena.byte_offset(out);
20516        let buf = arena.raw_buf_mut();
20517        unsafe {
20518            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
20519            for (i, &v) in xq.iter().enumerate() {
20520                *p.add(i) = v;
20521            }
20522            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
20523            for (i, &v) in wq.iter().enumerate() {
20524                *p.add(i) = v;
20525            }
20526            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
20527            for (i, &v) in bias.iter().enumerate() {
20528                *p.add(i) = v;
20529            }
20530        }
20531        execute_thunks(&sched, arena.raw_buf_mut());
20532        let out_q: Vec<i8> = unsafe {
20533            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
20534            (0..m * n).map(|i| *p.add(i)).collect()
20535        };
20536
20537        // ── Fake-quant reference: scalar emulation in plain Rust ──
20538        // Same arithmetic the kernel does, but in a verifier loop:
20539        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
20540        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
20541        let mut out_ref = vec![0i8; m * n];
20542        for mi in 0..m {
20543            for ni in 0..n {
20544                let mut acc: i32 = 0;
20545                for ki in 0..k {
20546                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
20547                }
20548                let r = (acc as f32 * mult).round() as i32;
20549                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
20550            }
20551        }
20552
20553        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
20554            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
20555        }
20556    }
20557
20558    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
20559    /// dequantize back, and confirm the result tracks the input
20560    /// within the per-element scale (the inevitable rounding error).
20561    /// Also pins the kernel's saturation behavior at the i8 limits.
20562    #[test]
20563    fn quantize_dequantize_round_trip() {
20564        use rlx_ir::Philox4x32;
20565        let len = 64;
20566        let mut rng = Philox4x32::new(2027);
20567        let mut x = vec![0f32; len];
20568        rng.fill_normal(&mut x);
20569        // Stretch a couple values past the +/- saturation cliff so
20570        // the saturate_i8 path is exercised.
20571        x[0] = 999.0;
20572        x[1] = -999.0;
20573
20574        let scale = 0.05f32;
20575        let zp = 3i32;
20576
20577        let f = DType::F32;
20578        let mut g = Graph::new("qdq");
20579        let xn = g.input("x", Shape::new(&[len], f));
20580        let q = g.quantize(xn, scale, zp);
20581        let dq = g.dequantize(q, scale, zp);
20582        g.set_outputs(vec![dq]);
20583
20584        let plan = rlx_opt::memory::plan_memory(&g);
20585        let mut arena = crate::arena::Arena::from_plan(plan);
20586        let sched = compile_thunks(&g, &arena);
20587        let xn_off = arena.byte_offset(xn);
20588        let dq_off = arena.byte_offset(dq);
20589        let buf = arena.raw_buf_mut();
20590        unsafe {
20591            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20592            for (i, &v) in x.iter().enumerate() {
20593                *p.add(i) = v;
20594            }
20595        }
20596        execute_thunks(&sched, arena.raw_buf_mut());
20597        let out: Vec<f32> = unsafe {
20598            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
20599            (0..len).map(|i| *p.add(i)).collect()
20600        };
20601
20602        // Saturated values at i=0,1 should clamp to ±127's dequant
20603        // range (= (±127 - zp) · scale).
20604        let sat_pos = (127 - zp) as f32 * scale;
20605        let sat_neg = (-128 - zp) as f32 * scale;
20606        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
20607        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
20608
20609        // Everything else should round-trip within `scale` (one quant
20610        // step = the worst-case rounding error).
20611        for i in 2..len {
20612            assert!(
20613                (out[i] - x[i]).abs() <= scale + 1e-5,
20614                "qdq[{i}]: {} → {}, scale={scale}",
20615                x[i],
20616                out[i]
20617            );
20618        }
20619    }
20620
20621    /// Per-channel quantize / dequantize: independent scale and zp
20622    /// per slice along an axis. Verifies (a) each channel uses its
20623    /// own scale (not a shared one), (b) saturation still respects
20624    /// the i8 range, (c) channel data layout decomposition is
20625    /// correct (no cross-channel leakage).
20626    #[test]
20627    fn quantize_per_channel_round_trip() {
20628        let c = 4usize;
20629        let inner = 5usize;
20630        // Different magnitudes per channel — proves the per-channel
20631        // scale is actually being read for each row.
20632        let mags = [0.01f32, 0.5, 5.0, 50.0];
20633        let mut x = vec![0f32; c * inner];
20634        for ci in 0..c {
20635            for ii in 0..inner {
20636                // Sweep through values that span [-max_abs, +max_abs]
20637                // for each channel, plus one value past the cliff to
20638                // trigger saturation.
20639                x[ci * inner + ii] = match ii {
20640                    0 => -mags[ci],
20641                    1 => 0.0,
20642                    2 => mags[ci],
20643                    3 => mags[ci] * 1000.0,  // saturates +
20644                    _ => -mags[ci] * 1000.0, // saturates -
20645                };
20646            }
20647        }
20648        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
20649        let zps: Vec<i32> = vec![0, 0, 0, 0];
20650
20651        let f = DType::F32;
20652        let mut g = Graph::new("qdq_pc");
20653        let xn = g.input("x", Shape::new(&[c, inner], f));
20654        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
20655        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
20656        g.set_outputs(vec![dq]);
20657
20658        let plan = rlx_opt::memory::plan_memory(&g);
20659        let mut arena = crate::arena::Arena::from_plan(plan);
20660        let sched = compile_thunks(&g, &arena);
20661        let xn_off = arena.byte_offset(xn);
20662        let dq_off = arena.byte_offset(dq);
20663        let buf = arena.raw_buf_mut();
20664        unsafe {
20665            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20666            for (i, &v) in x.iter().enumerate() {
20667                *p.add(i) = v;
20668            }
20669        }
20670        execute_thunks(&sched, arena.raw_buf_mut());
20671        let out: Vec<f32> = unsafe {
20672            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
20673            (0..c * inner).map(|i| *p.add(i)).collect()
20674        };
20675
20676        for ci in 0..c {
20677            // Within-range entries (positions 0, 1, 2) must round-trip
20678            // within one quant step of *that channel's* scale.
20679            for ii in 0..3 {
20680                let idx = ci * inner + ii;
20681                assert!(
20682                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
20683                    "ch {ci} idx {ii}: {} vs {}",
20684                    x[idx],
20685                    out[idx]
20686                );
20687            }
20688            // Saturated positions clamp to ±127 · scale[ci].
20689            let sat_pos = 127.0 * scales[ci];
20690            let sat_neg = -128.0 * scales[ci];
20691            assert!(
20692                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
20693                "ch {ci} +sat: {}",
20694                out[ci * inner + 3]
20695            );
20696            assert!(
20697                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
20698                "ch {ci} -sat: {}",
20699                out[ci * inner + 4]
20700            );
20701        }
20702    }
20703
20704    /// `Op::ActivationBackward` parity for every supported kind.
20705    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
20706    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
20707    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
20708    /// the kernel.
20709    #[test]
20710    fn activation_backward_matches_numerical_per_kind() {
20711        use rlx_ir::Philox4x32;
20712        use rlx_ir::op::Activation;
20713        let mut rng = Philox4x32::new(91);
20714        let len = 32;
20715        // x sampled away from kink/branch points: shifted positive
20716        // (exp/sqrt/log domain) for the unary-positive activations;
20717        // wide range otherwise. Two parallel tests would be cleaner
20718        // but this is concise enough.
20719        let mut x_pos = vec![0f32; len];
20720        rng.fill_normal(&mut x_pos);
20721        for v in x_pos.iter_mut() {
20722            *v = v.abs() + 0.5;
20723        }
20724        let mut x_any = vec![0f32; len];
20725        rng.fill_normal(&mut x_any);
20726        let mut dy = vec![0f32; len];
20727        rng.fill_normal(&mut dy);
20728
20729        for &(kind, x_data, eps, tol) in &[
20730            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
20731            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
20732            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
20733            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
20734            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
20735            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
20736            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
20737            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
20738            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
20739            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
20740        ] {
20741            let f = DType::F32;
20742            let mut g = Graph::new("act_bw");
20743            let xn = g.input("x", Shape::new(&[len], f));
20744            let dyn_ = g.input("dy", Shape::new(&[len], f));
20745            let dx = g.activation_backward(kind, xn, dyn_);
20746            g.set_outputs(vec![dx]);
20747
20748            let plan = rlx_opt::memory::plan_memory(&g);
20749            let mut arena = crate::arena::Arena::from_plan(plan);
20750            let sched = compile_thunks(&g, &arena);
20751
20752            let xn_off = arena.byte_offset(xn);
20753            let dyn_off = arena.byte_offset(dyn_);
20754            let dx_off = arena.byte_offset(dx);
20755            let buf = arena.raw_buf_mut();
20756            unsafe {
20757                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
20758                for (i, &v) in x_data.iter().enumerate() {
20759                    *p.add(i) = v;
20760                }
20761                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
20762                for (i, &v) in dy.iter().enumerate() {
20763                    *p.add(i) = v;
20764                }
20765            }
20766            execute_thunks(&sched, arena.raw_buf_mut());
20767            let analytical: Vec<f32> = unsafe {
20768                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
20769                (0..len).map(|i| *p.add(i)).collect()
20770            };
20771
20772            // Apply the forward activation manually; finite-difference
20773            // each element.
20774            let act_apply = |kind: Activation, x: f32| -> f32 {
20775                match kind {
20776                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
20777                    Activation::Tanh => x.tanh(),
20778                    Activation::Silu => x / (1.0 + (-x).exp()),
20779                    Activation::Gelu => {
20780                        // Match the kernel's exact erf form.
20781                        const INV_SQRT2: f32 = 0.707_106_77;
20782                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
20783                    }
20784                    Activation::GeluApprox => {
20785                        const C: f32 = 0.797_884_6;
20786                        const A: f32 = 0.044_715;
20787                        let inner = C * (x + A * x * x * x);
20788                        0.5 * x * (1.0 + inner.tanh())
20789                    }
20790                    Activation::Exp => x.exp(),
20791                    Activation::Log => x.ln(),
20792                    Activation::Sqrt => x.sqrt(),
20793                    Activation::Rsqrt => 1.0 / x.sqrt(),
20794                    Activation::Neg => -x,
20795                    Activation::Relu => x.max(0.0),
20796                    Activation::Abs => x.abs(),
20797                    Activation::Round => x.round(),
20798                    Activation::Sin => x.sin(),
20799                    Activation::Cos => x.cos(),
20800                    Activation::Tan => x.tan(),
20801                    Activation::Atan => x.atan(),
20802                }
20803            };
20804            for i in 0..len {
20805                let xv = x_data[i];
20806                let plus = act_apply(kind, xv + eps);
20807                let minus = act_apply(kind, xv - eps);
20808                let num = (plus - minus) / (2.0 * eps) * dy[i];
20809                assert!(
20810                    (analytical[i] - num).abs() < tol,
20811                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
20812                    analytical[i]
20813                );
20814            }
20815        }
20816    }
20817
20818    /// Batched 3-D MatMul VJP — the transformer-attention shape
20819    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
20820    /// `Op::Transpose` with a perm that swaps the last two dims.
20821    #[test]
20822    fn matmul_3d_gradient_matches_numerical() {
20823        use rlx_ir::Philox4x32;
20824        let batch = 2usize;
20825        let m = 3usize;
20826        let k = 4usize;
20827        let n = 5usize;
20828        let mut rng = Philox4x32::new(101);
20829        let mut a_data = vec![0f32; batch * m * k];
20830        rng.fill_normal(&mut a_data);
20831        let mut b_data = vec![0f32; batch * k * n];
20832        rng.fill_normal(&mut b_data);
20833
20834        let f = DType::F32;
20835        let mut fwd = Graph::new("matmul_3d");
20836        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
20837        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
20838        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
20839        let loss = fwd.add_node(
20840            Op::Reduce {
20841                op: ReduceOp::Sum,
20842                axes: vec![0, 1, 2],
20843                keep_dim: false,
20844            },
20845            vec![mm],
20846            Shape::from_dims(&[], f),
20847        );
20848        fwd.set_outputs(vec![loss]);
20849
20850        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
20851        let d_out = bwd_graph
20852            .nodes()
20853            .iter()
20854            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20855            .map(|n| n.id)
20856            .unwrap();
20857
20858        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20859        let mut arena = crate::arena::Arena::from_plan(plan);
20860        let sched = compile_thunks(&bwd_graph, &arena);
20861        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
20862            let off = arena.byte_offset(id);
20863            let buf = arena.raw_buf_mut();
20864            unsafe {
20865                let p = buf.as_mut_ptr().add(off) as *mut f32;
20866                for (i, &v) in data.iter().enumerate() {
20867                    *p.add(i) = v;
20868                }
20869            }
20870        }
20871        execute_thunks(&sched, arena.raw_buf_mut());
20872        let gb_id = bwd_graph.outputs[1];
20873        let g_b: Vec<f32> = unsafe {
20874            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
20875            (0..batch * k * n).map(|i| *p.add(i)).collect()
20876        };
20877
20878        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
20879        let forward_loss = |b_vals: &[f32]| -> f32 {
20880            let mut out = vec![0f32; batch * m * n];
20881            for bi in 0..batch {
20882                for mi in 0..m {
20883                    for ni in 0..n {
20884                        let mut acc = 0f32;
20885                        for ki in 0..k {
20886                            acc +=
20887                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
20888                        }
20889                        out[bi * m * n + mi * n + ni] = acc;
20890                    }
20891                }
20892            }
20893            out.iter().sum()
20894        };
20895        let eps = 1e-3f32;
20896        let mut bp_p = b_data.clone();
20897        let mut g_b_num = vec![0f32; b_data.len()];
20898        for i in 0..b_data.len() {
20899            let s = bp_p[i];
20900            bp_p[i] = s + eps;
20901            let lp = forward_loss(&bp_p);
20902            bp_p[i] = s - eps;
20903            let lm = forward_loss(&bp_p);
20904            bp_p[i] = s;
20905            g_b_num[i] = (lp - lm) / (2.0 * eps);
20906        }
20907        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
20908            assert!(
20909                (a - n).abs() < 5e-3,
20910                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
20911            );
20912        }
20913    }
20914
20915    /// Composed `Op::Softmax` VJP — the gradient is built from
20916    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
20917    /// SoftmaxBackward kernel. Verifies the closed-form
20918    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
20919    /// 2-D logits tensor.
20920    #[test]
20921    fn softmax_gradient_matches_numerical() {
20922        use rlx_ir::Philox4x32;
20923        let n = 3usize;
20924        let c = 5usize;
20925        let mut rng = Philox4x32::new(57);
20926        let mut x_data = vec![0f32; n * c];
20927        rng.fill_normal(&mut x_data);
20928
20929        let f = DType::F32;
20930        let mut fwd = Graph::new("softmax_only");
20931        let xn = fwd.input("x", Shape::new(&[n, c], f));
20932        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
20933        // Loss = sum(softmax · target) for some random fixed target —
20934        // any linear loss will do; sum-of-all is the simplest and gives
20935        // a uniform gradient flow into the softmax.
20936        let loss = fwd.add_node(
20937            Op::Reduce {
20938                op: ReduceOp::Sum,
20939                axes: vec![0, 1],
20940                keep_dim: false,
20941            },
20942            vec![sm],
20943            Shape::from_dims(&[], f),
20944        );
20945        fwd.set_outputs(vec![loss]);
20946
20947        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
20948        // input so we can compare it directly. The forward NodeId for
20949        // `xn` doubles as its bwd-graph mirror.
20950        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
20951        let d_out = bwd_graph
20952            .nodes()
20953            .iter()
20954            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
20955            .map(|n| n.id)
20956            .unwrap();
20957
20958        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
20959        let mut arena = crate::arena::Arena::from_plan(plan);
20960        let sched = compile_thunks(&bwd_graph, &arena);
20961        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
20962            let off = arena.byte_offset(id);
20963            let buf = arena.raw_buf_mut();
20964            unsafe {
20965                let p = buf.as_mut_ptr().add(off) as *mut f32;
20966                for (i, &v) in data.iter().enumerate() {
20967                    *p.add(i) = v;
20968                }
20969            }
20970        }
20971        execute_thunks(&sched, arena.raw_buf_mut());
20972        let g_x_id = bwd_graph.outputs[1];
20973        let g_x: Vec<f32> = unsafe {
20974            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
20975            (0..n * c).map(|i| *p.add(i)).collect()
20976        };
20977
20978        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
20979        // analytically. So expect g_x ≈ 0 within FD precision. (This
20980        // doubles as a strong sanity check for the composition.)
20981        let forward_loss = |x: &[f32]| -> f32 {
20982            let mut total = 0f32;
20983            for ni in 0..n {
20984                let row = &x[ni * c..(ni + 1) * c];
20985                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
20986                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
20987                for &v in row {
20988                    total += (v - m).exp() / denom;
20989                }
20990            }
20991            total
20992        };
20993        let eps = 1e-3f32;
20994        let mut p = x_data.clone();
20995        for i in 0..x_data.len() {
20996            let s = p[i];
20997            p[i] = s + eps;
20998            let lp = forward_loss(&p);
20999            p[i] = s - eps;
21000            let lm = forward_loss(&p);
21001            p[i] = s;
21002            let num = (lp - lm) / (2.0 * eps);
21003            assert!(
21004                (g_x[i] - num).abs() < 5e-3,
21005                "softmax g_x[{i}]: analytical {} vs numerical {num}",
21006                g_x[i]
21007            );
21008        }
21009    }
21010
21011    /// LayerNorm VJP — three gradients in one pass:
21012    ///   d_x via `LayerNormBackwardInput`,
21013    ///   d_gamma via `LayerNormBackwardGamma`,
21014    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
21015    #[test]
21016    fn layer_norm_gradient_matches_numerical() {
21017        use rlx_ir::Philox4x32;
21018        let rows = 3usize;
21019        let h = 6usize;
21020        let mut rng = Philox4x32::new(1009);
21021        let mut x_data = vec![0f32; rows * h];
21022        rng.fill_normal(&mut x_data);
21023        let mut g_data = vec![0f32; h];
21024        rng.fill_normal(&mut g_data);
21025        for v in g_data.iter_mut() {
21026            *v = v.abs() + 0.5;
21027        }
21028        let mut b_data = vec![0f32; h];
21029        rng.fill_normal(&mut b_data);
21030        let eps = 1e-5f32;
21031
21032        let f = DType::F32;
21033        let mut fwd = Graph::new("ln_only");
21034        let xn = fwd.input("x", Shape::new(&[rows, h], f));
21035        let gp = fwd.param("gamma", Shape::new(&[h], f));
21036        let bp = fwd.param("beta", Shape::new(&[h], f));
21037        let ln = fwd.add_node(
21038            Op::LayerNorm { axis: -1, eps },
21039            vec![xn, gp, bp],
21040            Shape::new(&[rows, h], f),
21041        );
21042        let loss = fwd.add_node(
21043            Op::Reduce {
21044                op: ReduceOp::Sum,
21045                axes: vec![0, 1],
21046                keep_dim: false,
21047            },
21048            vec![ln],
21049            Shape::from_dims(&[], f),
21050        );
21051        fwd.set_outputs(vec![loss]);
21052
21053        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
21054        let d_out = bwd_graph
21055            .nodes()
21056            .iter()
21057            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21058            .map(|n| n.id)
21059            .unwrap();
21060
21061        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
21062        let mut arena = crate::arena::Arena::from_plan(plan);
21063        let sched = compile_thunks(&bwd_graph, &arena);
21064        for &(id, data) in &[
21065            (xn, &x_data),
21066            (gp, &g_data),
21067            (bp, &b_data),
21068            (d_out, &vec![1.0f32]),
21069        ] {
21070            let off = arena.byte_offset(id);
21071            let buf = arena.raw_buf_mut();
21072            unsafe {
21073                let p = buf.as_mut_ptr().add(off) as *mut f32;
21074                for (i, &v) in data.iter().enumerate() {
21075                    *p.add(i) = v;
21076                }
21077            }
21078        }
21079        execute_thunks(&sched, arena.raw_buf_mut());
21080        let read = |id: NodeId, n: usize| -> Vec<f32> {
21081            let off = arena.byte_offset(id);
21082            unsafe {
21083                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
21084                (0..n).map(|i| *p.add(i)).collect()
21085            }
21086        };
21087        let dx_a = read(bwd_graph.outputs[1], rows * h);
21088        let dg_a = read(bwd_graph.outputs[2], h);
21089        let db_a = read(bwd_graph.outputs[3], h);
21090
21091        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
21092            let mut total = 0f32;
21093            for r in 0..rows {
21094                let row = &x[r * h..(r + 1) * h];
21095                let mean = row.iter().sum::<f32>() / h as f32;
21096                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
21097                let inv_std = 1.0 / (var + eps).sqrt();
21098                for d in 0..h {
21099                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
21100                }
21101            }
21102            total
21103        };
21104        let h_eps = 1e-3f32;
21105
21106        let mut x_p = x_data.clone();
21107        for i in 0..x_p.len() {
21108            let s = x_p[i];
21109            x_p[i] = s + h_eps;
21110            let lp = forward_loss(&x_p, &g_data, &b_data);
21111            x_p[i] = s - h_eps;
21112            let lm = forward_loss(&x_p, &g_data, &b_data);
21113            x_p[i] = s;
21114            let num = (lp - lm) / (2.0 * h_eps);
21115            assert!(
21116                (dx_a[i] - num).abs() < 5e-3,
21117                "ln dx[{i}]: analytical {} vs numerical {num}",
21118                dx_a[i]
21119            );
21120        }
21121        let mut g_p = g_data.clone();
21122        for i in 0..g_p.len() {
21123            let s = g_p[i];
21124            g_p[i] = s + h_eps;
21125            let lp = forward_loss(&x_data, &g_p, &b_data);
21126            g_p[i] = s - h_eps;
21127            let lm = forward_loss(&x_data, &g_p, &b_data);
21128            g_p[i] = s;
21129            let num = (lp - lm) / (2.0 * h_eps);
21130            assert!(
21131                (dg_a[i] - num).abs() < 5e-3,
21132                "ln dg[{i}]: analytical {} vs numerical {num}",
21133                dg_a[i]
21134            );
21135        }
21136        let mut b_p = b_data.clone();
21137        for i in 0..b_p.len() {
21138            let s = b_p[i];
21139            b_p[i] = s + h_eps;
21140            let lp = forward_loss(&x_data, &g_data, &b_p);
21141            b_p[i] = s - h_eps;
21142            let lm = forward_loss(&x_data, &g_data, &b_p);
21143            b_p[i] = s;
21144            let num = (lp - lm) / (2.0 * h_eps);
21145            assert!(
21146                (db_a[i] - num).abs() < 5e-3,
21147                "ln db[{i}]: analytical {} vs numerical {num}",
21148                db_a[i]
21149            );
21150        }
21151    }
21152
21153    /// Single dense layer + softmax-cross-entropy + mean reduce —
21154    /// the simplest non-trivial training graph. Validates MatMul,
21155    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
21156    /// plumbing all at once.
21157    #[test]
21158    fn dense_sce_mean_gradient_matches_numerical() {
21159        use rlx_ir::Philox4x32;
21160        let bs = 4usize;
21161        let k_in = 3usize;
21162        let c = 5usize;
21163        let mut rng = Philox4x32::new(7);
21164        let mut x = vec![0f32; bs * k_in];
21165        rng.fill_normal(&mut x);
21166        let mut w_init = vec![0f32; k_in * c];
21167        rng.fill_normal(&mut w_init);
21168        let mut b_init = vec![0f32; c];
21169        rng.fill_normal(&mut b_init);
21170        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
21171
21172        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
21173        let f = DType::F32;
21174        let mut fwd = Graph::new("dense_sce");
21175        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
21176        let lb = fwd.input("labels", Shape::new(&[bs], f));
21177        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
21178        let bp = fwd.param("b", Shape::new(&[c], f));
21179        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
21180        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
21181        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
21182        let loss = fwd.add_node(
21183            Op::Reduce {
21184                op: ReduceOp::Sum,
21185                axes: vec![0],
21186                keep_dim: false,
21187            },
21188            vec![loss_per],
21189            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
21190            Shape::from_dims(&[], f),
21191        );
21192        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
21193        // less aggressively than Mean would, and gives us a closed-form
21194        // reference for the loss we expect.
21195        // For simplicity though, switch to Mean which the tests should also cover.
21196        // (Re-using `loss` with Sum here for now; the mean factor cancels in
21197        // the gradient comparison since both analytical and numerical use the
21198        // same forward.)
21199        fwd.set_outputs(vec![loss]);
21200
21201        // ── Backward graph ──
21202        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
21203        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
21204        // in bwd_graph match their fwd ids (the mirror keeps order).
21205        let d_out = bwd_graph
21206            .nodes()
21207            .iter()
21208            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21209            .map(|n| n.id)
21210            .expect("d_output input");
21211
21212        let (sched, mut arena) = prepare(
21213            &bwd_graph,
21214            &[
21215                (xn, &x),
21216                (lb, &labels),
21217                (wp, &w_init),
21218                (bp, &b_init),
21219                (d_out, &[1.0]),
21220            ],
21221        );
21222        execute_thunks(&sched, arena.raw_buf_mut());
21223
21224        let outs = &bwd_graph.outputs;
21225        let loss_id = outs[0];
21226        let gw_id = outs[1];
21227        let gb_id = outs[2];
21228        let loss_actual = read_arena(&arena, loss_id, 1)[0];
21229        let gw_actual = read_arena(&arena, gw_id, k_in * c);
21230        let gb_actual = read_arena(&arena, gb_id, c);
21231
21232        // ── Forward-only graph for finite differences ──
21233        // Re-use the same `fwd` graph; set up its own arena and rerun
21234        // for each perturbed parameter.
21235        let plan = rlx_opt::memory::plan_memory(&fwd);
21236        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21237        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21238        write_arena(&mut fwd_arena, xn, &x);
21239        write_arena(&mut fwd_arena, lb, &labels);
21240
21241        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
21242            write_arena(arena, wp, w);
21243            write_arena(arena, bp, b);
21244            execute_thunks(&fwd_sched, arena.raw_buf_mut());
21245            read_arena(arena, loss, 1)[0]
21246        };
21247
21248        // Sanity: the loss reported by the bwd graph matches the
21249        // forward-only graph on the unperturbed inputs.
21250        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
21251        assert!(
21252            (loss_actual - loss_check).abs() < 1e-4,
21253            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
21254        );
21255
21256        let eps = 1e-3f32;
21257        let mut w_perturbed = w_init.clone();
21258        let mut gw_numerical = vec![0f32; w_init.len()];
21259        for i in 0..w_init.len() {
21260            let saved = w_perturbed[i];
21261            w_perturbed[i] = saved + eps;
21262            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
21263            w_perturbed[i] = saved - eps;
21264            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
21265            w_perturbed[i] = saved;
21266            gw_numerical[i] = (lp - lm) / (2.0 * eps);
21267        }
21268        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
21269            assert!(
21270                (a - n).abs() < 5e-3,
21271                "grad_w[{i}]: analytical {a} vs numerical {n}"
21272            );
21273        }
21274
21275        let mut b_perturbed = b_init.clone();
21276        let mut gb_numerical = vec![0f32; b_init.len()];
21277        for i in 0..b_init.len() {
21278            let saved = b_perturbed[i];
21279            b_perturbed[i] = saved + eps;
21280            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
21281            b_perturbed[i] = saved - eps;
21282            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
21283            b_perturbed[i] = saved;
21284            gb_numerical[i] = (lp - lm) / (2.0 * eps);
21285        }
21286        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
21287            assert!(
21288                (a - n).abs() < 5e-3,
21289                "grad_b[{i}]: analytical {a} vs numerical {n}"
21290            );
21291        }
21292    }
21293
21294    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
21295    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
21296    #[test]
21297    fn dense_sce_mean_reduce_gradient_matches_numerical() {
21298        use rlx_ir::Philox4x32;
21299        let bs = 3usize;
21300        let k_in = 2usize;
21301        let c = 4usize;
21302        let mut rng = Philox4x32::new(13);
21303        let mut x = vec![0f32; bs * k_in];
21304        rng.fill_normal(&mut x);
21305        let mut w_init = vec![0f32; k_in * c];
21306        rng.fill_normal(&mut w_init);
21307        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
21308
21309        let f = DType::F32;
21310        let mut fwd = Graph::new("dense_sce_mean");
21311        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
21312        let lb = fwd.input("labels", Shape::new(&[bs], f));
21313        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
21314        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
21315        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
21316        let loss = fwd.add_node(
21317            Op::Reduce {
21318                op: ReduceOp::Mean,
21319                axes: vec![0],
21320                keep_dim: false,
21321            },
21322            vec![loss_per],
21323            Shape::from_dims(&[], f),
21324        );
21325        fwd.set_outputs(vec![loss]);
21326
21327        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
21328        let d_out = bwd_graph
21329            .nodes()
21330            .iter()
21331            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21332            .map(|n| n.id)
21333            .unwrap();
21334
21335        let (sched, mut arena) = prepare(
21336            &bwd_graph,
21337            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
21338        );
21339        execute_thunks(&sched, arena.raw_buf_mut());
21340
21341        let outs = &bwd_graph.outputs;
21342        let loss_id = outs[0];
21343        let gw_id = outs[1];
21344        let _ = read_arena(&arena, loss_id, 1)[0];
21345        let gw_actual = read_arena(&arena, gw_id, k_in * c);
21346
21347        let plan = rlx_opt::memory::plan_memory(&fwd);
21348        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21349        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21350        write_arena(&mut fwd_arena, xn, &x);
21351        write_arena(&mut fwd_arena, lb, &labels);
21352
21353        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
21354            write_arena(arena, wp, w);
21355            execute_thunks(&fwd_sched, arena.raw_buf_mut());
21356            read_arena(arena, loss, 1)[0]
21357        };
21358
21359        let eps = 1e-3f32;
21360        let mut wp_p = w_init.clone();
21361        let mut gw_num = vec![0f32; w_init.len()];
21362        for i in 0..w_init.len() {
21363            let s = wp_p[i];
21364            wp_p[i] = s + eps;
21365            let lp = run_loss(&mut fwd_arena, &wp_p);
21366            wp_p[i] = s - eps;
21367            let lm = run_loss(&mut fwd_arena, &wp_p);
21368            wp_p[i] = s;
21369            gw_num[i] = (lp - lm) / (2.0 * eps);
21370        }
21371        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
21372            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
21373        }
21374    }
21375    /// The full TinyConv-MNIST forward path (downsized) plumbed
21376    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
21377    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
21378    /// compose into a graph that produces correct gradients.
21379    #[test]
21380    fn tinyconv_full_gradient_matches_numerical() {
21381        use rlx_ir::Philox4x32;
21382        // Tiny shapes so finite differences finish in <1s.
21383        let n = 1usize;
21384        let c_in = 1usize;
21385        let h = 6usize;
21386        let w_in = 6usize;
21387        let c_mid = 2usize; // first conv output channels
21388        let kh = 3;
21389        let kw = 3;
21390        let h1 = h - kh + 1; // 4
21391        let w1 = w_in - kw + 1; // 4
21392        let h2 = h1 / 2;
21393        let w2 = w1 / 2; // 2 × 2 after 2× pool
21394        let flat = c_mid * h2 * w2; // 8
21395        let num_classes = 3usize;
21396
21397        let mut rng = Philox4x32::new(31);
21398        let mut x = vec![0f32; n * c_in * h * w_in];
21399        rng.fill_normal(&mut x);
21400        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
21401        rng.fill_normal(&mut wc);
21402        for v in wc.iter_mut() {
21403            *v *= 0.2;
21404        }
21405        // Shift conv-bias well away from the ReLU zero-boundary. Without
21406        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
21407        // pre-activation that happened to land near zero — making the
21408        // central-difference numerical gradient discontinuous and
21409        // diverge from the analytical (which assumes local smoothness).
21410        // +5.0 keeps every pre-activation positive for any random init
21411        // produced by Philox seed 31 with the wc/x scales used here, so
21412        // ReLU acts as an identity and finite differences are exact.
21413        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
21414        let mut wfc = vec![0f32; flat * num_classes];
21415        rng.fill_normal(&mut wfc);
21416        for v in wfc.iter_mut() {
21417            *v *= 0.5;
21418        }
21419        let mut bfc = vec![0f32; num_classes];
21420        rng.fill_normal(&mut bfc);
21421        let labels: Vec<f32> = vec![1.0]; // batch=1
21422
21423        let f = DType::F32;
21424        let mut fwd = Graph::new("tinyconv");
21425        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
21426        let lb = fwd.input("labels", Shape::new(&[n], f));
21427        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
21428        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
21429        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
21430        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
21431
21432        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
21433        let conv = fwd.add_node(
21434            Op::Conv {
21435                kernel_size: vec![kh, kw],
21436                stride: vec![1, 1],
21437                padding: vec![0, 0],
21438                dilation: vec![1, 1],
21439                groups: 1,
21440            },
21441            vec![xn, wcp],
21442            Shape::new(&[n, c_mid, h1, w1], f),
21443        );
21444        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
21445        // shape so the Add becomes a plain element-wise op. Going through
21446        // an explicit Reshape→Expand instead of relying on the Add to
21447        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
21448        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
21449        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
21450        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
21451        // across all positions instead of channel-broadcasting. Going
21452        // through Expand (a real broadcast thunk) avoids that path
21453        // entirely. The autodiff still exercises `unbroadcast` because
21454        // `Op::Expand`'s VJP reduces over the broadcast axes.
21455        let bc_4d = fwd.add_node(
21456            Op::Reshape {
21457                new_shape: vec![1, c_mid as i64, 1, 1],
21458            },
21459            vec![bcp],
21460            Shape::new(&[1, c_mid, 1, 1], f),
21461        );
21462        let bc_expanded = fwd.add_node(
21463            Op::Expand {
21464                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
21465            },
21466            vec![bc_4d],
21467            Shape::new(&[n, c_mid, h1, w1], f),
21468        );
21469        let conv_b = fwd.binary(
21470            BinaryOp::Add,
21471            conv,
21472            bc_expanded,
21473            Shape::new(&[n, c_mid, h1, w1], f),
21474        );
21475        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
21476        let pool = fwd.add_node(
21477            Op::Pool {
21478                kind: ReduceOp::Max,
21479                kernel_size: vec![2, 2],
21480                stride: vec![2, 2],
21481                padding: vec![0, 0],
21482            },
21483            vec![relu],
21484            Shape::new(&[n, c_mid, h2, w2], f),
21485        );
21486        let flatn = fwd.add_node(
21487            Op::Reshape {
21488                new_shape: vec![n as i64, flat as i64],
21489            },
21490            vec![pool],
21491            Shape::new(&[n, flat], f),
21492        );
21493        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
21494        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
21495        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
21496        let loss = fwd.add_node(
21497            Op::Reduce {
21498                op: ReduceOp::Mean,
21499                axes: vec![0],
21500                keep_dim: false,
21501            },
21502            vec![loss_per],
21503            Shape::from_dims(&[], f),
21504        );
21505        fwd.set_outputs(vec![loss]);
21506
21507        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
21508        let d_out = bwd_graph
21509            .nodes()
21510            .iter()
21511            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
21512            .map(|n| n.id)
21513            .unwrap();
21514
21515        let (sched, mut arena) = prepare(
21516            &bwd_graph,
21517            &[
21518                (xn, &x),
21519                (lb, &labels),
21520                (wcp, &wc),
21521                (bcp, &bc),
21522                (wfp, &wfc),
21523                (bfp, &bfc),
21524                (d_out, &[1.0]),
21525            ],
21526        );
21527        execute_thunks(&sched, arena.raw_buf_mut());
21528
21529        let outs = bwd_graph.outputs.clone();
21530        let loss_id = outs[0];
21531        let g_wc_id = outs[1];
21532        let g_bc_id = outs[2];
21533        let g_wfc_id = outs[3];
21534        let g_bfc_id = outs[4];
21535        let loss_actual = read_arena(&arena, loss_id, 1)[0];
21536        let g_wc = read_arena(&arena, g_wc_id, wc.len());
21537        let g_bc = read_arena(&arena, g_bc_id, bc.len());
21538        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
21539        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
21540
21541        // Forward-only arena for finite differences.
21542        let plan = rlx_opt::memory::plan_memory(&fwd);
21543        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
21544        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
21545        write_arena(&mut fwd_arena, xn, &x);
21546        write_arena(&mut fwd_arena, lb, &labels);
21547
21548        // Closure variant: we need to set all four params each call so
21549        // perturbations to one don't leak between sweeps.
21550        let run_loss = |arena: &mut crate::arena::Arena,
21551                        wc: &[f32],
21552                        bc: &[f32],
21553                        wfc: &[f32],
21554                        bfc: &[f32]|
21555         -> f32 {
21556            write_arena(arena, wcp, wc);
21557            write_arena(arena, bcp, bc);
21558            write_arena(arena, wfp, wfc);
21559            write_arena(arena, bfp, bfc);
21560            execute_thunks(&fwd_sched, arena.raw_buf_mut());
21561            read_arena(arena, loss, 1)[0]
21562        };
21563
21564        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
21565        assert!(
21566            (loss_actual - loss_check).abs() < 1e-4,
21567            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
21568        );
21569
21570        let eps = 1e-3f32;
21571        let check_grad = |arena: &mut crate::arena::Arena,
21572                          name: &str,
21573                          analytical: &[f32],
21574                          mut perturb: Box<
21575            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
21576        >,
21577                          n: usize| {
21578            for i in 0..n {
21579                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
21580                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
21581                let num = (lp - lm) / (2.0 * eps);
21582                assert!(
21583                    (analytical[i] - num).abs() < 5e-3,
21584                    "{name}[{i}]: analytical {} vs numerical {num}",
21585                    analytical[i]
21586                );
21587            }
21588        };
21589
21590        // Helper to perturb one param and run forward. Kept as a
21591        // reference for the explicit per-param sweep pattern below.
21592        #[allow(unused_macros)]
21593        macro_rules! sweep {
21594            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
21595                let n = $base.len();
21596                for i in 0..n {
21597                    let mut p = $base.clone();
21598                    let s = p[i];
21599                    p[i] = s + eps;
21600                    let lp = {
21601                        let $set_param = &p;
21602                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
21603                        // Reset others, set the one being swept, run.
21604                        // (the macro receives one of the four params via $set_param)
21605                        let _ = $set_param;
21606                        // Fall through to the explicit per-param helper:
21607                        0.0_f32
21608                    };
21609                    let _ = lp;
21610                }
21611            }};
21612        }
21613        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
21614
21615        // Per-param sweeps (explicit, not macro — clearer).
21616        for i in 0..wc.len() {
21617            let mut p = wc.clone();
21618            let s = p[i];
21619            p[i] = s + eps;
21620            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
21621            p[i] = s - eps;
21622            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
21623            let num = (lp - lm) / (2.0 * eps);
21624            assert!(
21625                (g_wc[i] - num).abs() < 5e-3,
21626                "g_wc[{i}]: {} vs {num}",
21627                g_wc[i]
21628            );
21629        }
21630        for i in 0..bc.len() {
21631            let mut p = bc.clone();
21632            let s = p[i];
21633            p[i] = s + eps;
21634            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
21635            p[i] = s - eps;
21636            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
21637            let num = (lp - lm) / (2.0 * eps);
21638            assert!(
21639                (g_bc[i] - num).abs() < 5e-3,
21640                "g_bc[{i}]: {} vs {num}",
21641                g_bc[i]
21642            );
21643        }
21644        for i in 0..wfc.len() {
21645            let mut p = wfc.clone();
21646            let s = p[i];
21647            p[i] = s + eps;
21648            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
21649            p[i] = s - eps;
21650            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
21651            let num = (lp - lm) / (2.0 * eps);
21652            assert!(
21653                (g_wfc[i] - num).abs() < 5e-3,
21654                "g_wfc[{i}]: {} vs {num}",
21655                g_wfc[i]
21656            );
21657        }
21658        for i in 0..bfc.len() {
21659            let mut p = bfc.clone();
21660            let s = p[i];
21661            p[i] = s + eps;
21662            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
21663            p[i] = s - eps;
21664            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
21665            let num = (lp - lm) / (2.0 * eps);
21666            assert!(
21667                (g_bfc[i] - num).abs() < 5e-3,
21668                "g_bfc[{i}]: {} vs {num}",
21669                g_bfc[i]
21670            );
21671        }
21672    }
21673
21674    /// Negative case: a Narrow whose output has multiple consumers
21675    /// must NOT be fused (we can't elide its write — something else
21676    /// reads it).
21677    #[test]
21678    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
21679        let f = DType::F32;
21680        let mut g = Graph::new("nr_skip");
21681        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
21682        let cos = g.input("cos", Shape::new(&[16], f));
21683        let sin = g.input("sin", Shape::new(&[16], f));
21684        let q = g.narrow_(qkv, 2, 0, 64);
21685        let q_rope = g.rope(q, cos, sin, 16);
21686        // Second consumer of `q` blocks the fusion.
21687        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
21688        g.set_outputs(vec![q_rope, q_dup]);
21689
21690        let plan = rlx_opt::memory::plan_memory(&g);
21691        let arena = crate::arena::Arena::from_plan(plan);
21692        let sched = compile_thunks(&g, &arena);
21693
21694        let narrow_count = sched
21695            .thunks
21696            .iter()
21697            .filter(|t| matches!(t, Thunk::Narrow { .. }))
21698            .count();
21699        assert!(
21700            narrow_count >= 1,
21701            "Narrow with multiple consumers must NOT be fused away"
21702        );
21703    }
21704
21705    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
21706    //
21707    // Validates: forward execution inlines fwd_body; VJP rule inlines
21708    // vjp_body in place of recursing into fwd_body; JVP rule inlines
21709    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
21710    // would yield a *different* gradient than the override, so we know
21711    // the override actually fired.
21712
21713    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
21714    /// without override AD bodies. Verifies the body is compiled,
21715    /// constants in the body fill correctly, and the output lands at
21716    /// the outer node's slot.
21717    #[test]
21718    fn custom_fn_forward_inlines_body() {
21719        let s = Shape::new(&[3], DType::F32);
21720
21721        // Body: f(x) = x + 1
21722        let mut body = Graph::new("addone_body");
21723        let x = body.input("x", s.clone());
21724        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
21725        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
21726        let y = body.binary(BinaryOp::Add, x, one, s.clone());
21727        body.set_outputs(vec![y]);
21728
21729        let mut g = Graph::new("custom_fn_outer");
21730        let xin = g.input("x_in", s.clone());
21731        let cf = g.custom_fn(vec![xin], body, None, None);
21732        g.set_outputs(vec![cf]);
21733
21734        let xs = vec![10.0_f32, 20.0, 30.0];
21735        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
21736        execute_thunks(&sched, arena.raw_buf_mut());
21737        let got = read_arena(&arena, cf, 3);
21738        assert_eq!(got, vec![11.0, 21.0, 31.0]);
21739    }
21740
21741    /// Locate an Op::Input or Op::Param by name in a graph.
21742    fn find_named(graph: &Graph, want: &str) -> NodeId {
21743        for n in graph.nodes() {
21744            let name = match &n.op {
21745                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
21746                _ => None,
21747            };
21748            if name == Some(want) {
21749                return n.id;
21750            }
21751        }
21752        panic!("no node named {want:?} in graph");
21753    }
21754
21755    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
21756    /// reported gradient should be 2 — different from the natural 1
21757    /// you'd get by recursing into the identity body.
21758    #[test]
21759    fn custom_fn_vjp_overrides_natural_gradient() {
21760        use rlx_opt::autodiff::grad_with_loss;
21761        let s = Shape::new(&[1], DType::F32);
21762
21763        let mut fwd = Graph::new("id_fwd");
21764        let x = fwd.input("x", s.clone());
21765        fwd.set_outputs(vec![x]);
21766
21767        let mut vjp_g = Graph::new("id_vjp");
21768        let _x_p = vjp_g.input("x", s.clone());
21769        let _y_p = vjp_g.input("primal_output", s.clone());
21770        let dy = vjp_g.input("d_output", s.clone());
21771        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21772        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21773        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
21774        vjp_g.set_outputs(vec![dx]);
21775
21776        let mut g = Graph::new("outer");
21777        let xp = g.param("x", s.clone());
21778        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
21779        g.set_outputs(vec![cf]);
21780
21781        let bwd = grad_with_loss(&g, &[xp]);
21782        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
21783
21784        let xb = find_named(&bwd, "x");
21785        let dout = find_named(&bwd, "d_output");
21786        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
21787        execute_thunks(&sched, arena.raw_buf_mut());
21788        let loss = read_arena(&arena, bwd.outputs[0], 1);
21789        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
21790        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
21791        assert!(
21792            (dx_v[0] - 2.0).abs() < 1e-6,
21793            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
21794            dx_v[0]
21795        );
21796    }
21797
21798    /// VJP override: f(a, b) = a*b with vjp_body returning
21799    /// (b * d_output, a * d_output). Validates routing of multiple
21800    /// primals + d_output through the override; matches the natural
21801    /// autodiff-of-Mul gradient (b, a).
21802    #[test]
21803    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
21804        use rlx_opt::autodiff::grad_with_loss;
21805        let s = Shape::new(&[1], DType::F32);
21806
21807        let mut fwd = Graph::new("mul_fwd");
21808        let a_f = fwd.input("a", s.clone());
21809        let b_f = fwd.input("b", s.clone());
21810        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
21811        fwd.set_outputs(vec![y_f]);
21812
21813        let mut vjp_g = Graph::new("mul_vjp");
21814        let a_v = vjp_g.input("a", s.clone());
21815        let b_v = vjp_g.input("b", s.clone());
21816        let _y_v = vjp_g.input("primal_output", s.clone());
21817        let dy_v = vjp_g.input("d_output", s.clone());
21818        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
21819        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
21820        vjp_g.set_outputs(vec![da, db]);
21821
21822        let mut g = Graph::new("outer");
21823        let ap = g.param("a", s.clone());
21824        let bp = g.param("b", s.clone());
21825        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
21826        g.set_outputs(vec![cf]);
21827
21828        let bwd = grad_with_loss(&g, &[ap, bp]);
21829        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
21830
21831        let ab = find_named(&bwd, "a");
21832        let bb = find_named(&bwd, "b");
21833        let dout = find_named(&bwd, "d_output");
21834        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
21835        execute_thunks(&sched, arena.raw_buf_mut());
21836        let loss = read_arena(&arena, bwd.outputs[0], 1);
21837        let da_v = read_arena(&arena, bwd.outputs[1], 1);
21838        let db_v = read_arena(&arena, bwd.outputs[2], 1);
21839        assert!((loss[0] - 15.0).abs() < 1e-5);
21840        assert!(
21841            (da_v[0] - 5.0).abs() < 1e-5,
21842            "da should be b=5.0, got {}",
21843            da_v[0]
21844        );
21845        assert!(
21846            (db_v[0] - 3.0).abs() < 1e-5,
21847            "db should be a=3.0, got {}",
21848            db_v[0]
21849        );
21850    }
21851
21852    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
21853    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
21854    #[test]
21855    fn custom_fn_jvp_overrides_natural_tangent() {
21856        use rlx_opt::autodiff_fwd::jvp;
21857        let s = Shape::new(&[1], DType::F32);
21858
21859        let mut fwd = Graph::new("id_fwd");
21860        let x = fwd.input("x", s.clone());
21861        fwd.set_outputs(vec![x]);
21862
21863        let mut jvp_g = Graph::new("id_jvp");
21864        let _x_p = jvp_g.input("x", s.clone());
21865        let tx = jvp_g.input("tangent_0", s.clone());
21866        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
21867        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
21868        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
21869        jvp_g.set_outputs(vec![ty]);
21870
21871        let mut g = Graph::new("outer");
21872        let xin = g.input("x_in", s.clone());
21873        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
21874        g.set_outputs(vec![cf]);
21875
21876        let fwd_g = jvp(&g, &[xin]);
21877        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
21878
21879        let xb = find_named(&fwd_g, "x_in");
21880        let tan = find_named(&fwd_g, "tangent_x_in");
21881        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
21882        execute_thunks(&sched, arena.raw_buf_mut());
21883        let y = read_arena(&arena, fwd_g.outputs[0], 1);
21884        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
21885        assert!((y[0] - 7.0).abs() < 1e-6);
21886        assert!(
21887            (ty_v[0] - 2.0).abs() < 1e-6,
21888            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
21889            ty_v[0]
21890        );
21891    }
21892
21893    /// IR-level basic test: `DType::C64` is wired through the dtype
21894    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
21895    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
21896    /// 16 bytes.
21897    #[test]
21898    fn c64_dtype_storage_layout() {
21899        assert_eq!(
21900            DType::C64.size_bytes(),
21901            8,
21902            "C64 should be 8 bytes (f32 real + f32 imag)"
21903        );
21904        assert!(DType::C64.is_complex());
21905        assert!(!DType::C64.is_float());
21906
21907        // A length-2 C64 buffer should have shape size_bytes = 16.
21908        let s = Shape::new(&[2], DType::C64);
21909        assert_eq!(s.size_bytes().unwrap(), 16);
21910    }
21911
21912    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
21913    //
21914    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
21915    // output = a OP b. Run through CompileResult and compare against
21916    // the closed-form complex arithmetic on the four chosen pairs.
21917
21918    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
21919        let n = a.len();
21920        let s = Shape::new(&[n], DType::C64);
21921        let mut g = Graph::new("c64_bin");
21922        let in_a = g.input("a", s.clone());
21923        let in_b = g.input("b", s.clone());
21924        let out = g.binary(op, in_a, in_b, s.clone());
21925        g.set_outputs(vec![out]);
21926
21927        let plan = rlx_opt::memory::plan_memory(&g);
21928        let mut arena = crate::arena::Arena::from_plan(plan);
21929        let sched = compile_thunks(&g, &arena);
21930
21931        let a_off = arena.byte_offset(in_a);
21932        let b_off = arena.byte_offset(in_b);
21933        let out_off = arena.byte_offset(out);
21934        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
21935        let buf = arena.raw_buf_mut();
21936        unsafe {
21937            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
21938            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
21939            for (i, &(re, im)) in a.iter().enumerate() {
21940                *pa.add(2 * i) = re;
21941                *pa.add(2 * i + 1) = im;
21942            }
21943            for (i, &(re, im)) in b.iter().enumerate() {
21944                *pb.add(2 * i) = re;
21945                *pb.add(2 * i + 1) = im;
21946            }
21947        }
21948        execute_thunks(&sched, arena.raw_buf_mut());
21949        let raw_out: Vec<f32> = unsafe {
21950            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
21951            (0..(2 * n)).map(|i| *p.add(i)).collect()
21952        };
21953        (0..n)
21954            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
21955            .collect()
21956    }
21957
21958    #[track_caller]
21959    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
21960        let dr = (got.0 - expected.0).abs();
21961        let di = (got.1 - expected.1).abs();
21962        assert!(
21963            dr < tol && di < tol,
21964            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
21965            got.0,
21966            got.1,
21967            expected.0,
21968            expected.1
21969        );
21970    }
21971
21972    #[test]
21973    fn c64_binary_add_matches_complex_arithmetic() {
21974        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
21975        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
21976        let out = run_c64_binary(BinaryOp::Add, &a, &b);
21977        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
21978        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
21979    }
21980
21981    #[test]
21982    fn c64_binary_sub_matches_complex_arithmetic() {
21983        let a = [(5.0_f32, 1.0_f32)];
21984        let b = [(2.0_f32, 3.0_f32)];
21985        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
21986        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
21987    }
21988
21989    #[test]
21990    fn c64_binary_mul_matches_complex_arithmetic() {
21991        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
21992        let a = [(1.0_f32, 2.0_f32)];
21993        let b = [(3.0_f32, 4.0_f32)];
21994        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
21995        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
21996    }
21997
21998    #[test]
21999    fn c64_binary_div_matches_complex_arithmetic() {
22000        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
22001        //                     = (11 + 2i) / 25
22002        //                     = 0.44 + 0.08i
22003        let a = [(1.0_f32, 2.0_f32)];
22004        let b = [(3.0_f32, 4.0_f32)];
22005        let out = run_c64_binary(BinaryOp::Div, &a, &b);
22006        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
22007    }
22008
22009    #[test]
22010    fn c64_binary_mul_identity_one_is_no_op() {
22011        // (a + bi) · (1 + 0i) = a + bi.
22012        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
22013        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
22014        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
22015        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
22016        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
22017    }
22018
22019    #[test]
22020    fn c64_binary_mul_by_i_rotates_90_degrees() {
22021        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
22022        let a = [(1.0_f32, 0.0_f32)];
22023        let b = [(0.0_f32, 1.0_f32)];
22024        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
22025        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
22026    }
22027
22028    #[test]
22029    fn c64_binary_div_by_self_gives_unity() {
22030        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
22031        let out = run_c64_binary(BinaryOp::Div, &a, &a);
22032        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
22033        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
22034    }
22035
22036    #[test]
22037    #[should_panic(expected = "C64: complex max/min/pow")]
22038    fn c64_binary_max_is_rejected_at_lowering() {
22039        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
22040    }
22041
22042    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
22043        let n = a.len();
22044        let s = Shape::new(&[n], DType::C64);
22045        let mut g = Graph::new("c64_act");
22046        let in_a = g.input("a", s.clone());
22047        let out = g.activation(act, in_a, s.clone());
22048        g.set_outputs(vec![out]);
22049        let plan = rlx_opt::memory::plan_memory(&g);
22050        let mut arena = crate::arena::Arena::from_plan(plan);
22051        let sched = compile_thunks(&g, &arena);
22052        let a_off = arena.byte_offset(in_a);
22053        let out_off = arena.byte_offset(out);
22054        let buf = arena.raw_buf_mut();
22055        unsafe {
22056            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
22057            for (i, &(re, im)) in a.iter().enumerate() {
22058                *pa.add(2 * i) = re;
22059                *pa.add(2 * i + 1) = im;
22060            }
22061        }
22062        execute_thunks(&sched, arena.raw_buf_mut());
22063        let raw: Vec<f32> = unsafe {
22064            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22065            (0..(2 * n)).map(|i| *p.add(i)).collect()
22066        };
22067        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
22068    }
22069
22070    #[test]
22071    fn c64_activation_neg_negates_both_components() {
22072        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
22073        let out = run_c64_activation(Activation::Neg, &inp);
22074        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
22075        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
22076    }
22077
22078    #[test]
22079    fn c64_activation_exp_matches_euler() {
22080        // exp(0 + i·π) = -1 + 0i.
22081        // exp(1 + 0i) = e ≈ 2.71828.
22082        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
22083        let out = run_c64_activation(Activation::Exp, &inp);
22084        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
22085        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
22086    }
22087
22088    #[test]
22089    fn c64_activation_log_matches_principal_branch() {
22090        // log(1 + 0i) = 0.
22091        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
22092        // log(-1 + 0i) = 0 + i·π.
22093        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
22094        let out = run_c64_activation(Activation::Log, &inp);
22095        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
22096        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
22097        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
22098    }
22099
22100    #[test]
22101    fn c64_activation_sqrt_squared_recovers_input() {
22102        // For positive-real-part inputs, sqrt(z)² should equal z exactly
22103        // to f32 noise.
22104        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
22105        let roots = run_c64_activation(Activation::Sqrt, &inp);
22106        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
22107        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
22108        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
22109    }
22110
22111    #[test]
22112    #[should_panic(expected = "no natural complex extension")]
22113    fn c64_activation_relu_is_rejected_at_lowering() {
22114        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
22115    }
22116
22117    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
22118
22119    /// Forward `|z|²`: returns `[n]` f32.
22120    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
22121        let n = z.len();
22122        let mut g = Graph::new("cns_fwd");
22123        let in_z = g.input("z", Shape::new(&[n], DType::C64));
22124        let out = g.complex_norm_sq(in_z);
22125        g.set_outputs(vec![out]);
22126        let plan = rlx_opt::memory::plan_memory(&g);
22127        let mut arena = crate::arena::Arena::from_plan(plan);
22128        let sched = compile_thunks(&g, &arena);
22129        let z_off = arena.byte_offset(in_z);
22130        let out_off = arena.byte_offset(out);
22131        let buf = arena.raw_buf_mut();
22132        unsafe {
22133            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
22134            for (i, &(re, im)) in z.iter().enumerate() {
22135                *pz.add(2 * i) = re;
22136                *pz.add(2 * i + 1) = im;
22137            }
22138        }
22139        execute_thunks(&sched, arena.raw_buf_mut());
22140        unsafe {
22141            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22142            (0..n).map(|i| *p.add(i)).collect()
22143        }
22144    }
22145
22146    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
22147    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
22148        let n = z.len();
22149        let mut gr = Graph::new("cns_bwd");
22150        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
22151        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
22152        let out = gr.complex_norm_sq_backward(in_z, in_g);
22153        gr.set_outputs(vec![out]);
22154        let plan = rlx_opt::memory::plan_memory(&gr);
22155        let mut arena = crate::arena::Arena::from_plan(plan);
22156        let sched = compile_thunks(&gr, &arena);
22157        let z_off = arena.byte_offset(in_z);
22158        let g_off = arena.byte_offset(in_g);
22159        let out_off = arena.byte_offset(out);
22160        let buf = arena.raw_buf_mut();
22161        unsafe {
22162            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
22163            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
22164            for (i, &(re, im)) in z.iter().enumerate() {
22165                *pz.add(2 * i) = re;
22166                *pz.add(2 * i + 1) = im;
22167            }
22168            for (i, &v) in g.iter().enumerate() {
22169                *pg.add(i) = v;
22170            }
22171        }
22172        execute_thunks(&sched, arena.raw_buf_mut());
22173        unsafe {
22174            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22175            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
22176        }
22177    }
22178
22179    #[test]
22180    fn complex_norm_sq_matches_textbook() {
22181        // |3 + 4i|² = 9 + 16 = 25.
22182        // |1 + 0i|² = 1.
22183        // |0 + 0i|² = 0.
22184        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
22185        let out = run_complex_norm_sq(&z);
22186        assert!((out[0] - 25.0).abs() < 1e-5);
22187        assert!((out[1] - 1.0).abs() < 1e-6);
22188        assert!(out[2].abs() < 1e-6);
22189    }
22190
22191    #[test]
22192    fn complex_norm_sq_backward_matches_wirtinger_formula() {
22193        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
22194        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
22195        let g = [1.0_f32, 1.0_f32];
22196        let dz = run_complex_norm_sq_bwd(&z, &g);
22197        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
22198        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
22199    }
22200
22201    #[test]
22202    fn complex_norm_sq_backward_scales_with_upstream() {
22203        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
22204        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
22205        let g = [0.5_f32, -2.0_f32];
22206        let dz = run_complex_norm_sq_bwd(&z, &g);
22207        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
22208        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
22209    }
22210
22211    /// Multi-output Op::CustomFn via the concat-with-Narrow design
22212    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
22213    /// fwd_body returns two outputs (x², 2x), then materialize each
22214    /// via the MultiOutputHandle and verify both numerically.
22215    #[test]
22216    fn custom_fn_multi_extracts_each_subgraph_output() {
22217        use rlx_ir::ops::special::MultiOutputHandle;
22218
22219        let _ = MultiOutputHandle {
22220            source: NodeId(0),
22221            sub_shapes: vec![],
22222            offsets: vec![],
22223        }; // import sanity
22224
22225        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
22226        let mut body = Graph::new("multi_body");
22227        let s3 = Shape::new(&[3], DType::F32);
22228        let x = body.input("x", s3.clone());
22229        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
22230        let two = body.add_node(
22231            Op::Constant {
22232                data: vec![
22233                    2.0_f32.to_le_bytes(),
22234                    2.0_f32.to_le_bytes(),
22235                    2.0_f32.to_le_bytes(),
22236                ]
22237                .into_iter()
22238                .flatten()
22239                .collect(),
22240            },
22241            vec![],
22242            s3.clone(),
22243        );
22244        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
22245        body.set_outputs(vec![x_sq, two_x]);
22246
22247        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
22248        let mut outer = Graph::new("multi_outer");
22249        let in_x = outer.input("xin", s3.clone());
22250        let handle = outer.custom_fn_multi(vec![in_x], body);
22251        assert_eq!(handle.n_outputs(), 2);
22252        let out0 = handle.output(&mut outer, 0); // x²
22253        let out1 = handle.output(&mut outer, 1); // 2x
22254        outer.set_outputs(vec![out0, out1]);
22255
22256        let plan = rlx_opt::memory::plan_memory(&outer);
22257        let mut arena = crate::arena::Arena::from_plan(plan);
22258        let sched = compile_thunks(&outer, &arena);
22259        let xin_off = arena.byte_offset(in_x);
22260        let out0_off = arena.byte_offset(out0);
22261        let out1_off = arena.byte_offset(out1);
22262        let xs = [1.0_f32, 2.0, 3.0];
22263        unsafe {
22264            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
22265            for (i, &v) in xs.iter().enumerate() {
22266                *p.add(i) = v;
22267            }
22268        }
22269        execute_thunks(&sched, arena.raw_buf_mut());
22270        let out0_v: Vec<f32> = unsafe {
22271            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
22272            (0..3).map(|i| *p.add(i)).collect()
22273        };
22274        let out1_v: Vec<f32> = unsafe {
22275            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
22276            (0..3).map(|i| *p.add(i)).collect()
22277        };
22278        // x² = [1, 4, 9]; 2x = [2, 4, 6].
22279        for i in 0..3 {
22280            assert!(
22281                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
22282                "out0[{i}] = {} != x² = {}",
22283                out0_v[i],
22284                xs[i] * xs[i]
22285            );
22286            assert!(
22287                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
22288                "out1[{i}] = {} != 2x = {}",
22289                out1_v[i],
22290                2.0 * xs[i]
22291            );
22292        }
22293    }
22294
22295    #[test]
22296    fn complex_norm_sq_gradient_matches_finite_difference() {
22297        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
22298        let z = [(3.0_f32, 4.0_f32)];
22299        let eps = 1e-3_f32;
22300        let v0 = run_complex_norm_sq(&z)[0];
22301        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
22302        let v1 = run_complex_norm_sq(&z_pert)[0];
22303        let fd_re = (v1 - v0) / eps;
22304        let analytic_re = 2.0 * z[0].0;
22305        assert!((fd_re - analytic_re).abs() < 1e-2);
22306
22307        // ∂/∂im at z = (3, 4) is 2·im = 8.
22308        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
22309        let v2 = run_complex_norm_sq(&z_pert_im)[0];
22310        let fd_im = (v2 - v0) / eps;
22311        let analytic_im = 2.0 * z[0].1;
22312        assert!((fd_im - analytic_im).abs() < 1e-2);
22313
22314        // Compare with the Wirtinger backward at upstream g = 1.
22315        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
22316        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
22317        // 2·im) — that's the factor 2 difference between Wirtinger
22318        // ∂/∂z̄ and the real-vector gradient on (re, im).
22319        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
22320        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
22321        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
22322    }
22323
22324    /// Direct regression test for the 5-D mid-shape singleton broadcast
22325    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
22326    /// The SAM port worked around this by `concat`-tiling the rhs; this
22327    /// test verifies the in-graph broadcast path is bit-correct.
22328    #[test]
22329    fn binary_full_5d_mid_singleton_broadcast() {
22330        let bh = 2usize;
22331        let h = 3;
22332        let w = 4;
22333        let f = DType::F32;
22334
22335        let mut g = Graph::new("bcast_5d");
22336        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
22337        // rhs shape with size-1 at axis 3 (mid-shape singleton).
22338        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
22339        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
22340        g.set_outputs(vec![out]);
22341
22342        // Deterministic data.
22343        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
22344        let rhs_data: Vec<f32> = (0..bh * h * w * w)
22345            .map(|i| (i as f32 + 100.0) * 0.01)
22346            .collect();
22347
22348        // Compute expected output by hand.
22349        let mut expected = vec![0f32; bh * h * w * h * w];
22350        for b_ in 0..bh {
22351            for hq in 0..h {
22352                for wq in 0..w {
22353                    for hk in 0..h {
22354                        for wk in 0..w {
22355                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
22356                            // rhs has hk dim = 1, so it's always index 0 there.
22357                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
22358                            expected[li] = lhs_data[li] + rhs_data[ri];
22359                        }
22360                    }
22361                }
22362            }
22363        }
22364
22365        let plan = rlx_opt::memory::plan_memory(&g);
22366        let mut arena = crate::arena::Arena::from_plan(plan);
22367        let sched = compile_thunks(&g, &arena);
22368        let lhs_off = arena.byte_offset(lhs);
22369        let rhs_off = arena.byte_offset(rhs);
22370        let out_off = arena.byte_offset(out);
22371        let buf = arena.raw_buf_mut();
22372        unsafe {
22373            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
22374            for (i, &v) in lhs_data.iter().enumerate() {
22375                *p.add(i) = v;
22376            }
22377            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
22378            for (i, &v) in rhs_data.iter().enumerate() {
22379                *p.add(i) = v;
22380            }
22381        }
22382        execute_thunks(&sched, arena.raw_buf_mut());
22383        let actual: Vec<f32> = unsafe {
22384            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
22385            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
22386        };
22387
22388        // Bit-exact check.
22389        let mut max_diff = 0f32;
22390        let mut max_idx = 0;
22391        for i in 0..actual.len() {
22392            let d = (actual[i] - expected[i]).abs();
22393            if d > max_diff {
22394                max_diff = d;
22395                max_idx = i;
22396            }
22397        }
22398        assert!(
22399            max_diff < 1e-6,
22400            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
22401             (actual={}, expected={})",
22402            actual[max_idx],
22403            expected[max_idx]
22404        );
22405    }
22406
22407    #[test]
22408    fn layer_norm2d_and_conv_transpose2d_kernels() {
22409        let mut out = vec![0f32; 8];
22410        crate::kernels::layer_norm2d_nchw(
22411            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
22412            &[1.0, 1.0],
22413            &[0.0, 0.0],
22414            &mut out,
22415            1,
22416            2,
22417            2,
22418            2,
22419            1e-5,
22420        );
22421        let mean0: f32 = (1.0 + 3.0) / 2.0;
22422        assert!((out[0] - mean0).abs() > 0.1);
22423
22424        let mut up = vec![0f32; 4];
22425        crate::kernels::conv_transpose2d_nchw(
22426            &[2.0],
22427            &[1.0, 0.0, 0.0, 1.0],
22428            &mut up,
22429            1,
22430            1,
22431            1,
22432            1,
22433            1,
22434            2,
22435            2,
22436            2,
22437            2,
22438            2,
22439            2,
22440            0,
22441            0,
22442            1,
22443            1,
22444            1,
22445        );
22446        assert!((up[0] - 2.0).abs() < 1e-5);
22447        assert!((up[3] - 2.0).abs() < 1e-5);
22448    }
22449}