Skip to main content

rlx_autodiff/
autodiff.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reverse-mode automatic differentiation (VJP transform).
17//!
18//! Takes a forward graph that produces a single scalar output (the
19//! loss) and returns a new graph whose outputs are
20//! `[loss, grad_w_param0, grad_w_param1, ...]` for the parameters
21//! listed in `wrt`. Running the returned graph through any backend
22//! gives the loss and all parameter gradients in one pass.
23//!
24//! ## Implementation
25//!
26//! Standard reverse-mode AD: walk the forward graph in reverse topo
27//! order; for each visited node, emit gradient nodes that contribute
28//! to the gradients of its inputs. Multiple consumers' contributions
29//! are summed via `BinaryOp::Add`.
30//!
31//! For ops with a closed-form gradient kernel (`ReluBackward`,
32//! `MaxPool2dBackward`, `Conv2dBackwardInput/Weight`,
33//! `AttentionBackward`, `SoftmaxCrossEntropyBackward` — added in the
34//! rlx-ir backward-op
35//! family), the VJP emits the dedicated kernel rather than composing
36//! the gradient from primitives.
37//!
38//! ## Broadcast handling
39//!
40//! Forward broadcasts (e.g. `[N, C] + [C]` → `[N, C]`) require the
41//! reverse pass to *un-broadcast* the gradient back to the broadcast
42//! input's smaller shape via a `Reduce::Sum` over the inserted /
43//! tiled axes. `unbroadcast` does this; both `Op::Binary` and
44//! `Op::Expand` VJPs use it.
45//!
46//! ## Coverage
47//!
48//! Element-wise: `Binary(Add/Sub/Mul/Div/Min/Max/Pow)`,
49//! `Activation(*)` (Relu via dedicated `ReluBackward`, others via
50//! generic `ActivationBackward`), `Compare` (zero gradient),
51//! `Where`, `Cast`, `Quantize/Dequantize` (straight-through).
52//!
53//! Linear / reductions: `MatMul`, `Conv`, `Pool{Max,Mean}`,
54//! `Reduce{Sum,Mean,Min,Max,Prod}`, `Softmax`, `LayerNorm`
55//! (dedicated kernels), `RmsNorm` (composed), `Rope` (composed
56//! via negated sin), `SoftmaxCrossEntropyWithLogits`.
57//!
58//! Shape: `Reshape`, `Transpose`, `Expand`, `Concat`, `Narrow`,
59//! `Gather` (axis=0), `ScatterAdd`.
60//!
61//! Attention: `Op::Attention` → three [`Op::AttentionBackward`]
62//! nodes (`dQ` / `dK` / `dV`) for all mask kinds. Causal /
63//! SlidingWindow masks are applied inside the backward kernel (no
64//! mask tensor). `Custom` / `Bias` pass the forward mask input.
65//! before softmax; Custom uses the user-provided mask tensor;
66//! None is the no-mask path.
67//!
68//! Pre-pass: `UnfuseElementwiseRegions` runs automatically before
69//! the gradient walk so `Op::ElementwiseRegion` decomposes into
70//! its primitive chain (covered op-by-op above).
71//!
72//! Sampling-style (`TopK`, `Sample`): non-differentiable — emit no
73//! gradient (forward is a discrete selector / draw).
74//!
75//! Pre-pass: [`crate::prepare_ad::prepare_graph_for_ad`] runs before
76//! the gradient walk (also exposed as [`PrepareForAutodiff`] pass).
77//! It unfuses elementwise regions and tier-2 fused ops
78//! (`FusedMatMulBiasAct`, `FusedResidualLN`, `FusedResidualRmsNorm`,
79//! `FusedSwiGLU`, `FusedAttentionBlock`, `FusedTransformerLayer`,
80//! `GatedDeltaNet`, `SelectiveScan`, …), lowers `DotGeneral`, inlines
81//! `If` / unrolls `While`, inlines `CustomFn` without `vjp_body`, and
82//! rewrites scans for trajectory AD.
83//!
84//! For HIR builders, [`rlx_ir::hir::FusionPolicy::for_autodiff`] lowers
85//! to primitive MIR; [`grad_with_loss_module`] accepts HIR or MIR
86//! [`GraphModule`] stages (not LIR).
87//!
88//! Cumsum: backward composed via matmul with a constant
89//! upper-triangular ones matrix (avoids needing a new `Op::Flip`
90//! primitive across all backends). Fine for typical sequence
91//! lengths; an L×L matmul where L is the sequence size.
92//!
93//! Quantized / MoE: `Op::DequantMatMul` (QAT straight-through),
94//! `Op::QMatMul`, `Op::QConv2d`, and `Op::GroupedMatMul` are all
95//! supported via composed straight-through VJPs. Plain
96//! `Op::Quantize/Dequantize` straight-through covers the typical
97//! fake-quant fp32 training path.
98//!
99//! Coverage today: every op in the IR has a VJP rule or a
100//! pre-pass that rewrites it into ones that do. SelectiveScan
101//! (Mamba SSM step) and GatedDeltaNet (Qwen3.5 linear-attention
102//! scan) decompose by unrolling the time loop into Mul / Add /
103//! MatMul / Activation::Exp / Concat / Narrow / Reshape / Expand
104//! primitives — same shape as the rlx-mlx lowering, just emitted
105//! as IR nodes instead of MLX arrays.
106//! FusedTransformerLayer / FusedAttentionBlock / FusedSwiGLU /
107//! LoraMatMul / FusedMatMulBiasAct / FusedResidualLN are all
108//! decomposed by `rlx_fusion::unfuse_fused_for_autodiff` likewise. Op::If
109//! is rewritten to `Where(predicate, then, else)` with both
110//! branches inlined; Op::While is bounded-unrolled up to
111//! `max_iterations`.
112
113use rlx_ir::op::*;
114use rlx_ir::shape::Dim;
115use rlx_ir::*;
116use std::collections::HashMap;
117
118pub use crate::prepare_ad::{
119    AutodiffError, PrepareForAutodiff, grad_with_loss_module, jvp_module, prepare_graph_for_ad,
120    prepare_mir_for_ad, prepare_module_for_ad,
121};
122
123/// Compute the reverse-mode gradient graph and the loss value.
124///
125/// Returns a graph whose outputs are
126/// `[loss, aux₁, aux₂, …, grad_wrt[0], grad_wrt[1], …]`.
127///
128/// The **first** forward output is treated as the scalar loss and is
129/// differentiated. Any **additional** forward outputs (`aux₁ …`) are
130/// mirrored from the forward graph and emitted unchanged — gradients are
131/// not propagated through them. This is the canonical hook for emitting
132/// training-side statistics (BatchNorm batch mean/variance, debug probes,
133/// …) alongside the loss in a single forward+backward pass.
134///
135/// The returned graph contains a copy of the entire forward graph so
136/// activations needed by gradient kernels are recomputed from inputs;
137/// it also exposes a new `Op::Input` named `"d_output"` which the
138/// caller seeds with the upstream gradient of the loss (typically a
139/// scalar `1.0` for "differentiate the loss directly"). Auxiliary outputs
140/// have no `d_output`-equivalent — by construction they don't contribute
141/// to the gradient path.
142///
143/// ## Limitations
144/// - Forward graph must have **≥ 1** output. The first is the loss.
145/// - All ops in the forward graph must have an implemented VJP rule.
146///   Hitting an op without one is a panic, not a silent miscompute.
147/// Options for [`grad_with_loss_opts`].
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub struct GradWithLossOptions {
150    /// When true, parameters in `wrt` with no gradient path receive an
151    /// explicit zero tensor instead of panicking (e.g. unused `logit_bias`).
152    pub zero_missing_wrt: bool,
153}
154
155impl GradWithLossOptions {
156    pub const STRICT: Self = Self {
157        zero_missing_wrt: false,
158    };
159    pub const TRAINING: Self = Self {
160        zero_missing_wrt: true,
161    };
162}
163
164/// Build a backward graph with scalar loss + gradients w.r.t. `wrt`.
165///
166/// Panics if any `wrt` parameter receives no gradient (use
167/// [`grad_with_loss_opts`] with [`GradWithLossOptions::TRAINING`] to
168/// zero-fill instead).
169pub fn grad_with_loss(forward: &Graph, wrt: &[NodeId]) -> Graph {
170    grad_with_loss_opts(forward, wrt, GradWithLossOptions::STRICT)
171}
172
173/// Like [`grad_with_loss`] with configurable unused-parameter handling.
174pub fn grad_with_loss_opts(forward: &Graph, wrt: &[NodeId], opts: GradWithLossOptions) -> Graph {
175    assert!(
176        !forward.outputs.is_empty(),
177        "grad_with_loss: forward must have at least one output (the loss)"
178    );
179
180    // Pre-autodiff unfuse: decompose fused ops back to primitives so
181    // the per-op VJP rules cover them. Two layers:
182    //   1. `UnfuseElementwiseRegions` — splits the chain back to
183    //      Activation/Cast/Binary/Compare/Where ops.
184    //   2. `rlx_fusion::unfuse_fused_for_autodiff` (below) — handles the
185    //      tier-2 fused ops with closed-form decompositions:
186    //      FusedMatMulBiasAct, FusedResidualLN, LoraMatMul.
187    //
188    // FusedSwiGLU / FusedAttentionBlock / FusedTransformerLayer
189    // are all decomposed by `rlx_fusion::unfuse_fused_for_autodiff` (each is
190    // a multi-stage sub-graph; mirrors what `rlx-tpu/src/unfuse.rs`
191    // does for HLO emission).
192    let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
193    let forward = &forward_owned;
194
195    let mut bwd = Graph::new(format!("{}_grad", forward.name));
196
197    // Mirror every forward node into bwd. The activations needed by
198    // gradient kernels (`x` for ReluBackward, `logits` for
199    // SoftmaxCrossEntropyBackward, etc.) are looked up by these
200    // mirrored ids.
201    let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
202    for node in forward.nodes() {
203        let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
204        let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
205        fwd_to_bwd.insert(node.id, new_id);
206    }
207
208    // Seed: the gradient of the loss w.r.t. itself is an external
209    // input the caller provides (typically `[1.0]` for a scalar loss).
210    let loss_fwd = forward.outputs[0];
211    let loss_bwd = fwd_to_bwd[&loss_fwd];
212    let loss_shape = forward.node(loss_fwd).shape.clone();
213    let d_output = bwd.input("d_output", loss_shape);
214
215    let mut grads: HashMap<NodeId, NodeId> = HashMap::new();
216    grads.insert(loss_bwd, d_output);
217
218    for fwd_node in forward.nodes().iter().rev() {
219        let bwd_id = fwd_to_bwd[&fwd_node.id];
220        let upstream = match grads.get(&bwd_id) {
221            Some(g) => *g,
222            None => continue,
223        };
224        let input_grads = vjp(fwd_node, upstream, &fwd_to_bwd, &mut bwd);
225        for (idx, grad_id) in input_grads {
226            let target = fwd_node.inputs[idx];
227            let bwd_target = fwd_to_bwd[&target];
228            // Per-consumer gradients accumulate (`+=`).
229            let new_grad = if let Some(&prev) = grads.get(&bwd_target) {
230                let shape = bwd.node(prev).shape.clone();
231                bwd.binary(BinaryOp::Add, prev, grad_id, shape)
232            } else {
233                grad_id
234            };
235            grads.insert(bwd_target, new_grad);
236        }
237    }
238
239    let n_aux = forward.outputs.len().saturating_sub(1);
240    let mut outputs = Vec::with_capacity(1 + n_aux + wrt.len());
241    outputs.push(loss_bwd);
242    // Auxiliary forward outputs (everything past `outputs[0]`): mirrored
243    // from the forward graph, no gradient propagation.
244    for &aux in &forward.outputs[1..] {
245        outputs.push(fwd_to_bwd[&aux]);
246    }
247    for &id in wrt {
248        let g = match grads.get(&fwd_to_bwd[&id]).copied() {
249            Some(g) => g,
250            None if opts.zero_missing_wrt => {
251                let shape = forward.node(id).shape.clone();
252                let n = shape.num_elements().unwrap_or(0);
253                let data: Vec<u8> = (0..n).flat_map(|_| 0.0f32.to_le_bytes()).collect();
254                bwd.add_node(Op::Constant { data }, vec![], shape)
255            }
256            None => {
257                panic!(
258                    "no gradient flowed to {id} — \
259                    either the forward graph doesn't depend on it, or one \
260                    of its consumer ops has no VJP rule"
261                )
262            }
263        };
264        outputs.push(g);
265    }
266    bwd.set_outputs(outputs);
267    bwd
268}
269
270/// Backwards-compatible single-output alias (parameter gradients only,
271/// no loss). Kept for the existing tests; prefer `grad_with_loss` for
272/// training.
273pub fn grad(forward: &Graph, wrt: &[NodeId]) -> Graph {
274    let g = grad_with_loss(forward, wrt);
275    let mut g = g;
276    // Drop the loss output, keep only gradients.
277    let outs = g.outputs.iter().skip(1).copied().collect();
278    g.set_outputs(outs);
279    g
280}
281
282/// Project a gradient back to a smaller shape it was broadcasted from.
283/// `target_shape` is the broadcast *source* shape (e.g. `[C]` for a
284/// bias added to `[N, C, H, W]`). Sums over leading prepended axes
285/// and over any axis where target was 1 but the gradient is larger.
286/// Then reshapes to drop the size-1 axes if the rank shrunk.
287/// Returns `Some(bits)` if `node_id` is the output of an
288/// `Op::FakeQuantize { bits, .. }` (or `FakeQuantizeLSQ`) in the
289/// forward graph. Used by the autodiff Conv backward to detect the
290/// QAT pattern and emit a specialized weight-grad kernel that can
291/// skip dead bins (weights that round to the same code share the
292/// gradient). Today only the detection is exposed — the
293/// specialization is a follow-up commit.
294pub fn quantized_weight_bits(forward: &Graph, node_id: NodeId) -> Option<u8> {
295    match &forward.node(node_id).op {
296        Op::FakeQuantize { bits, .. } => Some(*bits),
297        Op::FakeQuantizeLSQ { bits, .. } => Some(*bits),
298        _ => None,
299    }
300}
301
302fn unbroadcast(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
303    let grad_shape = bwd.node(grad).shape.clone();
304    if grad_shape == *target_shape {
305        return grad;
306    }
307    let g_rank = grad_shape.rank();
308    let t_rank = target_shape.rank();
309    let extra = g_rank.saturating_sub(t_rank);
310
311    // Axes (in grad's coordinate system) that need summing.
312    let mut axes: Vec<usize> = (0..extra).collect();
313    for i in 0..t_rank {
314        let g_dim = grad_shape.dim(extra + i);
315        let t_dim = target_shape.dim(i);
316        if matches!(t_dim, Dim::Static(1)) && !matches!(g_dim, Dim::Static(1)) {
317            axes.push(extra + i);
318        }
319    }
320
321    let mut current = grad;
322    if !axes.is_empty() {
323        // The CPU `Op::Reduce` thunk only supports a *single contiguous*
324        // range of axes — `[0, 2, 3]` (the canonical conv-bias-gradient
325        // pattern) would silently fall through to a `Nop`. Decompose into
326        // a chain of single-axis reductions with `keep_dim=true` so rank
327        // stays at `g_rank` and earlier axis indices remain valid; the
328        // rank shrink (if any) happens in the reshape step below.
329        let mut running_dims: Vec<Dim> = (0..g_rank).map(|i| grad_shape.dim(i)).collect();
330        for &ax in &axes {
331            running_dims[ax] = Dim::Static(1);
332            let step_shape = Shape::from_dims(&running_dims, target_shape.dtype());
333            current = bwd.add_node(
334                Op::Reduce {
335                    op: ReduceOp::Sum,
336                    axes: vec![ax],
337                    keep_dim: true,
338                },
339                vec![current],
340                step_shape,
341            );
342        }
343    }
344
345    // Drop leading 1-axes via Reshape if the target rank is smaller.
346    if bwd.node(current).shape.rank() != t_rank {
347        let new_shape: Vec<i64> = target_shape
348            .dims()
349            .iter()
350            .map(|d| match d {
351                Dim::Static(n) => *n as i64,
352                Dim::Dynamic(_) => -1,
353            })
354            .collect();
355        current = bwd.add_node(
356            Op::Reshape { new_shape },
357            vec![current],
358            target_shape.clone(),
359        );
360    }
361    current
362}
363
364/// Reshape a gradient to a target shape (used by Reshape / Mean VJPs).
365fn reshape_to(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
366    if bwd.node(grad).shape == *target_shape {
367        return grad;
368    }
369    let new_shape: Vec<i64> = target_shape
370        .dims()
371        .iter()
372        .map(|d| match d {
373            Dim::Static(n) => *n as i64,
374            Dim::Dynamic(_) => -1,
375        })
376        .collect();
377    bwd.add_node(Op::Reshape { new_shape }, vec![grad], target_shape.clone())
378}
379
380/// VJP for `Op::GroupedMatMul` / dequantized MoE matmul (`dx`, `dw`).
381fn grouped_matmul_vjp(
382    bwd: &mut Graph,
383    upstream: NodeId,
384    x_bwd: NodeId,
385    w_bwd: NodeId,
386    expert_bwd: NodeId,
387    x_shape: &Shape,
388    w_shape: &Shape,
389) -> (NodeId, NodeId) {
390    let dtype = x_shape.dtype();
391    let m = x_shape.dim(0);
392    let k = x_shape.dim(1);
393    let e = w_shape.dim(0);
394    let n_out = w_shape.dim(2);
395    let m_static = match m {
396        Dim::Static(v) => v,
397        _ => panic!("GroupedMatMul VJP: M must be static"),
398    };
399    let k_static = match k {
400        Dim::Static(v) => v,
401        _ => panic!("GroupedMatMul VJP: K must be static"),
402    };
403    let n_static = match n_out {
404        Dim::Static(v) => v,
405        _ => panic!("GroupedMatMul VJP: N must be static"),
406    };
407
408    let w_per = bwd.add_node(
409        Op::Gather { axis: 0 },
410        vec![w_bwd, expert_bwd],
411        Shape::from_dims(&[m, k, n_out], dtype),
412    );
413
414    let up_3d_shape = Shape::from_dims(&[m, Dim::Static(1), n_out], dtype);
415    let up_3d = bwd.reshape(
416        upstream,
417        vec![m_static as i64, 1, n_static as i64],
418        up_3d_shape,
419    );
420    let w_per_t = bwd.add_node(
421        Op::Transpose {
422            perm: vec![0, 2, 1],
423        },
424        vec![w_per],
425        Shape::from_dims(&[m, n_out, k], dtype),
426    );
427    let dx_3d_shape = Shape::from_dims(&[m, Dim::Static(1), k], dtype);
428    let dx_3d = bwd.matmul(up_3d, w_per_t, dx_3d_shape);
429    let dx = bwd.reshape(
430        dx_3d,
431        vec![m_static as i64, k_static as i64],
432        x_shape.clone(),
433    );
434
435    let x_3d = bwd.reshape(
436        x_bwd,
437        vec![m_static as i64, k_static as i64, 1],
438        Shape::from_dims(&[m, k, Dim::Static(1)], dtype),
439    );
440    let up_for_outer = bwd.reshape(
441        upstream,
442        vec![m_static as i64, 1, n_static as i64],
443        Shape::from_dims(&[m, Dim::Static(1), n_out], dtype),
444    );
445    let dw_per = bwd.matmul(x_3d, up_for_outer, Shape::from_dims(&[m, k, n_out], dtype));
446    let dw = bwd.add_node(
447        Op::ScatterAdd,
448        vec![dw_per, expert_bwd],
449        Shape::from_dims(&[e, k, n_out], dtype),
450    );
451    (dx, dw)
452}
453
454/// Build a scalar f32 `Op::Constant` node.
455fn scalar_const(value: f32, bwd: &mut Graph) -> NodeId {
456    let bytes = value.to_le_bytes().to_vec();
457    let shape = Shape::from_dims(&[Dim::Static(1)], DType::F32);
458    bwd.add_node(Op::Constant { data: bytes }, vec![], shape)
459}
460
461/// Per-op VJP rule. Returns (input_index, gradient_node_id) pairs;
462/// inputs not listed receive no gradient (e.g. the labels argument
463/// of `SoftmaxCrossEntropyWithLogits` is non-differentiable).
464fn vjp(
465    node: &Node,
466    upstream: NodeId,
467    fwd_map: &HashMap<NodeId, NodeId>,
468    bwd: &mut Graph,
469) -> Vec<(usize, NodeId)> {
470    let upstream_shape = bwd.node(upstream).shape.clone();
471    match &node.op {
472        // Leaves — no inputs → no gradients to attribute.
473        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => vec![],
474
475        Op::Binary(BinaryOp::Add) => {
476            let a_bwd = fwd_map[&node.inputs[0]];
477            let b_bwd = fwd_map[&node.inputs[1]];
478            let a_shape = bwd.node(a_bwd).shape.clone();
479            let b_shape = bwd.node(b_bwd).shape.clone();
480            let g_a = unbroadcast(upstream, &a_shape, bwd);
481            let g_b = unbroadcast(upstream, &b_shape, bwd);
482            vec![(0, g_a), (1, g_b)]
483        }
484
485        Op::Binary(BinaryOp::Sub) => {
486            let a_bwd = fwd_map[&node.inputs[0]];
487            let b_bwd = fwd_map[&node.inputs[1]];
488            let a_shape = bwd.node(a_bwd).shape.clone();
489            let b_shape = bwd.node(b_bwd).shape.clone();
490            let neg = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
491            let g_a = unbroadcast(upstream, &a_shape, bwd);
492            let g_b = unbroadcast(neg, &b_shape, bwd);
493            vec![(0, g_a), (1, g_b)]
494        }
495
496        Op::Binary(BinaryOp::Mul) => {
497            let a_bwd = fwd_map[&node.inputs[0]];
498            let b_bwd = fwd_map[&node.inputs[1]];
499            let a_shape = bwd.node(a_bwd).shape.clone();
500            let b_shape = bwd.node(b_bwd).shape.clone();
501            // Wirtinger over C64: y = a·b → dL/dā = upstream · conj(b),
502            // dL/db̄ = upstream · conj(a). The conjugates turn the
503            // standard real Mul rule into the correct complex one
504            // without changing the kernel — `BinaryFullC64` does the
505            // native complex multiply on whatever inputs we hand it.
506            let is_c64 = upstream_shape.dtype() == DType::C64;
507            let b_for_a = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
508            let a_for_b = if is_c64 { bwd.conjugate(a_bwd) } else { a_bwd };
509            let g_a_full = bwd.binary(BinaryOp::Mul, upstream, b_for_a, upstream_shape.clone());
510            let g_b_full = bwd.binary(BinaryOp::Mul, upstream, a_for_b, upstream_shape);
511            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
512            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
513            vec![(0, g_a), (1, g_b)]
514        }
515
516        Op::Activation(kind) => {
517            let x_bwd = fwd_map[&node.inputs[0]];
518            // Dedicated `ReluBackward` kernel for the most common case
519            // (avoids the per-element kind-dispatch in
520            // `ActivationBackward`'s match). Every other activation
521            // family hits the generic kernel.
522            let dx = match kind {
523                Activation::Relu => bwd.relu_backward(x_bwd, upstream),
524                _ => bwd.activation_backward(*kind, x_bwd, upstream),
525            };
526            vec![(0, dx)]
527        }
528
529        Op::MatMul => {
530            // y [..batch, M, N] = a [..batch_a, M, K] @ b [..batch_b, K, N]
531            //   da = upstream @ b^T   (shape [..batch, M, K])
532            //   db = a^T   @ upstream (shape [..batch, K, N])
533            //
534            // The forward shape inference broadcasts `batch_a` and
535            // `batch_b` to a common `batch`; if either side was
536            // broadcasted, we sum the corresponding gradient back
537            // down via `unbroadcast` so it matches the param's actual
538            // shape. The transpose swaps the *last two* dims only,
539            // leaving batch untouched.
540            let a_bwd = fwd_map[&node.inputs[0]];
541            let b_bwd = fwd_map[&node.inputs[1]];
542            let a_shape = bwd.node(a_bwd).shape.clone();
543            let b_shape = bwd.node(b_bwd).shape.clone();
544            assert!(
545                a_shape.rank() >= 2 && b_shape.rank() >= 2,
546                "MatMul VJP: rank must be ≥ 2, got {} and {}",
547                a_shape.rank(),
548                b_shape.rank()
549            );
550            let dtype = upstream_shape.dtype();
551
552            // Transpose-last-two helper.
553            let trans_last_two = |bwd: &mut Graph, x: NodeId| -> NodeId {
554                let s = bwd.node(x).shape.clone();
555                let r = s.rank();
556                let mut perm: Vec<usize> = (0..r).collect();
557                perm.swap(r - 2, r - 1);
558                let mut dims: Vec<Dim> = s.dims().to_vec();
559                dims.swap(r - 2, r - 1);
560                let new_shape = Shape::from_dims(&dims, s.dtype());
561                bwd.add_node(Op::Transpose { perm }, vec![x], new_shape)
562            };
563
564            // Build the matmul output shape [..upstream_batch, M_or_K, K_or_N]
565            // by swapping in the trailing dims for each gradient.
566            let upstream_dims: Vec<Dim> = upstream_shape.dims().to_vec();
567            let r_up = upstream_dims.len();
568
569            // C64 Wirtinger (∂L/∂z̄ convention, matching Mul/Div): the
570            // gradient conjugates the *other* operand —
571            //   dA = upstream @ conj(B)ᵀ,  dB = conj(A)ᵀ @ upstream.
572            // `conj` and transpose commute elementwise, so we conjugate the
573            // transposed operand. No-op for real dtypes.
574            let is_c64 = dtype == DType::C64;
575
576            // ── grad-a = upstream @ b^T (output shape [..up_batch, M, K]) ──
577            let b_t = trans_last_two(bwd, b_bwd);
578            let b_t = if is_c64 { bwd.conjugate(b_t) } else { b_t };
579            let mut ga_dims = upstream_dims.clone();
580            ga_dims[r_up - 1] = a_shape.dim(a_shape.rank() - 1); // K
581            let ga_shape = Shape::from_dims(&ga_dims, dtype);
582            let g_a_full = bwd.matmul(upstream, b_t, ga_shape);
583            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
584
585            // ── grad-b = a^T @ upstream (output shape [..up_batch, K, N]) ──
586            let a_t = trans_last_two(bwd, a_bwd);
587            let a_t = if is_c64 { bwd.conjugate(a_t) } else { a_t };
588            let mut gb_dims = upstream_dims.clone();
589            gb_dims[r_up - 2] = a_shape.dim(a_shape.rank() - 1); // K
590            let gb_shape = Shape::from_dims(&gb_dims, dtype);
591            let g_b_full = bwd.matmul(a_t, upstream, gb_shape);
592            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
593
594            vec![(0, g_a), (1, g_b)]
595        }
596
597        Op::DenseSolve => {
598            // X = solve(A, B) ⇒ implicit-function VJP:
599            //   dB = solve(Aᵀ, upstream)        same shape as B / X
600            //   dA = -dB · Xᵀ                   [N, N]
601            //
602            // Rank-1 (b: [N]) and rank-2 (B: [N, K]) follow the same
603            // formula; rank-1 needs a reshape-to-column trick because
604            // we don't have a vector-outer-product op (matmul is
605            // matrix-only). Rank-2 is direct matmul.
606            let a_bwd = fwd_map[&node.inputs[0]];
607            let x_bwd = fwd_map[&node.id];
608            let a_shape = bwd.node(a_bwd).shape.clone();
609            let x_shape = bwd.node(x_bwd).shape.clone();
610            assert_eq!(a_shape.rank(), 2, "DenseSolve VJP: A must be 2-D");
611            let n = match a_shape.dim(0) {
612                Dim::Static(n) => n,
613                Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic N not supported"),
614            };
615            let dtype = a_shape.dtype();
616
617            // Aᵀ — shape [N, N] (square, transpose is just a perm).
618            let mut a_t_dims: Vec<Dim> = a_shape.dims().to_vec();
619            a_t_dims.swap(0, 1);
620            let a_t_shape = Shape::from_dims(&a_t_dims, dtype);
621            let a_t = bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![a_bwd], a_t_shape);
622
623            // dB = solve(Aᵀ, upstream). Same shape as the original B.
624            let d_b = bwd.dense_solve(a_t, upstream, x_shape.clone());
625
626            // dA = -dB · Xᵀ.
627            let neg_outer = match x_shape.rank() {
628                1 => {
629                    // b: [N]. Reshape both vectors to matrices for matmul.
630                    let col_shape = Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
631                    let row_shape = Shape::from_dims(&[Dim::Static(1), Dim::Static(n)], dtype);
632                    let db_col = bwd.add_node(
633                        Op::Reshape {
634                            new_shape: vec![n as i64, 1],
635                        },
636                        vec![d_b],
637                        col_shape,
638                    );
639                    let x_row = bwd.add_node(
640                        Op::Reshape {
641                            new_shape: vec![1, n as i64],
642                        },
643                        vec![x_bwd],
644                        row_shape,
645                    );
646                    let outer = bwd.matmul(db_col, x_row, a_shape.clone());
647                    bwd.activation(Activation::Neg, outer, a_shape)
648                }
649                2 => {
650                    // B: [N, K]. dA = -MatMul(dB, Xᵀ) where Xᵀ: [K, N].
651                    let k = match x_shape.dim(1) {
652                        Dim::Static(k) => k,
653                        Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic K not supported"),
654                    };
655                    let xt_dims = vec![Dim::Static(k), Dim::Static(n)];
656                    let xt_shape = Shape::from_dims(&xt_dims, dtype);
657                    let x_t =
658                        bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![x_bwd], xt_shape);
659                    let outer = bwd.matmul(d_b, x_t, a_shape.clone());
660                    bwd.activation(Activation::Neg, outer, a_shape)
661                }
662                r => panic!("DenseSolve VJP: B must be rank 1 or 2, got rank {r}"),
663            };
664
665            vec![(0, neg_outer), (1, d_b)]
666        }
667
668        Op::BatchedDenseSolve => {
669            // Per-batch independent. Same implicit-function VJP as
670            // DenseSolve, lifted with a leading B axis throughout:
671            //   dB = batched_solve(Aᵀ, upstream)        same shape as B/X
672            //   dA = -batched_matmul(dB, Xᵀ)            shape [B, N, N]
673            // where `Aᵀ` swaps the LAST TWO axes (perm = [0, 2, 1]).
674            let a_bwd = fwd_map[&node.inputs[0]];
675            let x_bwd = fwd_map[&node.id];
676            let a_shape = bwd.node(a_bwd).shape.clone();
677            let x_shape = bwd.node(x_bwd).shape.clone();
678            assert_eq!(
679                a_shape.rank(),
680                3,
681                "BatchedDenseSolve VJP: A must be rank-3 [B, N, N]"
682            );
683            let b_dim = match a_shape.dim(0) {
684                Dim::Static(b) => b,
685                Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic B not supported"),
686            };
687            let n = match a_shape.dim(1) {
688                Dim::Static(n) => n,
689                Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic N not supported"),
690            };
691            let dtype = a_shape.dtype();
692
693            // Aᵀ across last two dims — perm = [0, 2, 1]. Output shape
694            // is [B, N, N] (same as A; transpose of square is square).
695            let a_t = bwd.add_node(
696                Op::Transpose {
697                    perm: vec![0, 2, 1],
698                },
699                vec![a_bwd],
700                a_shape.clone(),
701            );
702
703            // dB = batched_solve(Aᵀ, upstream).
704            let d_b = bwd.batched_dense_solve(a_t, upstream, x_shape.clone());
705
706            // dA = -batched_matmul(dB, Xᵀ).
707            let neg_outer = match x_shape.rank() {
708                2 => {
709                    // b is [B, N]. Reshape to [B, N, 1] (column) for dB
710                    // and [B, 1, N] (row) for X, then batched matmul.
711                    let col_shape = Shape::from_dims(
712                        &[Dim::Static(b_dim), Dim::Static(n), Dim::Static(1)],
713                        dtype,
714                    );
715                    let row_shape = Shape::from_dims(
716                        &[Dim::Static(b_dim), Dim::Static(1), Dim::Static(n)],
717                        dtype,
718                    );
719                    let db_col = bwd.add_node(
720                        Op::Reshape {
721                            new_shape: vec![b_dim as i64, n as i64, 1],
722                        },
723                        vec![d_b],
724                        col_shape,
725                    );
726                    let x_row = bwd.add_node(
727                        Op::Reshape {
728                            new_shape: vec![b_dim as i64, 1, n as i64],
729                        },
730                        vec![x_bwd],
731                        row_shape,
732                    );
733                    let outer = bwd.matmul(db_col, x_row, a_shape.clone());
734                    bwd.activation(Activation::Neg, outer, a_shape)
735                }
736                3 => {
737                    // b is [B, N, K]. dA = -MatMul(dB, Xᵀ) with
738                    // Xᵀ = Transpose(perm=[0, 2, 1]) so [B, K, N].
739                    let k = match x_shape.dim(2) {
740                        Dim::Static(k) => k,
741                        Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic K not supported"),
742                    };
743                    let xt_shape = Shape::from_dims(
744                        &[Dim::Static(b_dim), Dim::Static(k), Dim::Static(n)],
745                        dtype,
746                    );
747                    let x_t = bwd.add_node(
748                        Op::Transpose {
749                            perm: vec![0, 2, 1],
750                        },
751                        vec![x_bwd],
752                        xt_shape,
753                    );
754                    let outer = bwd.matmul(d_b, x_t, a_shape.clone());
755                    bwd.activation(Activation::Neg, outer, a_shape)
756                }
757                r => panic!("BatchedDenseSolve VJP: b must be rank 2 or 3, got rank {r}"),
758            };
759
760            vec![(0, neg_outer), (1, d_b)]
761        }
762
763        Op::Scan {
764            body,
765            length,
766            save_trajectory,
767            num_bcast: _,
768            num_xs,
769            num_checkpoints,
770        } => {
771            // After `convert_scans_for_ad`, every scan reaching the AD
772            // walk carries its trajectory. Compile body's VJP once
773            // — w.r.t. carry AND every xs — so we can extract dinit
774            // (Op::ScanBackward) plus dxs_i for each xs
775            // (Op::ScanBackwardXs). Each variant runs its own backward
776            // sweep; this is `1 + num_xs` independent sweeps. A future
777            // optimization can fuse them via packed multi-output.
778            let init_bwd = fwd_map[&node.inputs[0]];
779            let traj_bwd = fwd_map[&node.id];
780            let init_shape = bwd.node(init_bwd).shape.clone();
781
782            // Body Inputs in NodeId order: first = carry, rest = x_t_i.
783            let mut body_input_ids: Vec<NodeId> = body
784                .nodes()
785                .iter()
786                .filter(|n| matches!(n.op, Op::Input { .. }))
787                .map(|n| n.id)
788                .collect();
789            body_input_ids.sort();
790
791            let body_vjp = grad(body, &body_input_ids);
792
793            let xs_bwd: Vec<NodeId> = (0..*num_xs as usize)
794                .map(|i| fwd_map[&node.inputs[1 + i]])
795                .collect();
796
797            // Recursive checkpointing: when num_checkpoints is set on
798            // the forward Scan, propagate it (and the forward body) to
799            // each emitted ScanBackward / ScanBackwardXs so the
800            // executor knows to recompute carries via `forward_body`
801            // between checkpoints.
802            let is_checkpointed = *num_checkpoints != 0 && *num_checkpoints != *length;
803            let forward_body_for_bwd = if is_checkpointed {
804                Some((**body).clone())
805            } else {
806                None
807            };
808
809            let dinit = bwd.scan_backward_with_checkpoints(
810                init_bwd,
811                traj_bwd,
812                upstream,
813                &xs_bwd,
814                body_vjp.clone(),
815                *length,
816                *save_trajectory,
817                *num_checkpoints,
818                forward_body_for_bwd.clone(),
819                init_shape,
820            );
821
822            let mut grads: Vec<(usize, NodeId)> = vec![(0, dinit)];
823            for i in 0..*num_xs as usize {
824                let outer_xs_id = node.inputs[1 + i];
825                let xs_shape = bwd.node(fwd_map[&outer_xs_id]).shape.clone();
826                let dxs_i = bwd.scan_backward_xs_with_checkpoints(
827                    init_bwd,
828                    traj_bwd,
829                    upstream,
830                    &xs_bwd,
831                    body_vjp.clone(),
832                    *length,
833                    *save_trajectory,
834                    i as u32,
835                    *num_checkpoints,
836                    forward_body_for_bwd.clone(),
837                    xs_shape,
838                );
839                grads.push((1 + i, dxs_i));
840            }
841            grads
842        }
843
844        Op::Conv {
845            kernel_size,
846            stride,
847            padding,
848            dilation,
849            groups,
850        } => {
851            let x_bwd = fwd_map[&node.inputs[0]];
852            let w_bwd = fwd_map[&node.inputs[1]];
853            let x_shape = bwd.node(x_bwd).shape.clone();
854            let w_shape = bwd.node(w_bwd).shape.clone();
855            let dx = bwd.conv2d_backward_input(
856                upstream,
857                w_bwd,
858                x_shape,
859                kernel_size.clone(),
860                stride.clone(),
861                padding.clone(),
862                dilation.clone(),
863                *groups,
864            );
865            // Detect the QAT pattern (`Conv` reading from a
866            // `FakeQuantize` weight) so a follow-up specialization
867            // can skip dead bins (weights that round to the same
868            // code share the gradient). For now we still emit the
869            // generic backward — the helper just exposes the bits
870            // for a future kernel variant.
871            // QAT-bits detection requires the forward graph, which isn't
872            // threaded through `vjp`. Skip for now; the generic backward
873            // is used unconditionally.
874            let _qat_bits: Option<u8> = None;
875            let dw = bwd.conv2d_backward_weight(
876                x_bwd,
877                upstream,
878                w_shape,
879                kernel_size.clone(),
880                stride.clone(),
881                padding.clone(),
882                dilation.clone(),
883                *groups,
884            );
885            vec![(0, dx), (1, dw)]
886        }
887
888        Op::Pool {
889            kind: ReduceOp::Max,
890            kernel_size,
891            stride,
892            padding,
893        } => {
894            let x_bwd = fwd_map[&node.inputs[0]];
895            let dx = bwd.maxpool2d_backward(
896                x_bwd,
897                upstream,
898                kernel_size.clone(),
899                stride.clone(),
900                padding.clone(),
901            );
902            vec![(0, dx)]
903        }
904
905        Op::SoftmaxCrossEntropyWithLogits => {
906            let logits_bwd = fwd_map[&node.inputs[0]];
907            let labels_bwd = fwd_map[&node.inputs[1]];
908            let dlogits = bwd.softmax_cross_entropy_backward(logits_bwd, labels_bwd, upstream);
909            // labels has no gradient.
910            vec![(0, dlogits)]
911        }
912
913        Op::Reduce {
914            op: ReduceOp::Sum,
915            axes,
916            keep_dim,
917        } => {
918            let x_bwd = fwd_map[&node.inputs[0]];
919            let x_shape = bwd.node(x_bwd).shape.clone();
920            let g = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
921            vec![(0, g)]
922        }
923
924        Op::Reduce {
925            op: ReduceOp::Mean,
926            axes,
927            keep_dim,
928        } => {
929            // Mean = Sum / N. Do the Sum-style expansion first, then
930            // multiply the broadcast result by 1/N. Multiplying after
931            // the expand keeps the broadcast cleanly anchored at the
932            // full input shape and sidesteps the rank-promotion when
933            // the reduced output is a scalar (shape `[]`).
934            let x_bwd = fwd_map[&node.inputs[0]];
935            let x_shape = bwd.node(x_bwd).shape.clone();
936            let count: usize = axes
937                .iter()
938                .map(|&a| match x_shape.dim(a) {
939                    Dim::Static(n) => n,
940                    _ => panic!("Reduce::Mean VJP requires static reduced dims"),
941                })
942                .product();
943            let expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
944            let inv_count = scalar_const(1.0 / count as f32, bwd);
945            let g = bwd.binary(BinaryOp::Mul, expanded, inv_count, x_shape);
946            vec![(0, g)]
947        }
948
949        Op::Reshape { .. } => {
950            let x_bwd = fwd_map[&node.inputs[0]];
951            let x_shape = bwd.node(x_bwd).shape.clone();
952            let dx = reshape_to(upstream, &x_shape, bwd);
953            vec![(0, dx)]
954        }
955
956        Op::ComplexNormSq => {
957            // Wirtinger: ∂|z|²/∂z̄ = z. Cotangent g (real) maps to
958            // dz = g·z (complex, element-wise).
959            let z_bwd = fwd_map[&node.inputs[0]];
960            let dz = bwd.complex_norm_sq_backward(z_bwd, upstream);
961            vec![(0, dz)]
962        }
963
964        Op::Conjugate => {
965            // For w = conj(z): under the JAX-style cotangent (carrying
966            // ∂L/∂z̄ for a real-valued L), the rule reduces to
967            // cotangent_z = conj(cotangent_w). So the VJP of Conjugate
968            // is Conjugate itself. Symmetric — second-order derivatives
969            // through complex graphs stay consistent.
970            let dz = bwd.conjugate(upstream);
971            vec![(0, dz)]
972        }
973
974        Op::Cast { .. } => {
975            let x_bwd = fwd_map[&node.inputs[0]];
976            let x_shape = bwd.node(x_bwd).shape.clone();
977            let dx = bwd.add_node(
978                Op::Cast {
979                    to: x_shape.dtype(),
980                },
981                vec![upstream],
982                x_shape,
983            );
984            vec![(0, dx)]
985        }
986
987        // Stop-gradient (a.k.a. `detach`): forward identity, **no**
988        // gradient contribution to the input. Returning an empty list
989        // here keeps the reverse-mode walker from accumulating any
990        // upstream into `node.inputs[0]`, which is the whole point.
991        Op::StopGradient => vec![],
992
993        // Straight-through estimator: forward simulates the lossy
994        // round-trip (x → q → x'), backward pretends it was an
995        // identity. `dx = upstream` for both ops. The upstream is the
996        // f32 gradient computed by the consumer; the int8 dtype on
997        // the input/output is ignored for the gradient — we treat
998        // the entire Q/DQ chain as a real-valued no-op for autodiff
999        // purposes. This is the foundation for QAT (quantization-
1000        // aware training): the model trains in fp32 but every
1001        // forward pass tastes the int8 round-tripped activations,
1002        // so the learned weights are robust to deployment-time
1003        // quantization.
1004        Op::Quantize { .. } | Op::Dequantize { .. } => {
1005            vec![(0, upstream)]
1006        }
1007
1008        Op::FakeQuantizeLSQ { bits, axis } => {
1009            // LSQ has TWO gradients: dx (STE-clipped) and dscale
1010            // (closed-form). Route them to inputs[0] (x) and
1011            // inputs[1] (scale) respectively.
1012            let x_bwd = fwd_map[&node.inputs[0]];
1013            let scale_bwd = fwd_map[&node.inputs[1]];
1014            let x_shape = bwd.node(x_bwd).shape.clone();
1015            let scale_shape = bwd.node(scale_bwd).shape.clone();
1016            let dx = bwd.add_node(
1017                Op::FakeQuantizeLSQBackwardX {
1018                    bits: *bits,
1019                    axis: *axis,
1020                },
1021                vec![x_bwd, scale_bwd, upstream],
1022                x_shape,
1023            );
1024            let dscale = bwd.add_node(
1025                Op::FakeQuantizeLSQBackwardScale {
1026                    bits: *bits,
1027                    axis: *axis,
1028                },
1029                vec![x_bwd, scale_bwd, upstream],
1030                scale_shape,
1031            );
1032            vec![(0, dx), (1, dscale)]
1033        }
1034
1035        // FakeQuantize backward depends on the STE variant. The
1036        // default `Identity` is a clean passthrough; the others
1037        // attenuate the gradient based on `x` and the per-channel
1038        // scale, so we emit a dedicated `FakeQuantizeBackward` op.
1039        Op::FakeQuantize {
1040            bits, axis, ste, ..
1041        } => {
1042            use rlx_ir::op::SteKind;
1043            match ste {
1044                SteKind::Identity => vec![(0, upstream)],
1045                _ => {
1046                    let x_bwd = fwd_map[&node.inputs[0]];
1047                    let x_shape = bwd.node(x_bwd).shape.clone();
1048                    let dx = bwd.add_node(
1049                        Op::FakeQuantizeBackward {
1050                            bits: *bits,
1051                            axis: *axis,
1052                            ste: *ste,
1053                        },
1054                        vec![x_bwd, upstream],
1055                        x_shape,
1056                    );
1057                    vec![(0, dx)]
1058                }
1059            }
1060        }
1061
1062        Op::Expand { .. } => {
1063            let x_bwd = fwd_map[&node.inputs[0]];
1064            let x_shape = bwd.node(x_bwd).shape.clone();
1065            let dx = unbroadcast(upstream, &x_shape, bwd);
1066            vec![(0, dx)]
1067        }
1068
1069        Op::BatchNormInference { eps } => {
1070            let x_bwd = fwd_map[&node.inputs[0]];
1071            let gamma_bwd = fwd_map[&node.inputs[1]];
1072            let _beta_bwd = fwd_map[&node.inputs[2]];
1073            let mean_bwd = fwd_map[&node.inputs[3]];
1074            let var_bwd = fwd_map[&node.inputs[4]];
1075            let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1076            let dx = bwd.batch_norm_inference_backward_input(
1077                x_bwd, gamma_bwd, mean_bwd, var_bwd, upstream, *eps,
1078            );
1079            let dgamma = bwd.batch_norm_inference_backward_gamma(
1080                x_bwd,
1081                mean_bwd,
1082                var_bwd,
1083                upstream,
1084                gamma_shape.clone(),
1085                *eps,
1086            );
1087            let dbeta = bwd.batch_norm_inference_backward_beta(upstream, gamma_shape);
1088            // mean/var are frozen — no gradients.
1089            vec![(0, dx), (1, dgamma), (2, dbeta)]
1090        }
1091
1092        Op::LayerNorm { axis, eps } => {
1093            // y = LayerNorm(x, gamma, beta) over the feature axis.
1094            // d_x via the dedicated `LayerNormBackwardInput` kernel
1095            // (closed-form, recomputes mean/var/x̂ inline).
1096            // d_gamma via `LayerNormBackwardGamma` (sums over batch axes).
1097            // d_beta = sum(upstream) over batch axes — composable with
1098            // an unbroadcast back to gamma's shape (gamma and beta share shape).
1099            let x_bwd = fwd_map[&node.inputs[0]];
1100            let gamma_bwd = fwd_map[&node.inputs[1]];
1101            let _beta_bwd = fwd_map[&node.inputs[2]];
1102            let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1103
1104            let dx = bwd.layer_norm_backward_input(x_bwd, gamma_bwd, upstream, *axis, *eps);
1105            let dgamma =
1106                bwd.layer_norm_backward_gamma(x_bwd, upstream, gamma_shape.clone(), *axis, *eps);
1107            let dbeta = unbroadcast(upstream, &gamma_shape, bwd);
1108            vec![(0, dx), (1, dgamma), (2, dbeta)]
1109        }
1110
1111        Op::Softmax { axis } => {
1112            // y = softmax(x, axis)  →  dy/dx[i] = y[i] · (g[i] - Σⱼ y[j]·g[j])
1113            // where the Σⱼ is over the softmax axis. Compose from existing
1114            // primitives:  yg = y * upstream
1115            //              s  = reduce_sum(yg, axis, keep_dim=true)
1116            //              s' = expand(s, target=y.shape)
1117            //              dx = y * (upstream - s')
1118            //
1119            // The forward `y` lives at `fwd_to_bwd[node.id]` — bwd
1120            // graph mirrors every forward node so its slot survives
1121            // through this VJP. We *explicitly* expand `s` to `y.shape`
1122            // before the Sub rather than relying on `Op::Binary`'s
1123            // broadcast (which has a known shape-confusion bug for the
1124            // `[..., 1]` keep-dim case — see the rlx-cpu thunk
1125            // dispatch). Going through `Op::Expand` runs the
1126            // dedicated stride-aware broadcast thunk, which is correct.
1127            let y_bwd = fwd_map[&node.id];
1128            let y_shape = bwd.node(y_bwd).shape.clone();
1129            let dtype = y_shape.dtype();
1130            let rank = y_shape.rank();
1131            let axis_pos = if *axis < 0 {
1132                (rank as i32 + *axis) as usize
1133            } else {
1134                *axis as usize
1135            };
1136
1137            let yg = bwd.binary(BinaryOp::Mul, y_bwd, upstream, y_shape.clone());
1138
1139            let mut kept_dims: Vec<Dim> = y_shape.dims().to_vec();
1140            kept_dims[axis_pos] = Dim::Static(1);
1141            let kept_shape = Shape::from_dims(&kept_dims, dtype);
1142            let s = bwd.add_node(
1143                Op::Reduce {
1144                    op: ReduceOp::Sum,
1145                    axes: vec![axis_pos],
1146                    keep_dim: true,
1147                },
1148                vec![yg],
1149                kept_shape,
1150            );
1151
1152            let target_dims: Vec<i64> = y_shape
1153                .dims()
1154                .iter()
1155                .map(|d| match d {
1156                    Dim::Static(n) => *n as i64,
1157                    Dim::Dynamic(_) => -1,
1158                })
1159                .collect();
1160            let s_expanded = bwd.add_node(
1161                Op::Expand {
1162                    target_shape: target_dims,
1163                },
1164                vec![s],
1165                y_shape.clone(),
1166            );
1167
1168            let diff = bwd.binary(BinaryOp::Sub, upstream, s_expanded, y_shape.clone());
1169            let dx = bwd.binary(BinaryOp::Mul, y_bwd, diff, y_shape);
1170            vec![(0, dx)]
1171        }
1172
1173        // ── Shape ops: just route the upstream gradient through ──
1174        Op::Transpose { perm } => {
1175            // Inverse permutation: if forward maps axis i → perm[i],
1176            // backward maps perm[i] → i.
1177            let inv: Vec<usize> = {
1178                let mut v = vec![0usize; perm.len()];
1179                for (i, &p) in perm.iter().enumerate() {
1180                    v[p] = i;
1181                }
1182                v
1183            };
1184            let x_bwd = fwd_map[&node.inputs[0]];
1185            let x_shape = bwd.node(x_bwd).shape.clone();
1186            let dx = bwd.add_node(Op::Transpose { perm: inv }, vec![upstream], x_shape);
1187            vec![(0, dx)]
1188        }
1189
1190        Op::Concat { axis } => {
1191            // Split upstream along the concat axis: each input gets
1192            // `Narrow(upstream, axis, offset, x_i.dim(axis))`.
1193            let mut grads = Vec::with_capacity(node.inputs.len());
1194            let mut offset: usize = 0;
1195            for (i, &input_id) in node.inputs.iter().enumerate() {
1196                let x_bwd = fwd_map[&input_id];
1197                let x_shape = bwd.node(x_bwd).shape.clone();
1198                let len = match x_shape.dim(*axis) {
1199                    Dim::Static(n) => n,
1200                    _ => panic!("Concat VJP: dynamic concat dim"),
1201                };
1202                let dx = bwd.add_node(
1203                    Op::Narrow {
1204                        axis: *axis,
1205                        start: offset,
1206                        len,
1207                    },
1208                    vec![upstream],
1209                    x_shape,
1210                );
1211                grads.push((i, dx));
1212                offset += len;
1213            }
1214            grads
1215        }
1216
1217        Op::Narrow { axis, start, len } => {
1218            // Inverse of slicing: pad upstream with zeros on both
1219            // sides along `axis` so the result matches input shape.
1220            // Build via Concat[zeros_pre, upstream, zeros_post].
1221            let x_bwd = fwd_map[&node.inputs[0]];
1222            let x_shape = bwd.node(x_bwd).shape.clone();
1223            let full_n = match x_shape.dim(*axis) {
1224                Dim::Static(n) => n,
1225                _ => panic!("Narrow VJP: dynamic axis"),
1226            };
1227            let pre = *start;
1228            let post = full_n - *start - *len;
1229
1230            let zero_buf = |bwd: &mut Graph, len_axis: usize| -> NodeId {
1231                if len_axis == 0 {
1232                    return upstream; // sentinel, never used (filtered below)
1233                }
1234                let dtype = x_shape.dtype();
1235                let mut dims: Vec<Dim> = x_shape.dims().to_vec();
1236                dims[*axis] = Dim::Static(len_axis);
1237                let s = Shape::from_dims(&dims, dtype);
1238                let n_elems = dims.iter().fold(1usize, |a, d| match d {
1239                    Dim::Static(k) => a * k,
1240                    _ => a,
1241                });
1242                // Bytes per element scales with dtype; bytewise-zero is
1243                // a valid zero at any precision (IEEE +0.0 / signed 0 /
1244                // unsigned 0), so a vec of zero bytes is safe.
1245                let bytes = vec![0u8; n_elems * dtype.size_bytes()];
1246                bwd.add_node(Op::Constant { data: bytes }, vec![], s)
1247            };
1248
1249            let mut parts: Vec<NodeId> = Vec::new();
1250            if pre > 0 {
1251                parts.push(zero_buf(bwd, pre));
1252            }
1253            parts.push(upstream);
1254            if post > 0 {
1255                parts.push(zero_buf(bwd, post));
1256            }
1257
1258            let dx = if parts.len() == 1 {
1259                parts[0]
1260            } else {
1261                bwd.add_node(Op::Concat { axis: *axis }, parts, x_shape)
1262            };
1263            vec![(0, dx)]
1264        }
1265
1266        Op::Gather { axis } => {
1267            let table_bwd = fwd_map[&node.inputs[0]];
1268            let indices_bwd = fwd_map[&node.inputs[1]];
1269            let table_shape = bwd.node(table_bwd).shape.clone();
1270            if *axis == 0 {
1271                let dtable = bwd.add_node(Op::ScatterAdd, vec![upstream, indices_bwd], table_shape);
1272                vec![(0, dtable)]
1273            } else {
1274                let dtable = bwd.gather_backward(
1275                    upstream,
1276                    indices_bwd,
1277                    table_shape,
1278                    (*axis).try_into().unwrap(),
1279                );
1280                vec![(0, dtable)]
1281            }
1282        }
1283
1284        // ── Non-differentiable predicates / selectors ──
1285        Op::Compare(_) => {
1286            // Compare returns a boolean tensor; gradient w.r.t.
1287            // continuous inputs is zero almost everywhere. We don't
1288            // propagate (caller will see zero grads for any path
1289            // that flows through a Compare alone).
1290            vec![]
1291        }
1292
1293        Op::Where => {
1294            // out = where(cond, a, b). Cond has zero gradient
1295            // (it's a predicate); a's gradient is `where(cond,
1296            // upstream, 0)`; b's gradient is `where(cond, 0, upstream)`.
1297            let cond = fwd_map[&node.inputs[0]];
1298            let a_bwd = fwd_map[&node.inputs[1]];
1299            let b_bwd = fwd_map[&node.inputs[2]];
1300            let a_shape = bwd.node(a_bwd).shape.clone();
1301            let b_shape = bwd.node(b_bwd).shape.clone();
1302            let out_shape = upstream_shape.clone();
1303
1304            let zero_a_bytes = vec![0u8; a_shape.num_elements().expect("Where VJP: dynamic a") * 4];
1305            let zero_b_bytes = vec![0u8; b_shape.num_elements().expect("Where VJP: dynamic b") * 4];
1306            let zero_a = bwd.add_node(Op::Constant { data: zero_a_bytes }, vec![], a_shape.clone());
1307            let zero_b = bwd.add_node(Op::Constant { data: zero_b_bytes }, vec![], b_shape.clone());
1308            // Need to match shapes for Op::Where (cond, a, b same).
1309            // Upstream shape == out_shape == broadcast of a/b.
1310            let zero_a_bcast = unbroadcast_inverse(zero_a, &out_shape, bwd);
1311            let zero_b_bcast = unbroadcast_inverse(zero_b, &out_shape, bwd);
1312            let g_a_full = bwd.add_node(
1313                Op::Where,
1314                vec![cond, upstream, zero_a_bcast],
1315                out_shape.clone(),
1316            );
1317            let g_b_full = bwd.add_node(Op::Where, vec![cond, zero_b_bcast, upstream], out_shape);
1318            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1319            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1320            vec![(1, g_a), (2, g_b)]
1321        }
1322
1323        // ── Element-wise binary ops ──
1324        Op::Binary(BinaryOp::Div) => {
1325            // Real:  d/da (a/b) = 1/b,        d/db (a/b) = -a/b² = -y/b
1326            // C64 (Wirtinger):
1327            //        d/dā = upstream / conj(b)
1328            //        d/db̄ = -upstream · conj(y) / conj(b)
1329            // Substituting `b ↦ conj(b)` and `y ↦ conj(y)` in the real
1330            // rule recovers the complex one — the kernel itself is
1331            // unchanged.
1332            let a_bwd = fwd_map[&node.inputs[0]];
1333            let b_bwd = fwd_map[&node.inputs[1]];
1334            let y_bwd = fwd_map[&node.id];
1335            let a_shape = bwd.node(a_bwd).shape.clone();
1336            let b_shape = bwd.node(b_bwd).shape.clone();
1337            let is_c64 = upstream_shape.dtype() == DType::C64;
1338
1339            let b_term = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
1340            let y_term = if is_c64 { bwd.conjugate(y_bwd) } else { y_bwd };
1341
1342            // d/da: upstream / b_term
1343            let g_a_full = bwd.binary(BinaryOp::Div, upstream, b_term, upstream_shape.clone());
1344            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1345
1346            // d/db: -upstream * y_term / b_term
1347            let neg_up = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
1348            let neg_up_y = bwd.binary(BinaryOp::Mul, neg_up, y_term, upstream_shape.clone());
1349            let g_b_full = bwd.binary(BinaryOp::Div, neg_up_y, b_term, upstream_shape);
1350            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1351
1352            vec![(0, g_a), (1, g_b)]
1353        }
1354
1355        // ── Reductions: gradient flows to where the reduction "saw" ──
1356        Op::Reduce {
1357            op: ReduceOp::Max,
1358            axes,
1359            keep_dim,
1360        }
1361        | Op::Reduce {
1362            op: ReduceOp::Min,
1363            axes,
1364            keep_dim,
1365        } => {
1366            // d_x[i] = upstream where x[i] equals the (broadcast)
1367            // reduce result, else 0. Composed via
1368            // expand(upstream) * (compare(x, expand(y), Eq) → 1.0).
1369            let is_max = matches!(
1370                node.op,
1371                Op::Reduce {
1372                    op: ReduceOp::Max,
1373                    ..
1374                }
1375            );
1376            let _ = is_max;
1377            let x_bwd = fwd_map[&node.inputs[0]];
1378            let y_bwd = fwd_map[&node.id];
1379            let x_shape = bwd.node(x_bwd).shape.clone();
1380            let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1381            let mask_bool = bwd.add_node(
1382                Op::Compare(CmpOp::Eq),
1383                vec![x_bwd, y_expanded],
1384                Shape::from_dims(x_shape.dims(), DType::F32),
1385            );
1386            // Convert bool→f32 via Cast (the IR encodes bool/PRED as
1387            // F32 in our backends already; this is a no-op cast on
1388            // most paths).
1389            let mask_f32 = bwd.add_node(
1390                Op::Cast {
1391                    to: x_shape.dtype(),
1392                },
1393                vec![mask_bool],
1394                x_shape.clone(),
1395            );
1396            let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1397            let dx = bwd.binary(BinaryOp::Mul, upstream_expanded, mask_f32, x_shape);
1398            vec![(0, dx)]
1399        }
1400
1401        // ── Rope: backward is forward with negated sin ──
1402        //
1403        //   forward:  out = x * cos + rotate(x) * sin
1404        //   reverse:  dx  = dy * cos + rotate(dy) * (-sin)
1405        //         =  rope(dy, cos, neg(sin))
1406        Op::Rope { head_dim, n_rot } => {
1407            let cos = fwd_map[&node.inputs[1]];
1408            let sin = fwd_map[&node.inputs[2]];
1409            let dx = bwd.rope_backward(upstream, cos, sin, *head_dim, *n_rot);
1410            vec![(0, dx)]
1411        }
1412
1413        Op::RmsNorm { axis, eps } => {
1414            let x = fwd_map[&node.inputs[0]];
1415            let gamma = fwd_map[&node.inputs[1]];
1416            let beta = fwd_map[&node.inputs[2]];
1417            let dx = bwd.rms_norm_backward_input(x, gamma, beta, upstream, *axis, *eps);
1418            let dgamma = bwd.rms_norm_backward_gamma(x, gamma, beta, upstream, *axis, *eps);
1419            let dbeta = bwd.rms_norm_backward_beta(x, gamma, beta, upstream, *axis, *eps);
1420            vec![(0, dx), (1, dgamma), (2, dbeta)]
1421        }
1422
1423        Op::GroupNorm { num_groups, eps } => {
1424            let x = fwd_map[&node.inputs[0]];
1425            let gamma = fwd_map[&node.inputs[1]];
1426            let beta = fwd_map[&node.inputs[2]];
1427            let gamma_shape = bwd.node(gamma).shape.clone();
1428            let beta_shape = bwd.node(beta).shape.clone();
1429            let dx = bwd.group_norm_backward_input(x, gamma, beta, upstream, *num_groups, *eps);
1430            let dgamma = bwd.group_norm_backward_gamma(x, upstream, gamma_shape, *num_groups, *eps);
1431            let dbeta = bwd.group_norm_backward_beta(x, upstream, beta_shape, *num_groups, *eps);
1432            vec![(0, dx), (1, dgamma), (2, dbeta)]
1433        }
1434
1435        // ── Attention → dedicated backward kernels ──────────────
1436        Op::Attention {
1437            num_heads,
1438            head_dim,
1439            mask_kind,
1440            score_scale: _,
1441            attn_logit_softcap: _,
1442        } => {
1443            let q = fwd_map[&node.inputs[0]];
1444            let k = fwd_map[&node.inputs[1]];
1445            let v = fwd_map[&node.inputs[2]];
1446            let mask = match mask_kind {
1447                MaskKind::Custom | MaskKind::Bias => Some(fwd_map[&node.inputs[3]]),
1448                _ => None,
1449            };
1450            let (dq, dk, dv) = bwd
1451                .attention_backward_all(q, k, v, upstream, *num_heads, *head_dim, *mask_kind, mask);
1452            vec![(0, dq), (1, dk), (2, dv)]
1453        }
1454
1455        // ── Reduce(Prod) ────────────────────────────────────────
1456        //
1457        // Forward: y[axes_reduced] = ∏ x[axes_reduced…]
1458        // Backward: dx[i] = upstream · y / x[i]   (per-row).
1459        // (Numerically dicey when any x[i] = 0; production users
1460        //  needing zero-safe Prod-grad should pre-mask.)
1461        Op::Reduce {
1462            op: ReduceOp::Prod,
1463            axes,
1464            keep_dim,
1465        } => {
1466            let x_bwd = fwd_map[&node.inputs[0]];
1467            let y_bwd = fwd_map[&node.id];
1468            let x_shape = bwd.node(x_bwd).shape.clone();
1469            let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1470            let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1471            // dx = upstream_b · y_b / x
1472            let num = bwd.binary(
1473                BinaryOp::Mul,
1474                upstream_expanded,
1475                y_expanded,
1476                x_shape.clone(),
1477            );
1478            let dx = bwd.binary(BinaryOp::Div, num, x_bwd, x_shape);
1479            vec![(0, dx)]
1480        }
1481
1482        // ── Pool(Mean) ──────────────────────────────────────────
1483        //
1484        // Forward: y[..., h_out, w_out] = mean(window).
1485        // Backward: dx[i] = upstream[output_pos(i)] / |window|
1486        //   distributed across each pool window.
1487        //
1488        // Compose via a Conv2dBackwardInput with a constant
1489        // 1/|window| kernel of shape [C, 1, kH, kW] and groups=C
1490        // (depthwise — no channel mixing). This gives the correct
1491        // "spread upstream over window" behavior including stride
1492        // and padding handling.
1493        Op::Pool {
1494            kind: ReduceOp::Mean,
1495            kernel_size,
1496            stride,
1497            padding,
1498        } => {
1499            assert_eq!(kernel_size.len(), 2, "Pool(Mean) VJP: 2-D pool only");
1500            let x_bwd = fwd_map[&node.inputs[0]];
1501            let x_shape = bwd.node(x_bwd).shape.clone();
1502            let dtype = x_shape.dtype();
1503            // Channels = x_shape.dim(1).
1504            let c = match x_shape.dim(1) {
1505                Dim::Static(n) => n,
1506                _ => panic!("Pool(Mean) VJP: dynamic channel dim"),
1507            };
1508            let kh = kernel_size[0];
1509            let kw = kernel_size[1];
1510            let inv_n = 1.0_f32 / (kh as f32 * kw as f32);
1511            let kernel_n = c * kh * kw;
1512            let mut bytes: Vec<u8> = Vec::with_capacity(kernel_n * 4);
1513            for _ in 0..kernel_n {
1514                bytes.extend_from_slice(&inv_n.to_le_bytes());
1515            }
1516            let kernel_shape = Shape::from_dims(
1517                &[
1518                    Dim::Static(c),
1519                    Dim::Static(1),
1520                    Dim::Static(kh),
1521                    Dim::Static(kw),
1522                ],
1523                dtype,
1524            );
1525            let kernel = bwd.add_node(Op::Constant { data: bytes }, vec![], kernel_shape);
1526            let dx = bwd.conv2d_backward_input(
1527                upstream,
1528                kernel,
1529                x_shape,
1530                kernel_size.clone(),
1531                stride.clone(),
1532                padding.clone(),
1533                vec![1, 1],
1534                c, // groups = c → depthwise
1535            );
1536            vec![(0, dx)]
1537        }
1538
1539        // ── Binary(Min/Max) ─────────────────────────────────────
1540        //
1541        // Element-wise min/max: gradient flows to whichever input
1542        // was selected (ties go to the first operand by convention).
1543        //   da = where(a == out, upstream, 0)
1544        //   db = where(a == out, 0, upstream)   ← exclusive
1545        Op::Binary(BinaryOp::Min) | Op::Binary(BinaryOp::Max) => {
1546            let a_bwd = fwd_map[&node.inputs[0]];
1547            let b_bwd = fwd_map[&node.inputs[1]];
1548            let y_bwd = fwd_map[&node.id];
1549            let a_shape = bwd.node(a_bwd).shape.clone();
1550            let b_shape = bwd.node(b_bwd).shape.clone();
1551            let dtype = upstream_shape.dtype();
1552
1553            let bool_shape = Shape::from_dims(upstream_shape.dims(), DType::Bool);
1554            let mask_pred = bwd.add_node(Op::Compare(CmpOp::Eq), vec![a_bwd, y_bwd], bool_shape);
1555            let mask_f32 = bwd.add_node(
1556                Op::Cast { to: dtype },
1557                vec![mask_pred],
1558                upstream_shape.clone(),
1559            );
1560            let zero_bytes = vec![
1561                0u8;
1562                upstream_shape
1563                    .num_elements()
1564                    .expect("Min/Max VJP: dyn shape")
1565                    * 4
1566            ];
1567            let zero = bwd.add_node(
1568                Op::Constant { data: zero_bytes },
1569                vec![],
1570                upstream_shape.clone(),
1571            );
1572            let g_a_full = bwd.add_node(
1573                Op::Where,
1574                vec![mask_f32, upstream, zero],
1575                upstream_shape.clone(),
1576            );
1577            let g_b_full = bwd.add_node(Op::Where, vec![mask_f32, zero, upstream], upstream_shape);
1578            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1579            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1580            vec![(0, g_a), (1, g_b)]
1581        }
1582
1583        // ── Binary(Pow) ─────────────────────────────────────────
1584        //
1585        //   d/da (aᵇ) = b · a^(b-1)
1586        //   d/db (aᵇ) = aᵇ · ln(a)
1587        //
1588        // We don't have a `Pow` activation, but `pow(a, b)` for
1589        // positive base equals `exp(b · ln(a))`, and the derivative
1590        // simplifies. Express via `Activation::Log / Exp` and `Mul`.
1591        Op::Binary(BinaryOp::Pow) => {
1592            let a_bwd = fwd_map[&node.inputs[0]];
1593            let b_bwd = fwd_map[&node.inputs[1]];
1594            let y_bwd = fwd_map[&node.id]; // a^b
1595            let a_shape = bwd.node(a_bwd).shape.clone();
1596            let b_shape = bwd.node(b_bwd).shape.clone();
1597
1598            // d/da: upstream · y / a = upstream · b · a^(b-1).
1599            // Easier route: upstream · y · b / a.
1600            let yb = bwd.binary(BinaryOp::Mul, y_bwd, b_bwd, upstream_shape.clone());
1601            let yb_over_a = bwd.binary(BinaryOp::Div, yb, a_bwd, upstream_shape.clone());
1602            let g_a_full = bwd.binary(BinaryOp::Mul, upstream, yb_over_a, upstream_shape.clone());
1603            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1604
1605            // d/db: upstream · y · ln(a)
1606            let ln_a = bwd.activation(Activation::Log, a_bwd, a_shape);
1607            let ln_a_b = unbroadcast_inverse(ln_a, &upstream_shape, bwd);
1608            let yln = bwd.binary(BinaryOp::Mul, y_bwd, ln_a_b, upstream_shape.clone());
1609            let g_b_full = bwd.binary(BinaryOp::Mul, upstream, yln, upstream_shape);
1610            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1611
1612            vec![(0, g_a), (1, g_b)]
1613        }
1614
1615        // ── DequantMatMul (QAT-style straight-through) ─────────
1616        //
1617        // Forward (Int8BlockAsym):
1618        //   w_dq = (cast<f32>(w_q) - zp_b) * scale_b
1619        //   y    = x @ w_dq
1620        //
1621        // Backward (QAT convention — scale and zp are typically
1622        // frozen during fine-tuning; w_q's int8 cast is treated as
1623        // a no-op for the gradient via straight-through):
1624        //   dx     = upstream @ w_dq^T
1625        //   dw_q   = x^T @ upstream * scale_b   (straight-through;
1626        //            the user's optimizer would project back to
1627        //            int8 after the step)
1628        //   dscale = 0   (frozen)
1629        //   dzp    = 0   (frozen)
1630        //
1631        // For full QAT with learnable scale/zp, replace the zero
1632        // gradients with the closed-form ∂y/∂scale / ∂y/∂zp.
1633        Op::DequantMatMul { scheme: _ } => {
1634            let x_bwd = fwd_map[&node.inputs[0]];
1635            let w_q_bwd = fwd_map[&node.inputs[1]];
1636            let scale_bwd = fwd_map[&node.inputs[2]];
1637            let zp_bwd = fwd_map[&node.inputs[3]];
1638            let x_shape = bwd.node(x_bwd).shape.clone();
1639            let w_shape = bwd.node(w_q_bwd).shape.clone();
1640            let scale_shape = bwd.node(scale_bwd).shape.clone();
1641            let zp_shape = bwd.node(zp_bwd).shape.clone();
1642
1643            // dx = upstream @ w_dq^T. Recompute w_dq inline.
1644            // w_q is int8 in the IR — cast to f32 for the matmul
1645            // backward graph (straight-through equivalent).
1646            let dtype = x_shape.dtype();
1647            let w_q_f32 = bwd.add_node(
1648                Op::Cast { to: dtype },
1649                vec![w_q_bwd],
1650                Shape::from_dims(w_shape.dims(), dtype),
1651            );
1652            // Broadcast scale/zp to w_shape before subtract/mul.
1653            let scale_b =
1654                unbroadcast_inverse(scale_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1655            let zp_b = unbroadcast_inverse(zp_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1656            let w_centered = bwd.binary(
1657                BinaryOp::Sub,
1658                w_q_f32,
1659                zp_b,
1660                Shape::from_dims(w_shape.dims(), dtype),
1661            );
1662            let w_dq = bwd.binary(
1663                BinaryOp::Mul,
1664                w_centered,
1665                scale_b,
1666                Shape::from_dims(w_shape.dims(), dtype),
1667            );
1668
1669            // Transpose w_dq's last two dims for dx = upstream @ w_dq^T.
1670            let w_rank = w_shape.rank();
1671            let mut perm: Vec<usize> = (0..w_rank).collect();
1672            perm.swap(w_rank - 2, w_rank - 1);
1673            let mut wdt_dims: Vec<Dim> = w_shape.dims().to_vec();
1674            wdt_dims.swap(w_rank - 2, w_rank - 1);
1675            let w_dq_t_shape = Shape::from_dims(&wdt_dims, dtype);
1676            let w_dq_t = bwd.add_node(Op::Transpose { perm }, vec![w_dq], w_dq_t_shape);
1677            let dx = bwd.matmul(upstream, w_dq_t, x_shape.clone());
1678
1679            // dw_q = (x^T @ upstream) * scale_b   (straight-through).
1680            // The result is in the int8-weight space — caller's
1681            // optimizer is expected to project back. We emit it as
1682            // f32 here and let downstream cast.
1683            let x_rank = x_shape.rank();
1684            let mut x_perm: Vec<usize> = (0..x_rank).collect();
1685            x_perm.swap(x_rank - 2, x_rank - 1);
1686            let mut x_t_dims: Vec<Dim> = x_shape.dims().to_vec();
1687            x_t_dims.swap(x_rank - 2, x_rank - 1);
1688            let x_t = bwd.add_node(
1689                Op::Transpose { perm: x_perm },
1690                vec![x_bwd],
1691                Shape::from_dims(&x_t_dims, dtype),
1692            );
1693            let dw_unscaled = bwd.matmul(x_t, upstream, Shape::from_dims(w_shape.dims(), dtype));
1694            let dw_q_f32 = bwd.binary(
1695                BinaryOp::Mul,
1696                dw_unscaled,
1697                scale_b,
1698                Shape::from_dims(w_shape.dims(), dtype),
1699            );
1700            // Cast back to the IR's int8 dtype convention.
1701            let dw_q = bwd.add_node(
1702                Op::Cast {
1703                    to: w_shape.dtype(),
1704                },
1705                vec![dw_q_f32],
1706                w_shape,
1707            );
1708
1709            // scale and zp: zero gradients (frozen QAT convention).
1710            let zero_scale_bytes =
1711                vec![0u8; scale_shape.num_elements().expect("DQMM VJP: dyn scale") * 4];
1712            let zero_zp_bytes = vec![0u8; zp_shape.num_elements().expect("DQMM VJP: dyn zp") * 4];
1713            let dscale = bwd.add_node(
1714                Op::Constant {
1715                    data: zero_scale_bytes,
1716                },
1717                vec![],
1718                scale_shape,
1719            );
1720            let dzp = bwd.add_node(
1721                Op::Constant {
1722                    data: zero_zp_bytes,
1723                },
1724                vec![],
1725                zp_shape,
1726            );
1727
1728            vec![(0, dx), (1, dw_q), (2, dscale), (3, dzp)]
1729        }
1730
1731        // ── ScatterAdd ──────────────────────────────────────────
1732        //
1733        // Forward: out[indices[i], ...] += updates[i, ...].
1734        // Backward: d_updates[i, ...] = upstream[indices[i], ...]  (gather).
1735        //   Indices are non-differentiable.
1736        Op::ScatterAdd => {
1737            let updates_bwd = fwd_map[&node.inputs[0]];
1738            let indices_bwd = fwd_map[&node.inputs[1]];
1739            let updates_shape = bwd.node(updates_bwd).shape.clone();
1740            let dupdates = bwd.add_node(
1741                Op::Gather { axis: 0 },
1742                vec![upstream, indices_bwd],
1743                updates_shape,
1744            );
1745            vec![(0, dupdates)]
1746        }
1747
1748        // ── Cumsum ──────────────────────────────────────────────
1749        //
1750        Op::Cumsum { axis, exclusive } => {
1751            let x_bwd = fwd_map[&node.inputs[0]];
1752            let x_shape = bwd.node(x_bwd).shape.clone();
1753            let dx = bwd.cumsum_backward(upstream, x_shape, *axis, *exclusive);
1754            vec![(0, dx)]
1755        }
1756
1757        // ── GroupedMatMul (MoE primitive) ──────────────────────
1758        //
1759        // Forward: y[i] = x[i] @ w[expert[i]]
1760        //   x        [M, K]
1761        //   w        [E, K, N]
1762        //   expert   [M] (f32-encoded indices)
1763        //   y        [M, N]
1764        //
1765        // Backward (composed via Gather + batched-MatMul + ScatterAdd):
1766        //   dx[i] = upstream[i] @ w[expert[i]]^T
1767        //   dw[e, k, n] = sum_{i : expert[i]=e} x[i,k] · upstream[i,n]
1768        //   dexpert: zero (non-differentiable index input).
1769        Op::GroupedMatMul => {
1770            let x_bwd = fwd_map[&node.inputs[0]];
1771            let w_bwd = fwd_map[&node.inputs[1]];
1772            let expert_bwd = fwd_map[&node.inputs[2]];
1773            let x_shape = bwd.node(x_bwd).shape.clone();
1774            let w_shape = bwd.node(w_bwd).shape.clone();
1775            let (dx, dw) =
1776                grouped_matmul_vjp(bwd, upstream, x_bwd, w_bwd, expert_bwd, &x_shape, &w_shape);
1777            vec![(0, dx), (1, dw)]
1778        }
1779
1780        // ── DequantGroupedMatMul (frozen GGUF MoE weights) ─────
1781        //
1782        // Materialize w_dq via `Op::DequantMoEWeights`, then reuse the
1783        // GroupedMatMul VJP. Packed U8 weights and expert indices are
1784        // non-differentiable (inference / QAT-frozen convention).
1785        Op::DequantGroupedMatMul { scheme } => {
1786            let x_bwd = fwd_map[&node.inputs[0]];
1787            let w_packed = fwd_map[&node.inputs[1]];
1788            let expert_bwd = fwd_map[&node.inputs[2]];
1789            let x_shape = bwd.node(x_bwd).shape.clone();
1790            let w_packed_shape = bwd.node(w_packed).shape.clone();
1791            let dtype = x_shape.dtype();
1792            let k = x_shape.dim(1);
1793            let n_out = node.shape.dim(node.shape.rank() - 1);
1794            let k_static = match k {
1795                Dim::Static(v) => v,
1796                _ => panic!("DequantGroupedMatMul VJP: K must be static"),
1797            };
1798            let n_static = match n_out {
1799                Dim::Static(v) => v,
1800                _ => panic!("DequantGroupedMatMul VJP: N must be static"),
1801            };
1802            let block_elems = scheme.gguf_block_size() as usize;
1803            let block_bytes = scheme.gguf_block_bytes() as usize;
1804            let slab_bytes = (k_static * n_static) / block_elems * block_bytes;
1805            let total_bytes = w_packed_shape
1806                .num_elements()
1807                .expect("DequantGroupedMatMul VJP: dyn packed");
1808            let e_static = total_bytes / slab_bytes.max(1);
1809            let w_shape = Shape::from_dims(
1810                &[
1811                    Dim::Static(e_static),
1812                    Dim::Static(k_static),
1813                    Dim::Static(n_static),
1814                ],
1815                dtype,
1816            );
1817            let w_dq = bwd.add_node(
1818                Op::DequantMoEWeights { scheme: *scheme },
1819                vec![w_packed],
1820                w_shape.clone(),
1821            );
1822            let (dx, _dw) =
1823                grouped_matmul_vjp(bwd, upstream, x_bwd, w_dq, expert_bwd, &x_shape, &w_shape);
1824            vec![(0, dx)]
1825        }
1826
1827        // ── QMatMul / QConv2d (straight-through INT8 backward) ──
1828        //
1829        // Real INT8 inference kernels. The forward applies
1830        //   out = clamp(round((x − x_zp) · (w − w_zp) · mult + bias)
1831        //               + out_zp, [-128, 127])
1832        // and outputs i8. For training, the standard QAT recipe
1833        // treats the round/clamp/quantize as straight-through, so
1834        // the gradient is what a plain f32 MatMul (or Conv) backward
1835        // would give applied to the dequantized representations.
1836        // Zero-points and `mult` are typically frozen (calibration
1837        // outputs); we emit zero gradients for them. Bias gets the
1838        // standard sum-over-batch gradient.
1839        Op::QMatMul {
1840            x_zp,
1841            w_zp,
1842            out_zp: _,
1843            mult,
1844        } => {
1845            let x_bwd = fwd_map[&node.inputs[0]];
1846            let w_bwd = fwd_map[&node.inputs[1]];
1847            let bias_bwd = fwd_map[&node.inputs[2]];
1848            let x_shape = bwd.node(x_bwd).shape.clone();
1849            let w_shape = bwd.node(w_bwd).shape.clone();
1850            let bias_shape = bwd.node(bias_bwd).shape.clone();
1851            let dtype = upstream_shape.dtype();
1852
1853            // Promote x and w to f32 (straight-through); subtract zps.
1854            let x_f32 = bwd.add_node(
1855                Op::Cast { to: dtype },
1856                vec![x_bwd],
1857                Shape::from_dims(x_shape.dims(), dtype),
1858            );
1859            let w_f32 = bwd.add_node(
1860                Op::Cast { to: dtype },
1861                vec![w_bwd],
1862                Shape::from_dims(w_shape.dims(), dtype),
1863            );
1864            let xzp_c = scalar_const(*x_zp as f32, bwd);
1865            let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
1866            let _ = bwd.binary(
1867                BinaryOp::Sub,
1868                x_f32,
1869                xzp_b,
1870                Shape::from_dims(x_shape.dims(), dtype),
1871            );
1872            let wzp_c = scalar_const(*w_zp as f32, bwd);
1873            let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1874            let w_centered = bwd.binary(
1875                BinaryOp::Sub,
1876                w_f32,
1877                wzp_b,
1878                Shape::from_dims(w_shape.dims(), dtype),
1879            );
1880
1881            // mult scaling.
1882            let mult_c = scalar_const(*mult, bwd);
1883            let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
1884            let upstream_scaled =
1885                bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
1886
1887            // dx = upstream_scaled @ w_centered^T   (still i8 dtype
1888            //  on the input side; cast the gradient back).
1889            let w_rank = w_shape.rank();
1890            let mut perm: Vec<usize> = (0..w_rank).collect();
1891            perm.swap(w_rank - 2, w_rank - 1);
1892            let mut wt_dims: Vec<Dim> = w_shape.dims().to_vec();
1893            wt_dims.swap(w_rank - 2, w_rank - 1);
1894            let w_t = bwd.add_node(
1895                Op::Transpose { perm },
1896                vec![w_centered],
1897                Shape::from_dims(&wt_dims, dtype),
1898            );
1899            let dx_f32 = bwd.matmul(
1900                upstream_scaled,
1901                w_t,
1902                Shape::from_dims(x_shape.dims(), dtype),
1903            );
1904            let dx = bwd.add_node(
1905                Op::Cast {
1906                    to: x_shape.dtype(),
1907                },
1908                vec![dx_f32],
1909                x_shape.clone(),
1910            );
1911
1912            // dw = x_centered^T @ upstream_scaled  (similarly cast).
1913            let x_rank = x_shape.rank();
1914            let mut x_perm: Vec<usize> = (0..x_rank).collect();
1915            x_perm.swap(x_rank - 2, x_rank - 1);
1916            let mut xt_dims: Vec<Dim> = x_shape.dims().to_vec();
1917            xt_dims.swap(x_rank - 2, x_rank - 1);
1918            // Need to pull x_centered into scope — recompute inline.
1919            let x_f32_2 = bwd.add_node(
1920                Op::Cast { to: dtype },
1921                vec![x_bwd],
1922                Shape::from_dims(x_shape.dims(), dtype),
1923            );
1924            let x_centered = bwd.binary(
1925                BinaryOp::Sub,
1926                x_f32_2,
1927                xzp_b,
1928                Shape::from_dims(x_shape.dims(), dtype),
1929            );
1930            let x_t = bwd.add_node(
1931                Op::Transpose { perm: x_perm },
1932                vec![x_centered],
1933                Shape::from_dims(&xt_dims, dtype),
1934            );
1935            let dw_f32 = bwd.matmul(
1936                x_t,
1937                upstream_scaled,
1938                Shape::from_dims(w_shape.dims(), dtype),
1939            );
1940            let dw = bwd.add_node(
1941                Op::Cast {
1942                    to: w_shape.dtype(),
1943                },
1944                vec![dw_f32],
1945                w_shape,
1946            );
1947
1948            // dbias = sum upstream_scaled over batch axes (matches
1949            // f32 MatMul-with-bias backward shape).
1950            let bias_rank = bias_shape.rank();
1951            let reduce_axes: Vec<usize> = (0..upstream_shape.rank())
1952                .filter(|&i| i + bias_rank < upstream_shape.rank() || i == 0)
1953                .collect();
1954            let dbias_f32 = bwd.add_node(
1955                Op::Reduce {
1956                    op: ReduceOp::Sum,
1957                    axes: reduce_axes,
1958                    keep_dim: false,
1959                },
1960                vec![upstream_scaled],
1961                Shape::from_dims(bias_shape.dims(), dtype),
1962            );
1963            let dbias = bwd.add_node(
1964                Op::Cast {
1965                    to: bias_shape.dtype(),
1966                },
1967                vec![dbias_f32],
1968                bias_shape,
1969            );
1970
1971            vec![(0, dx), (1, dw), (2, dbias)]
1972        }
1973
1974        Op::QConv2d {
1975            kernel_size,
1976            stride,
1977            padding,
1978            dilation,
1979            groups,
1980            x_zp,
1981            w_zp,
1982            out_zp: _,
1983            mult,
1984        } => {
1985            // Same straight-through pattern as QMatMul, lifted to
1986            // 2-D conv via the existing Conv2dBackwardInput / Weight
1987            // kernels.
1988            let x_bwd = fwd_map[&node.inputs[0]];
1989            let w_bwd = fwd_map[&node.inputs[1]];
1990            let bias_bwd = fwd_map[&node.inputs[2]];
1991            let x_shape = bwd.node(x_bwd).shape.clone();
1992            let w_shape = bwd.node(w_bwd).shape.clone();
1993            let bias_shape = bwd.node(bias_bwd).shape.clone();
1994            let dtype = upstream_shape.dtype();
1995
1996            // Promote and dequantize.
1997            let x_f32 = bwd.add_node(
1998                Op::Cast { to: dtype },
1999                vec![x_bwd],
2000                Shape::from_dims(x_shape.dims(), dtype),
2001            );
2002            let w_f32 = bwd.add_node(
2003                Op::Cast { to: dtype },
2004                vec![w_bwd],
2005                Shape::from_dims(w_shape.dims(), dtype),
2006            );
2007            let xzp_c = scalar_const(*x_zp as f32, bwd);
2008            let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
2009            let x_centered = bwd.binary(
2010                BinaryOp::Sub,
2011                x_f32,
2012                xzp_b,
2013                Shape::from_dims(x_shape.dims(), dtype),
2014            );
2015            let wzp_c = scalar_const(*w_zp as f32, bwd);
2016            let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
2017            let w_centered = bwd.binary(
2018                BinaryOp::Sub,
2019                w_f32,
2020                wzp_b,
2021                Shape::from_dims(w_shape.dims(), dtype),
2022            );
2023
2024            // mult scaling on upstream.
2025            let mult_c = scalar_const(*mult, bwd);
2026            let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
2027            let upstream_scaled =
2028                bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
2029
2030            // dx, dw via the existing conv-backward kernels.
2031            let dx_f32 = bwd.conv2d_backward_input(
2032                upstream_scaled,
2033                w_centered,
2034                Shape::from_dims(x_shape.dims(), dtype),
2035                kernel_size.clone(),
2036                stride.clone(),
2037                padding.clone(),
2038                dilation.clone(),
2039                *groups,
2040            );
2041            let dx = bwd.add_node(
2042                Op::Cast {
2043                    to: x_shape.dtype(),
2044                },
2045                vec![dx_f32],
2046                x_shape,
2047            );
2048            let dw_f32 = bwd.conv2d_backward_weight(
2049                x_centered,
2050                upstream_scaled,
2051                Shape::from_dims(w_shape.dims(), dtype),
2052                kernel_size.clone(),
2053                stride.clone(),
2054                padding.clone(),
2055                dilation.clone(),
2056                *groups,
2057            );
2058            let dw = bwd.add_node(
2059                Op::Cast {
2060                    to: w_shape.dtype(),
2061                },
2062                vec![dw_f32],
2063                w_shape,
2064            );
2065
2066            // dbias = sum upstream_scaled over (N, H_out, W_out) keeping C_out.
2067            let dbias_f32 = bwd.add_node(
2068                Op::Reduce {
2069                    op: ReduceOp::Sum,
2070                    axes: vec![0, 2, 3],
2071                    keep_dim: false,
2072                },
2073                vec![upstream_scaled],
2074                Shape::from_dims(bias_shape.dims(), dtype),
2075            );
2076            let dbias = bwd.add_node(
2077                Op::Cast {
2078                    to: bias_shape.dtype(),
2079                },
2080                vec![dbias_f32],
2081                bias_shape,
2082            );
2083
2084            vec![(0, dx), (1, dw), (2, dbias)]
2085        }
2086
2087        // ── Sampling-style ops: non-differentiable ──
2088        Op::TopK { .. } | Op::Sample { .. } | Op::RngNormal { .. } | Op::RngUniform { .. } => {
2089            // TopK selects; Sample multinomial-draws. Gradient w.r.t.
2090            // the input distribution is undefined / zero in the
2091            // standard sense. Skip propagation.
2092            vec![]
2093        }
2094
2095        Op::GaussianSplatRender {
2096            width,
2097            height,
2098            tile_size,
2099            radius_scale,
2100            alpha_cutoff,
2101            max_splat_steps,
2102            transmittance_threshold,
2103            max_list_entries,
2104            ..
2105        } => {
2106            use rlx_ir::ops::splat::{
2107                GaussianSplatBackwardParams, GaussianSplatInputs, GaussianSplatRenderParams,
2108                unpack_gaussian_splat_packed_grads,
2109            };
2110            let render = GaussianSplatRenderParams {
2111                width: *width,
2112                height: *height,
2113                tile_size: *tile_size,
2114                radius_scale: *radius_scale,
2115                alpha_cutoff: *alpha_cutoff,
2116                max_splat_steps: *max_splat_steps,
2117                transmittance_threshold: *transmittance_threshold,
2118                max_list_entries: *max_list_entries,
2119            };
2120            let inputs = GaussianSplatInputs {
2121                positions: fwd_map[&node.inputs[0]],
2122                scales: fwd_map[&node.inputs[1]],
2123                rotations: fwd_map[&node.inputs[2]],
2124                opacities: fwd_map[&node.inputs[3]],
2125                colors: fwd_map[&node.inputs[4]],
2126                sh_coeffs: fwd_map[&node.inputs[5]],
2127                meta: fwd_map[&node.inputs[6]],
2128            };
2129            let count = bwd.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
2130            let sh_len = bwd.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
2131            let meta_shape = bwd.shape(inputs.meta).clone();
2132            let packed = bwd.gaussian_splat_render_backward(
2133                inputs,
2134                upstream,
2135                GaussianSplatBackwardParams {
2136                    render,
2137                    loss_grad_clip: 1.0,
2138                    sh_band: 0,
2139                    max_anisotropy: 10.0,
2140                },
2141            );
2142            let sh_coeff_count = if count == 0 {
2143                1
2144            } else {
2145                (sh_len / (count * 3)).max(1)
2146            };
2147            let grads = unpack_gaussian_splat_packed_grads(bwd, packed, count, sh_coeff_count);
2148            let meta_n = meta_shape.num_elements().unwrap_or(0);
2149            let zero_meta = bwd.add_node(
2150                Op::Constant {
2151                    data: vec![0u8; meta_n * meta_shape.dtype().size_bytes()],
2152                },
2153                vec![],
2154                meta_shape,
2155            );
2156            vec![
2157                (0, grads.positions),
2158                (1, grads.scales),
2159                (2, grads.rotations),
2160                (3, grads.opacities),
2161                (4, grads.colors),
2162                (5, grads.sh_coeffs),
2163                (6, zero_meta),
2164            ]
2165        }
2166
2167        Op::GaussianSplatRenderBackward { .. } => {
2168            // Scene/meta inputs are not differentiated through this op in v1.
2169            vec![]
2170        }
2171
2172        Op::GaussianSplatPrepare { .. } | Op::GaussianSplatRasterize { .. } => {
2173            panic!(
2174                "autodiff: decomposed splat ops must be fused before AD — \
2175                 `prepare_graph_for_ad` rewrites Prepare→Rasterize into \
2176                 `GaussianSplatRender`, or use `Op::GaussianSplatRender` directly"
2177            );
2178        }
2179
2180        // ── Anything else: explicit panic with op name ──
2181        //
2182        // All ops in the IR have either a per-op VJP rule above
2183        // or a pre-pass rewrite that decomposes them into ops
2184        // that do:
2185        //   * Op::If → control_flow::inline_if (Where + inlined
2186        //     branches).
2187        //   * Op::While → control_flow::unroll_while (bounded
2188        //     unroll up to max_iterations).
2189        //   * Op::SelectiveScan / Op::FusedTransformerLayer /
2190        //     Op::FusedAttentionBlock / Op::FusedSwiGLU /
2191        //     Op::LoraMatMul / Op::FusedMatMulBiasAct /
2192        //     Op::FusedResidualLN → rlx_fusion::unfuse_fused_for_autodiff.
2193        //
2194        // User-defined sub-graph (Op::CustomFn) with override AD.
2195        // When `vjp_body` is supplied, inline it into `bwd`: each
2196        // primal Op::Input maps to the outer forward NodeId for that
2197        // primal; the special-named "primal_output" Input maps to the
2198        // forward NodeId of this CustomFn node; "d_output" maps to
2199        // `upstream`. The body's N outputs become this op's N input
2200        // gradients in declaration order.
2201        Op::CustomFn {
2202            vjp_body: Some(vjp_body),
2203            num_inputs,
2204            ..
2205        } => {
2206            // Map vjp_body NodeIds → bwd NodeIds.
2207            let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
2208
2209            // Collect primal-input NodeIds from vjp_body (excluding
2210            // special names), sorted by NodeId. Position k in this list
2211            // matches the outer node's input k.
2212            let mut primal_input_ids: Vec<NodeId> = vjp_body
2213                .nodes()
2214                .iter()
2215                .filter_map(|n| match &n.op {
2216                    Op::Input { name } if name != "primal_output" && name != "d_output" => {
2217                        Some(n.id)
2218                    }
2219                    _ => None,
2220                })
2221                .collect();
2222            primal_input_ids.sort();
2223            assert_eq!(primal_input_ids.len(), *num_inputs as usize);
2224
2225            // Walk vjp_body in declaration order, cloning each non-Input
2226            // node into bwd with input remapping.
2227            for sub_node in vjp_body.nodes() {
2228                let new_id = match &sub_node.op {
2229                    Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
2230                    Op::Input { name } if name == "d_output" => upstream,
2231                    Op::Input { .. } => {
2232                        // Find this Input's index in primal_input_ids.
2233                        let idx = primal_input_ids
2234                            .iter()
2235                            .position(|&id| id == sub_node.id)
2236                            .expect(
2237                                "custom_fn vjp_body: primal Input \
2238                                     not found in primal list",
2239                            );
2240                        fwd_map[&node.inputs[idx]]
2241                    }
2242                    _ => {
2243                        let new_inputs: Vec<NodeId> =
2244                            sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
2245                        bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
2246                    }
2247                };
2248                sub_to_bwd.insert(sub_node.id, new_id);
2249            }
2250
2251            // Collect outputs in set_outputs order — each maps to a
2252            // primal-input gradient.
2253            let mut grads: Vec<(usize, NodeId)> = Vec::with_capacity(*num_inputs as usize);
2254            for (i, out_id) in vjp_body.outputs.iter().enumerate() {
2255                grads.push((i, sub_to_bwd[out_id]));
2256            }
2257            grads
2258        }
2259
2260        // CustomFn without vjp_body is inlined by `inline_custom_fn_for_autodiff`
2261        // before the reverse walk — reaching here means the pre-pass missed it.
2262        Op::CustomFn { vjp_body: None, .. } => {
2263            panic!(
2264                "autodiff: Op::CustomFn has no vjp_body and was not inlined. \
2265                 This is an internal error in inline_custom_fn_for_autodiff."
2266            )
2267        }
2268
2269        // User-registered custom op — dispatch the VJP through the
2270        // op registry. The impl emits gradient nodes via the same
2271        // `bwd` builder built-in arms use; default impl returns
2272        // `vec![]` (non-differentiable).
2273        Op::Custom { name, .. } => {
2274            let ext = rlx_ir::lookup_op(name).unwrap_or_else(|| {
2275                panic!(
2276                    "autodiff: Op::Custom('{name}') is not registered \
2277                        in the op registry — register it via \
2278                        rlx_ir::register_op before compiling the graph"
2279                )
2280            });
2281            let mut ctx = rlx_ir::VjpContext {
2282                upstream,
2283                fwd_map,
2284                bwd,
2285            };
2286            ext.vjp(node, &mut ctx)
2287        }
2288
2289        Op::Conv2dBackwardInput {
2290            kernel_size,
2291            stride,
2292            padding,
2293            dilation,
2294            groups,
2295        } => {
2296            let dy_bwd = fwd_map[&node.inputs[0]];
2297            let w_bwd = fwd_map[&node.inputs[1]];
2298            let dy_shape = bwd.node(dy_bwd).shape.clone();
2299            let _x_shape = node.shape.clone();
2300            let d_dy = bwd.add_node(
2301                Op::Conv {
2302                    kernel_size: kernel_size.clone(),
2303                    stride: stride.clone(),
2304                    padding: padding.clone(),
2305                    dilation: dilation.clone(),
2306                    groups: *groups,
2307                },
2308                vec![upstream, w_bwd],
2309                dy_shape,
2310            );
2311            vec![(0, d_dy)]
2312        }
2313
2314        Op::Conv2dBackwardWeight {
2315            kernel_size,
2316            stride,
2317            padding,
2318            dilation,
2319            groups,
2320        } => {
2321            let x_bwd = fwd_map[&node.inputs[0]];
2322            let dy_bwd = fwd_map[&node.inputs[1]];
2323            let x_shape = bwd.node(x_bwd).shape.clone();
2324            let dy_shape = bwd.node(dy_bwd).shape.clone();
2325            let d_x = bwd.conv2d_backward_input(
2326                dy_bwd,
2327                upstream,
2328                x_shape,
2329                kernel_size.clone(),
2330                stride.clone(),
2331                padding.clone(),
2332                dilation.clone(),
2333                *groups,
2334            );
2335            let d_dy = bwd.add_node(
2336                Op::Conv {
2337                    kernel_size: kernel_size.clone(),
2338                    stride: stride.clone(),
2339                    padding: padding.clone(),
2340                    dilation: dilation.clone(),
2341                    groups: *groups,
2342                },
2343                vec![x_bwd, upstream],
2344                dy_shape,
2345            );
2346            vec![(0, d_x), (1, d_dy)]
2347        }
2348
2349        // 1D FFT: y = fft(x; inverse). Both forward and inverse are
2350        // unnormalized linear operators on the 2N real-block layout,
2351        // and the DFT matrix's transpose (over the real-block view)
2352        // equals the unnormalized inverse DFT. So:
2353        //   VJP(fft)  = ifft(upstream)
2354        //   VJP(ifft) = fft(upstream)
2355        // No scaling — the choice to leave both directions unnormalized
2356        // makes the chain rule a flag flip and nothing else.
2357        Op::Fft { inverse, norm } => {
2358            let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
2359            let s = norm.output_scale(n, *inverse) as f32;
2360            let z = if s != 1.0 {
2361                let sc = scalar_const(s, bwd);
2362                bwd.mul(upstream, sc)
2363            } else {
2364                upstream
2365            };
2366            let dx = bwd.fft(z, !*inverse);
2367            vec![(0, dx)]
2368        }
2369
2370        Op::LogMel => {
2371            let spec_bwd = fwd_map[&node.inputs[0]];
2372            let filt_bwd = fwd_map[&node.inputs[1]];
2373            let dx = bwd.log_mel_backward(spec_bwd, filt_bwd, upstream);
2374            vec![(0, dx)]
2375        }
2376
2377        // The catch-all below remains as a safety net: if a
2378        // future op is added without a VJP rule, this panic
2379        // names it for the implementer.
2380        other => panic!(
2381            "autodiff: no VJP rule for {other}. See the matching \
2382             entry in rlx-opt/src/autodiff.rs (catch-all panic) for \
2383             a pointer to what's needed to differentiate this op.",
2384        ),
2385    }
2386}
2387
2388/// Decompose tier-2 fused ops back to their primitive components so
2389/// the per-op VJP rules cover them. Conceptually identical to what a
2390/// "training-aware compile" pipeline would do as a pre-pass: avoid
2391/// running `FuseMatMulBiasAct` / `FuseResidualLN` / `FuseSwiGLU` /
2392/// `FuseAttentionBlock` if you plan to autodiff afterward. This
2393/// helper handles the case where they're already in the graph (e.g.
2394/// from a re-trained inference model).
2395///
2396/// Decomposed today: `FusedMatMulBiasAct`, `FusedResidualLN`,
2397/// `LoraMatMul`, `FusedSwiGLU`, `FusedAttentionBlock`,
2398/// `FusedTransformerLayer`, and `SelectiveScan` / `GatedDeltaNet` —
2399/// each rewritten back to its primitive chain (matmul / narrow / attention /
2400/// layer_norm / residual / activation, plus reduce-sum / concat /
2401/// Pre-AD pass: convert every `Op::Scan { save_trajectory: false }`
2402/// into `Op::Scan { save_trajectory: true }` followed by `Narrow` +
2403/// `Reshape` to extract the final carry. After this rewrite, every
2404/// scan in the graph carries its full trajectory — which is what the
2405/// VJP rule needs to compute backward through time. The user-facing
2406/// shape is unchanged (Narrow + Reshape collapse [length, *carry]
2407/// back down to *carry).
2408///
2409/// Memory cost: trajectory storage is now `O(length × carry_size)`
2410/// for the duration of the forward + backward pass. For Diffrax-style
2411/// transients this is the same as Diffrax's `RecursiveCheckpointAdjoint::All`
2412/// strategy. Recursive checkpointing is a future pass.
2413/// Pre-AD pass: rewrite `Op::Scan` nodes with `num_bcast > 0` into
2414/// equivalent `num_bcast = 0` scans by materialising each broadcast
2415/// input `b` of shape `*bcast` into a per-step xs of shape
2416/// `[length, *bcast]` (built as `ones([length, *bcast]) × b`). The
2417/// reverse-mode AD walk and the rest of `convert_scans_for_ad` then
2418/// see only carry+xs scans — the bcast channel is a forward-only
2419/// memory optimisation, transparent to backward.
2420fn materialize_bcasts_for_ad(g: Graph) -> Graph {
2421    use rlx_ir::op::BinaryOp;
2422
2423    let needs = g.nodes().iter().any(|n| {
2424        matches!(
2425            &n.op, Op::Scan { num_bcast, .. } if *num_bcast > 0
2426        )
2427    });
2428    if !needs {
2429        return g;
2430    }
2431
2432    let mut out = Graph::new(g.name.clone());
2433    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2434
2435    for node in g.nodes() {
2436        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2437        match &node.op {
2438            Op::Scan {
2439                body,
2440                length,
2441                save_trajectory,
2442                num_bcast,
2443                num_xs,
2444                num_checkpoints,
2445            } if *num_bcast > 0 => {
2446                // Each bcast input gets multiplied by an
2447                // `[length, 1, ..., 1]` ones constant of matching dtype
2448                // (broadcast against the bcast's natural shape) to
2449                // produce a `[length, *bcast]` materialised xs.
2450                let bcast_base = 1;
2451                let xs_base = 1 + *num_bcast as usize;
2452
2453                let mut new_scan_inputs = vec![new_inputs[0]];
2454
2455                // Original xs first remain xs.
2456                let mut materialised_xs: Vec<NodeId> = Vec::new();
2457                for i in 0..*num_bcast as usize {
2458                    let b_id = new_inputs[bcast_base + i];
2459                    let b_shape = out.node(b_id).shape.clone();
2460                    let dtype = b_shape.dtype();
2461
2462                    // ones with shape [length, 1, 1, ...] (matching b's rank
2463                    // beyond the leading axis we're prepending). Broadcast
2464                    // against b of shape [*bcast] gives [length, *bcast].
2465                    let mut ones_dims: Vec<rlx_ir::Dim> =
2466                        vec![rlx_ir::Dim::Static(*length as usize)];
2467                    for _ in 0..b_shape.rank() {
2468                        ones_dims.push(rlx_ir::Dim::Static(1));
2469                    }
2470                    let ones_shape = rlx_ir::Shape::from_dims(&ones_dims, dtype);
2471                    let n_elems: usize = ones_dims
2472                        .iter()
2473                        .map(|d| match d {
2474                            rlx_ir::Dim::Static(n) => *n,
2475                            rlx_ir::Dim::Dynamic(_) => 1,
2476                        })
2477                        .product();
2478                    let elem_size = dtype.size_bytes();
2479                    let mut data = Vec::with_capacity(n_elems * elem_size);
2480                    match dtype {
2481                        rlx_ir::DType::F64 => {
2482                            for _ in 0..n_elems {
2483                                data.extend_from_slice(&1.0_f64.to_le_bytes());
2484                            }
2485                        }
2486                        rlx_ir::DType::F32 => {
2487                            for _ in 0..n_elems {
2488                                data.extend_from_slice(&1.0_f32.to_le_bytes());
2489                            }
2490                        }
2491                        other => {
2492                            panic!("materialize_bcasts_for_ad: unsupported bcast dtype {other:?}")
2493                        }
2494                    }
2495                    let ones = out.add_node(Op::Constant { data }, vec![], ones_shape);
2496
2497                    // Output shape of broadcast Mul: [length, *bcast].
2498                    let mut xs_dims: Vec<rlx_ir::Dim> = vec![rlx_ir::Dim::Static(*length as usize)];
2499                    for i in 0..b_shape.rank() {
2500                        xs_dims.push(b_shape.dim(i));
2501                    }
2502                    let xs_shape = rlx_ir::Shape::from_dims(&xs_dims, dtype);
2503                    let xs_id = out.add_node(Op::Binary(BinaryOp::Mul), vec![ones, b_id], xs_shape);
2504                    materialised_xs.push(xs_id);
2505                }
2506
2507                new_scan_inputs.extend_from_slice(&materialised_xs);
2508                for i in 0..*num_xs as usize {
2509                    new_scan_inputs.push(new_inputs[xs_base + i]);
2510                }
2511
2512                let new_id = out.add_node(
2513                    Op::Scan {
2514                        body: body.clone(),
2515                        length: *length,
2516                        save_trajectory: *save_trajectory,
2517                        num_bcast: 0,
2518                        num_xs: *num_bcast + *num_xs,
2519                        num_checkpoints: *num_checkpoints,
2520                    },
2521                    new_scan_inputs,
2522                    node.shape.clone(),
2523                );
2524                id_map.insert(node.id, new_id);
2525            }
2526            _ => {
2527                let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2528                id_map.insert(node.id, new_id);
2529            }
2530        }
2531    }
2532
2533    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2534    out.set_outputs(new_outputs);
2535    out
2536}
2537
2538pub fn convert_scans_for_ad(g: Graph) -> Graph {
2539    use rlx_ir::shape::Shape as IrShape;
2540
2541    // First, materialise broadcast inputs into per-step xs. The AD
2542    // walk and the rest of this pre-pass don't know about bcasts
2543    // (forward-only memory optimisation); after this rewrite the bwd
2544    // graph treats them as regular xs.
2545    let g = materialize_bcasts_for_ad(g);
2546
2547    // Quick check: does any scan need rewriting? Avoid a full graph
2548    // rebuild when the input is already trajectory-only.
2549    let needs = g.nodes().iter().any(|n| {
2550        matches!(
2551            &n.op,
2552            Op::Scan {
2553                save_trajectory: false,
2554                ..
2555            }
2556        )
2557    });
2558    if !needs {
2559        return g;
2560    }
2561
2562    let mut out = Graph::new(g.name.clone());
2563    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2564
2565    for node in g.nodes() {
2566        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2567        match &node.op {
2568            Op::Scan {
2569                body,
2570                length,
2571                save_trajectory: false,
2572                num_xs,
2573                num_checkpoints,
2574                ..
2575            } => {
2576                let carry_shape = node.shape.clone();
2577                // Trajectory shape: [length, *carry_shape].
2578                //
2579                // NB: when `num_checkpoints` is set (recursive
2580                // checkpointing), the executor only writes `K` rows
2581                // into this buffer (the saved checkpoints, indexed by
2582                // k=0..K-1 at offsets 0..K·cb). Rows K..length-1 stay
2583                // zero. The Narrow + Reshape below extracts row
2584                // `length-1`, which is **zero** in checkpointed mode
2585                // — i.e. the rewritten forward output is wrong (the
2586                // FORWARD value of `scan_checkpointed` followed by a
2587                // direct read is not currently supported).
2588                //
2589                // Backward gradients are still correct: Narrow's VJP
2590                // scatters the upstream into row `length-1` of the
2591                // gradient tensor, ScanBackward reads upstream[t·cb]
2592                // for t in 0..length finds zero everywhere except at
2593                // t=length-1 where it picks up `d_loss`, and the
2594                // segment-cached recompute uses the K saved
2595                // checkpoints (at offsets 0..K·cb) plus the forward
2596                // body to reconstruct intermediate carries.
2597                let mut traj_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2598                traj_dims.push(Dim::Static(*length as usize));
2599                for i in 0..carry_shape.rank() {
2600                    traj_dims.push(carry_shape.dim(i));
2601                }
2602                let traj_shape = IrShape::from_dims(&traj_dims, carry_shape.dtype());
2603                let traj = out.add_node(
2604                    Op::Scan {
2605                        body: body.clone(),
2606                        length: *length,
2607                        save_trajectory: true,
2608                        num_bcast: 0,
2609                        num_xs: *num_xs,
2610                        num_checkpoints: *num_checkpoints,
2611                    },
2612                    new_inputs,
2613                    traj_shape,
2614                );
2615                // Narrow last row → [1, *carry].
2616                let mut narrow_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2617                narrow_dims.push(Dim::Static(1));
2618                for i in 0..carry_shape.rank() {
2619                    narrow_dims.push(carry_shape.dim(i));
2620                }
2621                let narrow_shape = IrShape::from_dims(&narrow_dims, carry_shape.dtype());
2622                let narrowed = out.add_node(
2623                    Op::Narrow {
2624                        axis: 0,
2625                        start: (*length as usize).saturating_sub(1),
2626                        len: 1,
2627                    },
2628                    vec![traj],
2629                    narrow_shape,
2630                );
2631                // Reshape to drop the leading 1 → carry_shape.
2632                let new_shape: Vec<i64> = (0..carry_shape.rank())
2633                    .map(|i| match carry_shape.dim(i) {
2634                        Dim::Static(n) => n as i64,
2635                        Dim::Dynamic(_) => -1,
2636                    })
2637                    .collect();
2638                let final_id = out.add_node(Op::Reshape { new_shape }, vec![narrowed], carry_shape);
2639                id_map.insert(node.id, final_id);
2640            }
2641            _ => {
2642                let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2643                id_map.insert(node.id, new_id);
2644            }
2645        }
2646    }
2647
2648    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2649    out.set_outputs(new_outputs);
2650    out
2651}
2652
2653/// Pre-AD pass: inline `Op::CustomFn` nodes that have neither a
2654/// `vjp_body` nor a `jvp_body` by expanding their `fwd_body` into the
2655/// parent graph. When either override body is present, keep the
2656/// `CustomFn` wrapper so reverse- / forward-mode AD can dispatch to it.
2657pub fn inline_custom_fn_for_autodiff(g: Graph) -> Graph {
2658    use rlx_fusion::control_flow::inline_subgraph_into;
2659
2660    let mut out = Graph::new(g.name.clone());
2661    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2662    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
2663
2664    for node in &nodes {
2665        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2666        let new_id = match &node.op {
2667            Op::CustomFn {
2668                vjp_body: None,
2669                jvp_body: None,
2670                fwd_body,
2671                num_inputs,
2672                ..
2673            } => {
2674                assert_eq!(
2675                    new_inputs.len(),
2676                    *num_inputs as usize,
2677                    "custom_fn: outer input count mismatch"
2678                );
2679                inline_subgraph_into(fwd_body, &new_inputs, &mut out)
2680            }
2681            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
2682        };
2683        id_map.insert(node.id, new_id);
2684    }
2685
2686    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
2687    out.set_outputs(new_outputs);
2688    out
2689}
2690
2691/// Inverse of `unbroadcast`: broadcast a small tensor up to a target
2692/// shape via `Op::Expand`. Convenience wrapper for the few VJPs that
2693/// need it.
2694pub(crate) fn unbroadcast_inverse(x: NodeId, target: &Shape, bwd: &mut Graph) -> NodeId {
2695    let target_dims: Vec<i64> = target
2696        .dims()
2697        .iter()
2698        .map(|d| match d {
2699            Dim::Static(n) => *n as i64,
2700            Dim::Dynamic(_) => -1,
2701        })
2702        .collect();
2703    bwd.add_node(
2704        Op::Expand {
2705            target_shape: target_dims,
2706        },
2707        vec![x],
2708        target.clone(),
2709    )
2710}
2711
2712/// Expand a gradient back to its pre-reduction shape: optionally
2713/// reshape to insert size-1 axes (when forward had `keep_dim=false`),
2714/// then `Op::Expand` to broadcast to `x_shape`. The reverse of
2715/// `Reduce::Sum`.
2716fn expand_to(
2717    grad: NodeId,
2718    x_shape: &Shape,
2719    axes: &[usize],
2720    keep_dim: bool,
2721    bwd: &mut Graph,
2722) -> NodeId {
2723    let mut current = grad;
2724    if !keep_dim {
2725        // Insert size-1 axes at the reduced positions so the rank
2726        // matches x_shape and Expand can broadcast cleanly.
2727        let kept_dims: Vec<Dim> = (0..x_shape.rank())
2728            .map(|i| {
2729                if axes.contains(&i) {
2730                    Dim::Static(1)
2731                } else {
2732                    x_shape.dim(i)
2733                }
2734            })
2735            .collect();
2736        let kept = Shape::from_dims(&kept_dims, x_shape.dtype());
2737        current = reshape_to(current, &kept, bwd);
2738    }
2739    let target_shape: Vec<i64> = x_shape
2740        .dims()
2741        .iter()
2742        .map(|d| match d {
2743            Dim::Static(n) => *n as i64,
2744            Dim::Dynamic(_) => -1,
2745        })
2746        .collect();
2747    bwd.add_node(Op::Expand { target_shape }, vec![current], x_shape.clone())
2748}
2749
2750#[cfg(test)]
2751mod tests {
2752    use super::*;
2753
2754    #[test]
2755    fn grad_of_add_is_identity() {
2756        let mut g = Graph::new("test");
2757        let x = g.input("x", Shape::new(&[4], DType::F32));
2758        let y = g.input("y", Shape::new(&[4], DType::F32));
2759        let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2760        g.set_outputs(vec![z]);
2761
2762        let bwd = grad(&g, &[x, y]);
2763        // bwd graph should expose two outputs: dz/dx and dz/dy, both = d_output.
2764        assert_eq!(bwd.outputs.len(), 2);
2765    }
2766
2767    #[test]
2768    fn grad_of_mul_uses_other_operand() {
2769        let mut g = Graph::new("test");
2770        let x = g.input("x", Shape::new(&[4], DType::F32));
2771        let y = g.input("y", Shape::new(&[4], DType::F32));
2772        let z = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
2773        g.set_outputs(vec![z]);
2774
2775        let bwd = grad(&g, &[x, y]);
2776        // bwd should contain Mul nodes (upstream * y, upstream * x).
2777        assert!(
2778            bwd.nodes()
2779                .iter()
2780                .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
2781                .count()
2782                >= 2
2783        );
2784    }
2785
2786    #[test]
2787    fn grad_with_loss_returns_loss_first() {
2788        let mut g = Graph::new("loss");
2789        let x = g.input("x", Shape::new(&[4], DType::F32));
2790        let y = g.input("y", Shape::new(&[4], DType::F32));
2791        let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2792        g.set_outputs(vec![z]);
2793
2794        let bwd = grad_with_loss(&g, &[x, y]);
2795        // [loss, dz/dx, dz/dy] — three outputs.
2796        assert_eq!(bwd.outputs.len(), 3);
2797    }
2798
2799    #[test]
2800    fn grad_of_dense_solve_emits_implicit_function_rule() {
2801        // Forward:
2802        //   A      : Param [2,2]
2803        //   b      : Input [2]
2804        //   x      = solve(A, b)
2805        //   loss   = sum(x)         (scalar)
2806        //
2807        // Backward must contain:
2808        //   - a Transpose of A
2809        //   - a second DenseSolve (dx_int = solve(Aᵀ, upstream))
2810        //   - a MatMul (the outer product dx_int · xᵀ)
2811        //   - a Neg (the −outer)
2812        //
2813        // Outputs are [loss, dA, db].
2814        let mut g = Graph::new("solve_test");
2815        let a = g.param("A", Shape::new(&[2, 2], DType::F32));
2816        let b = g.input("b", Shape::new(&[2], DType::F32));
2817        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F32));
2818        let loss = g.reduce(
2819            x,
2820            ReduceOp::Sum,
2821            vec![0],
2822            false,
2823            Shape::new(&[1], DType::F32),
2824        );
2825        g.set_outputs(vec![loss]);
2826
2827        let bwd = grad_with_loss(&g, &[a, b]);
2828        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
2829
2830        let count =
2831            |pred: fn(&Op) -> bool| -> usize { bwd.nodes().iter().filter(|n| pred(&n.op)).count() };
2832
2833        // Forward is mirrored into bwd, so we expect 1 + 1 = 2 DenseSolves
2834        // (forward copy + reverse).
2835        assert!(
2836            count(|o| matches!(o, Op::DenseSolve)) >= 2,
2837            "expected ≥2 DenseSolve nodes (forward mirror + reverse), got\n{bwd}"
2838        );
2839        assert!(
2840            count(|o| matches!(o, Op::Transpose { .. })) >= 1,
2841            "expected a Transpose for Aᵀ, got\n{bwd}"
2842        );
2843        assert!(
2844            count(|o| matches!(o, Op::MatMul)) >= 1,
2845            "expected a MatMul for the outer product, got\n{bwd}"
2846        );
2847        assert!(
2848            count(|o| matches!(o, Op::Activation(Activation::Neg))) >= 1,
2849            "expected a Neg for −outer, got\n{bwd}"
2850        );
2851    }
2852
2853    #[test]
2854    fn inline_if_replaces_with_where() {
2855        // Build a parent graph:
2856        //   x       : Input
2857        //   pred    : Input (scalar bool)
2858        //   then_b  : sub-graph with Input(0) → Activation(Relu)
2859        //   else_b  : sub-graph with Input(0) → Activation(Sigmoid)
2860        //   out     : If(pred, [x] -> then_b, else_b)
2861        let s = Shape::new(&[4], DType::F32);
2862        let pred_s = Shape::new(&[1], DType::F32);
2863
2864        let mut then_g = Graph::new("then_branch");
2865        let then_in = then_g.input("captured", s.clone());
2866        let then_out = then_g.activation(Activation::Relu, then_in, s.clone());
2867        then_g.set_outputs(vec![then_out]);
2868
2869        let mut else_g = Graph::new("else_branch");
2870        let else_in = else_g.input("captured", s.clone());
2871        let else_out = else_g.activation(Activation::Sigmoid, else_in, s.clone());
2872        else_g.set_outputs(vec![else_out]);
2873
2874        let mut g = Graph::new("parent");
2875        let x = g.input("x", s.clone());
2876        let pred = g.input("pred", pred_s);
2877        let if_out = g.add_node(
2878            Op::If {
2879                then_branch: Box::new(then_g),
2880                else_branch: Box::new(else_g),
2881            },
2882            vec![pred, x],
2883            s,
2884        );
2885        g.set_outputs(vec![if_out]);
2886
2887        let inlined = rlx_fusion::control_flow::inline_if(g);
2888
2889        // After inlining: no Op::If, exactly one Op::Where, one
2890        // Activation(Relu), one Activation(Sigmoid). Inputs (x,
2891        // pred) and the original output count are preserved.
2892        let has_if = inlined
2893            .nodes()
2894            .iter()
2895            .any(|n| matches!(n.op, Op::If { .. }));
2896        let has_where = inlined.nodes().iter().any(|n| matches!(n.op, Op::Where));
2897        let has_relu = inlined
2898            .nodes()
2899            .iter()
2900            .any(|n| matches!(n.op, Op::Activation(Activation::Relu)));
2901        let has_sigmoid = inlined
2902            .nodes()
2903            .iter()
2904            .any(|n| matches!(n.op, Op::Activation(Activation::Sigmoid)));
2905        assert!(!has_if, "Op::If should be inlined away");
2906        assert!(has_where, "Op::Where should replace the Op::If");
2907        assert!(has_relu, "then_branch's Activation(Relu) should be inlined");
2908        assert!(
2909            has_sigmoid,
2910            "else_branch's Activation(Sigmoid) should be inlined"
2911        );
2912        assert_eq!(inlined.outputs.len(), 1);
2913    }
2914
2915    #[test]
2916    fn grad_through_if_propagates() {
2917        // Sanity: autodiff a graph with Op::If and confirm it
2918        // produces a gradient (the Where VJP handles the join).
2919        let s = Shape::new(&[4], DType::F32);
2920        let pred_s = Shape::new(&[1], DType::F32);
2921
2922        let mut then_g = Graph::new("th");
2923        let ti = then_g.input("c", s.clone());
2924        let to = then_g.binary(BinaryOp::Mul, ti, ti, s.clone());
2925        then_g.set_outputs(vec![to]);
2926
2927        let mut else_g = Graph::new("el");
2928        let ei = else_g.input("c", s.clone());
2929        let eo = else_g.activation(Activation::Relu, ei, s.clone());
2930        else_g.set_outputs(vec![eo]);
2931
2932        let mut g = Graph::new("parent");
2933        let x = g.input("x", s.clone());
2934        let pred = g.input("pred", pred_s);
2935        let z = g.add_node(
2936            Op::If {
2937                then_branch: Box::new(then_g),
2938                else_branch: Box::new(else_g),
2939            },
2940            vec![pred, x],
2941            s,
2942        );
2943        g.set_outputs(vec![z]);
2944
2945        let bwd = grad_with_loss(&g, &[x]);
2946        // [loss, dz/dx] — two outputs.
2947        assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
2948    }
2949
2950    #[test]
2951    fn unroll_while_replicates_body_n_times() {
2952        // Build a parent graph:
2953        //   x   : Input
2954        //   out : While(cond=trivial, body=Activation(Relu), N=3)
2955        // After unrolling we expect zero Op::While, three Activation
2956        // (Relu) nodes (one per replica).
2957        let s = Shape::new(&[4], DType::F32);
2958        let bool_s = Shape::new(&[1], DType::F32);
2959
2960        let mut cond_g = Graph::new("cond");
2961        let ci = cond_g.input("c", s.clone());
2962        // dummy bool: just feed input through (cond is not evaluated
2963        // by the unroll pass, so its body doesn't matter).
2964        cond_g.set_outputs(vec![ci]);
2965        // Replace output shape: cond's output is logically a scalar
2966        // bool — but the unroll pass never inspects it.
2967        let _ = bool_s;
2968
2969        let mut body_g = Graph::new("body");
2970        let bi = body_g.input("c", s.clone());
2971        let bo = body_g.activation(Activation::Relu, bi, s.clone());
2972        body_g.set_outputs(vec![bo]);
2973
2974        let mut g = Graph::new("parent");
2975        let x = g.input("x", s.clone());
2976        let w = g.add_node(
2977            Op::While {
2978                cond: Box::new(cond_g),
2979                body: Box::new(body_g),
2980                max_iterations: Some(3),
2981            },
2982            vec![x],
2983            s,
2984        );
2985        g.set_outputs(vec![w]);
2986
2987        let unrolled = rlx_fusion::control_flow::unroll_while(g);
2988
2989        let has_while = unrolled
2990            .nodes()
2991            .iter()
2992            .any(|n| matches!(n.op, Op::While { .. }));
2993        let relu_count = unrolled
2994            .nodes()
2995            .iter()
2996            .filter(|n| matches!(n.op, Op::Activation(Activation::Relu)))
2997            .count();
2998        assert!(!has_while, "Op::While should be unrolled away");
2999        assert_eq!(
3000            relu_count, 3,
3001            "body's Activation(Relu) should appear once per iteration"
3002        );
3003        assert_eq!(unrolled.outputs.len(), 1);
3004    }
3005
3006    #[test]
3007    fn grad_through_while_propagates() {
3008        // Sanity: autodiff a graph with Op::While and confirm the
3009        // gradient pipeline produces a result (the unroll pass turns
3010        // it into a chain of body replicas before the gradient walk).
3011        let s = Shape::new(&[4], DType::F32);
3012
3013        let mut cond_g = Graph::new("cond");
3014        let ci = cond_g.input("c", s.clone());
3015        cond_g.set_outputs(vec![ci]);
3016
3017        let mut body_g = Graph::new("body");
3018        let bi = body_g.input("c", s.clone());
3019        let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
3020        body_g.set_outputs(vec![bo]);
3021
3022        let mut g = Graph::new("parent");
3023        let x = g.input("x", s.clone());
3024        let w = g.add_node(
3025            Op::While {
3026                cond: Box::new(cond_g),
3027                body: Box::new(body_g),
3028                max_iterations: Some(2),
3029            },
3030            vec![x],
3031            s,
3032        );
3033        g.set_outputs(vec![w]);
3034
3035        let bwd = grad_with_loss(&g, &[x]);
3036        assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
3037    }
3038
3039    /// Build a tiny BERT-style FTL graph with the given bias mode.
3040    /// Returns (graph, hidden_input_id, all_param_ids).
3041    fn build_ftl_graph(has_bias: bool) -> (Graph, NodeId, Vec<NodeId>) {
3042        // B=1, S=2, hidden=4, heads=2, head_dim=2, intermediate=8.
3043        let mut g = Graph::new("ftl_test");
3044        let h_shape = Shape::new(&[1, 2, 4], DType::F32);
3045        let h = g.input("h", h_shape.clone());
3046        let qkv_w = g.param("qkv_w", Shape::new(&[4, 12], DType::F32));
3047        let out_w = g.param("out_w", Shape::new(&[4, 4], DType::F32));
3048        let ln1_g = g.param("ln1_g", Shape::new(&[4], DType::F32));
3049        let fc1_w = g.param("fc1_w", Shape::new(&[4, 8], DType::F32));
3050        let fc2_w = g.param("fc2_w", Shape::new(&[8, 4], DType::F32));
3051        let ln2_g = g.param("ln2_g", Shape::new(&[4], DType::F32));
3052        let mask = g.input("mask", Shape::new(&[1, 2, 2, 2], DType::F32));
3053
3054        let (inputs, params) = if has_bias {
3055            let qkv_b = g.param("qkv_b", Shape::new(&[12], DType::F32));
3056            let out_b = g.param("out_b", Shape::new(&[4], DType::F32));
3057            let ln1_b = g.param("ln1_b", Shape::new(&[4], DType::F32));
3058            let fc1_b = g.param("fc1_b", Shape::new(&[8], DType::F32));
3059            let fc2_b = g.param("fc2_b", Shape::new(&[4], DType::F32));
3060            let ln2_b = g.param("ln2_b", Shape::new(&[4], DType::F32));
3061            (
3062                vec![
3063                    h, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3064                    ln2_b, mask,
3065                ],
3066                vec![
3067                    qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3068                    ln2_b,
3069                ],
3070            )
3071        } else {
3072            (
3073                vec![h, qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g, mask],
3074                vec![qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g],
3075            )
3076        };
3077        let y = g.add_node(
3078            Op::FusedTransformerLayer {
3079                num_heads: 2,
3080                head_dim: 2,
3081                intermediate_size: 8,
3082                eps1: 1e-5,
3083                eps2: 1e-5,
3084                activation: rlx_ir::op::Activation::Gelu,
3085                has_bias,
3086            },
3087            inputs,
3088            h_shape,
3089        );
3090        g.set_outputs(vec![y]);
3091        (g, h, params)
3092    }
3093
3094    #[test]
3095    fn unfuse_decomposes_fused_transformer_layer() {
3096        // After rlx_fusion::unfuse_fused_for_autodiff, the FTL node is gone and
3097        // primitives appear: at least 4 MatMul (qkv, out, fc1, fc2),
3098        // 1 Attention, 2 LayerNorm, plus narrows / adds / activation.
3099        let (g, _h, _params) = build_ftl_graph(true);
3100        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3101
3102        let has_ftl = unfused
3103            .nodes()
3104            .iter()
3105            .any(|n| matches!(n.op, Op::FusedTransformerLayer { .. }));
3106        assert!(!has_ftl, "Op::FusedTransformerLayer should be unfused");
3107
3108        let count = |pred: fn(&Op) -> bool| -> usize {
3109            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3110        };
3111        assert!(
3112            count(|o| matches!(o, Op::MatMul)) >= 4,
3113            "expected >=4 MatMul after FTL unfuse"
3114        );
3115        assert_eq!(
3116            count(|o| matches!(o, Op::Attention { .. })),
3117            1,
3118            "expected exactly 1 Attention after FTL unfuse"
3119        );
3120        assert_eq!(
3121            count(|o| matches!(o, Op::LayerNorm { .. })),
3122            2,
3123            "expected exactly 2 LayerNorm after FTL unfuse"
3124        );
3125        assert!(
3126            count(|o| matches!(o, Op::Narrow { .. })) >= 3,
3127            "expected >=3 Narrow (Q/K/V split) after FTL unfuse"
3128        );
3129        assert_eq!(
3130            count(|o| matches!(o, Op::Activation(_))),
3131            1,
3132            "expected exactly 1 Activation (FFN) after FTL unfuse"
3133        );
3134    }
3135
3136    #[test]
3137    fn grad_through_fused_transformer_layer_propagates() {
3138        // End-to-end: grad_with_loss through an FTL graph returns
3139        // [loss, ...grads]. Confirms every primitive emitted by the
3140        // unfuse has a VJP rule on the gradient walk.
3141        let (g, _h, params) = build_ftl_graph(true);
3142        let bwd = grad_with_loss(&g, &params);
3143        assert_eq!(
3144            bwd.outputs.len(),
3145            1 + params.len(),
3146            "expected loss + {} param grads",
3147            params.len()
3148        );
3149    }
3150
3151    #[test]
3152    fn grad_through_fused_transformer_layer_no_bias() {
3153        // No-bias variant exercises the synthesized zero-beta path
3154        // for both LayerNorms.
3155        let (g, _h, params) = build_ftl_graph(false);
3156        let bwd = grad_with_loss(&g, &params);
3157        assert_eq!(
3158            bwd.outputs.len(),
3159            1 + params.len(),
3160            "expected loss + {} param grads (no-bias)",
3161            params.len()
3162        );
3163    }
3164
3165    /// Build a tiny SelectiveScan graph: B=1, S=3, H=2, N=4.
3166    /// Returns (graph, [x, delta, a, b, c]).
3167    fn build_ssm_graph() -> (Graph, NodeId, Vec<NodeId>) {
3168        let mut g = Graph::new("ssm_test");
3169        let bsh = Shape::new(&[1, 3, 2], DType::F32);
3170        let hn = Shape::new(&[2, 4], DType::F32);
3171        let bsn = Shape::new(&[1, 3, 4], DType::F32);
3172
3173        let x = g.input("x", bsh.clone());
3174        let delta = g.input("delta", bsh.clone());
3175        let a = g.param("a", hn);
3176        let b = g.input("b", bsn.clone());
3177        let c = g.input("c", bsn);
3178        let y = g.selective_scan(x, delta, a, b, c, 4, bsh);
3179        g.set_outputs(vec![y]);
3180        (g, x, vec![a])
3181    }
3182
3183    #[test]
3184    fn unfuse_decomposes_selective_scan() {
3185        // After unfuse, no Op::SelectiveScan; instead we see Concat
3186        // (one for S>1), per-step Reduce(Sum), per-step Activation::Exp,
3187        // and many Mul / Add / Narrow / Reshape / Expand nodes.
3188        // S=3 → 3 timesteps.
3189        let (g, _x, _params) = build_ssm_graph();
3190        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3191
3192        let has_ssm = unfused
3193            .nodes()
3194            .iter()
3195            .any(|n| matches!(n.op, Op::SelectiveScan { .. }));
3196        assert!(!has_ssm, "Op::SelectiveScan should be unfused");
3197
3198        let count = |pred: fn(&Op) -> bool| -> usize {
3199            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3200        };
3201        assert_eq!(
3202            count(|o| matches!(o, Op::Concat { .. })),
3203            1,
3204            "expected 1 Concat (over the 3 time steps)"
3205        );
3206        assert_eq!(
3207            count(|o| matches!(
3208                o,
3209                Op::Reduce {
3210                    op: ReduceOp::Sum,
3211                    ..
3212                }
3213            )),
3214            3,
3215            "expected one Reduce(Sum) per time step (S=3)"
3216        );
3217        assert_eq!(
3218            count(|o| matches!(o, Op::Activation(Activation::Exp))),
3219            3,
3220            "expected one exp(δA) per time step (S=3)"
3221        );
3222        assert!(
3223            count(|o| matches!(o, Op::Narrow { .. })) >= 12,
3224            "expected >=12 Narrows (4 per step × 3 steps)"
3225        );
3226    }
3227
3228    #[test]
3229    fn grad_through_selective_scan_propagates() {
3230        // End-to-end: grad_with_loss through SelectiveScan returns
3231        // [loss, da] — confirms every primitive emitted by the
3232        // unroll has a VJP rule on the gradient walk (Mul, Add,
3233        // Activation::Exp, Reduce::Sum, Concat, Narrow, Reshape,
3234        // Expand).
3235        let (g, _x, params) = build_ssm_graph();
3236        let bwd = grad_with_loss(&g, &params);
3237        assert_eq!(
3238            bwd.outputs.len(),
3239            1 + params.len(),
3240            "expected loss + {} param grads",
3241            params.len()
3242        );
3243    }
3244
3245    /// Tiny GatedDeltaNet: B=1, S=3, H=2, N=4.
3246    fn build_gdn_graph() -> (Graph, NodeId, Vec<NodeId>) {
3247        let (b, s, h, n) = (1usize, 3, 2, 4);
3248        let mut g = Graph::new("gdn_test");
3249        let bshn = Shape::new(&[b, s, h, n], DType::F32);
3250        let bsh = Shape::new(&[b, s, h], DType::F32);
3251        let q = g.input("q", bshn.clone());
3252        let k = g.input("k", bshn.clone());
3253        let v = g.input("v", bshn.clone());
3254        let g_in = g.input("g", bsh.clone());
3255        let beta = g.input("beta", bsh);
3256        let y = g.gated_delta_net(q, k, v, g_in, beta, n, bshn);
3257        g.set_outputs(vec![y]);
3258        (g, q, vec![q, k, v, g_in, beta])
3259    }
3260
3261    #[test]
3262    fn unfuse_decomposes_gated_delta_net() {
3263        let (g, _q, _params) = build_gdn_graph();
3264        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3265
3266        let has_gdn = unfused
3267            .nodes()
3268            .iter()
3269            .any(|n| matches!(n.op, Op::GatedDeltaNet { .. }));
3270        assert!(!has_gdn, "Op::GatedDeltaNet should be unfused");
3271
3272        let count = |pred: fn(&Op) -> bool| -> usize {
3273            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3274        };
3275        assert_eq!(
3276            count(|o| matches!(o, Op::Concat { .. })),
3277            1,
3278            "expected 1 Concat over S=3 steps"
3279        );
3280        assert!(
3281            count(|o| matches!(o, Op::MatMul)) >= 3,
3282            "expected >=3 MatMul per step (sk + out) × S=3"
3283        );
3284        assert_eq!(
3285            count(|o| matches!(o, Op::Activation(Activation::Exp))),
3286            3,
3287            "expected one exp(g) per time step"
3288        );
3289    }
3290
3291    #[test]
3292    fn grad_through_gated_delta_net_propagates() {
3293        let (g, _q, params) = build_gdn_graph();
3294        let bwd = grad_with_loss(&g, &params);
3295        assert_eq!(
3296            bwd.outputs.len(),
3297            1 + params.len(),
3298            "expected loss + {} input grads",
3299            params.len()
3300        );
3301    }
3302
3303    #[test]
3304    fn custom_fn_vjp_body_is_inlined_into_bwd() {
3305        // Forward: y = x² via custom_fn (fwd_body = Mul(x, x)).
3306        // Override VJP to return Activation::Sin(d_output) — a unique
3307        // marker that natural autodiff of Mul would never emit. If
3308        // grad_with_loss inlines the override correctly, the bwd graph
3309        // must contain a Sin node; if it falls back to recursing into
3310        // fwd_body, it would emit two Muls (upstream·x + x·upstream)
3311        // and no Sin.
3312        let n = 4usize;
3313        let shape = Shape::new(&[n], DType::F32);
3314
3315        // fwd_body: x → x · x.
3316        let mut fwd_body = Graph::new("square_fwd");
3317        let xb = fwd_body.input("x", shape.clone());
3318        let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3319        fwd_body.set_outputs(vec![yb]);
3320
3321        // vjp_body: (x, primal_output, d_output) → sin(d_output).
3322        let mut vjp_body = Graph::new("square_vjp");
3323        let _vx = vjp_body.input("x", shape.clone());
3324        let _vp = vjp_body.input("primal_output", shape.clone());
3325        let vd = vjp_body.input("d_output", shape.clone());
3326        let dx = vjp_body.activation(Activation::Sin, vd, shape.clone());
3327        vjp_body.set_outputs(vec![dx]);
3328
3329        let mut g = Graph::new("custom_fn_test");
3330        let x = g.input("x", shape.clone());
3331        let y = g.custom_fn(vec![x], fwd_body, Some(vjp_body), None);
3332        let loss = g.reduce(
3333            y,
3334            ReduceOp::Sum,
3335            vec![0],
3336            false,
3337            Shape::new(&[1], DType::F32),
3338        );
3339        g.set_outputs(vec![loss]);
3340
3341        let bwd = grad_with_loss(&g, &[x]);
3342        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3343        let sin_count = bwd
3344            .nodes()
3345            .iter()
3346            .filter(|n| matches!(n.op, Op::Activation(Activation::Sin)))
3347            .count();
3348        assert!(
3349            sin_count >= 1,
3350            "expected the vjp_body's Sin to be inlined into bwd, got\n{bwd}"
3351        );
3352    }
3353
3354    #[test]
3355    fn custom_fn_without_vjp_inlines_fwd_body_for_autodiff() {
3356        // Forward: y = x² via custom_fn without vjp_body. After the
3357        // inline pre-pass, autodiff should recurse into Mul and emit
3358        // dx = 2·x·d_output (two Mul nodes in the backward graph).
3359        let n = 4usize;
3360        let shape = Shape::new(&[n], DType::F32);
3361
3362        let mut fwd_body = Graph::new("square_fwd");
3363        let xb = fwd_body.input("x", shape.clone());
3364        let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3365        fwd_body.set_outputs(vec![yb]);
3366
3367        let mut g = Graph::new("custom_fn_no_vjp");
3368        let x = g.input("x", shape.clone());
3369        let y = g.custom_fn(vec![x], fwd_body, None, None);
3370        let loss = g.reduce(
3371            y,
3372            ReduceOp::Sum,
3373            vec![0],
3374            false,
3375            Shape::new(&[1], DType::F32),
3376        );
3377        g.set_outputs(vec![loss]);
3378
3379        let bwd = grad_with_loss(&g, &[x]);
3380        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3381        let custom_fn_count = bwd
3382            .nodes()
3383            .iter()
3384            .filter(|n| matches!(n.op, Op::CustomFn { .. }))
3385            .count();
3386        assert_eq!(
3387            custom_fn_count, 0,
3388            "CustomFn should be inlined away before autodiff"
3389        );
3390        let mul_count = bwd
3391            .nodes()
3392            .iter()
3393            .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
3394            .count();
3395        assert!(mul_count >= 2, "expected Mul-based VJP for x², got\n{bwd}");
3396    }
3397
3398    #[test]
3399    fn convert_scans_for_ad_forces_save_trajectory_true() {
3400        // grad_with_loss runs `convert_scans_for_ad` as a pre-pass: any
3401        // forward Op::Scan with `save_trajectory: false` is rewritten
3402        // to `save_trajectory: true` followed by Narrow + Reshape so
3403        // the reverse pass has the trajectory it needs. This test
3404        // verifies the rewrite happens — the bwd graph should contain
3405        // at least one Scan with save_trajectory == true.
3406        let n = 2usize;
3407        let length = 3u32;
3408        let carry = Shape::new(&[n], DType::F32);
3409        let xs_shape = Shape::new(&[length as usize, n], DType::F32);
3410
3411        // body: (carry, x_t) → carry + x_t. One primal Input each.
3412        let mut body = Graph::new("scan_body");
3413        let bc = body.input("carry", carry.clone());
3414        let bx = body.input("x_t", carry.clone());
3415        let by = body.binary(BinaryOp::Add, bc, bx, carry.clone());
3416        body.set_outputs(vec![by]);
3417
3418        let mut g = Graph::new("scan_save_false");
3419        let init = g.input("init", carry.clone());
3420        let xs = g.input("xs", xs_shape);
3421        let scan_out = g.add_node(
3422            Op::Scan {
3423                body: Box::new(body),
3424                length,
3425                save_trajectory: false,
3426                num_bcast: 0,
3427                num_xs: 1,
3428                num_checkpoints: 0,
3429            },
3430            vec![init, xs],
3431            carry.clone(),
3432        );
3433        let loss = g.reduce(
3434            scan_out,
3435            ReduceOp::Sum,
3436            vec![0],
3437            false,
3438            Shape::new(&[1], DType::F32),
3439        );
3440        g.set_outputs(vec![loss]);
3441
3442        let bwd = grad_with_loss(&g, &[init, xs]);
3443        let saved_traj = bwd.nodes().iter().any(|n| {
3444            matches!(
3445                &n.op,
3446                Op::Scan {
3447                    save_trajectory: true,
3448                    ..
3449                }
3450            )
3451        });
3452        assert!(
3453            saved_traj,
3454            "convert_scans_for_ad should rewrite save_trajectory=false → \
3455             save_trajectory=true in the AD-prepared graph; got\n{bwd}"
3456        );
3457    }
3458}