Skip to main content

rlx_autodiff/
autodiff.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reverse-mode automatic differentiation (VJP transform).
17//!
18//! Takes a forward graph that produces a single scalar output (the
19//! loss) and returns a new graph whose outputs are
20//! `[loss, grad_w_param0, grad_w_param1, ...]` for the parameters
21//! listed in `wrt`. Running the returned graph through any backend
22//! gives the loss and all parameter gradients in one pass.
23//!
24//! ## Implementation
25//!
26//! Standard reverse-mode AD: walk the forward graph in reverse topo
27//! order; for each visited node, emit gradient nodes that contribute
28//! to the gradients of its inputs. Multiple consumers' contributions
29//! are summed via `BinaryOp::Add`.
30//!
31//! For ops with a closed-form gradient kernel (`ReluBackward`,
32//! `MaxPool2dBackward`, `Conv2dBackwardInput/Weight`,
33//! `AttentionBackward`, `SoftmaxCrossEntropyBackward` — added in the
34//! rlx-ir backward-op
35//! family), the VJP emits the dedicated kernel rather than composing
36//! the gradient from primitives.
37//!
38//! ## Broadcast handling
39//!
40//! Forward broadcasts (e.g. `[N, C] + [C]` → `[N, C]`) require the
41//! reverse pass to *un-broadcast* the gradient back to the broadcast
42//! input's smaller shape via a `Reduce::Sum` over the inserted /
43//! tiled axes. `unbroadcast` does this; both `Op::Binary` and
44//! `Op::Expand` VJPs use it.
45//!
46//! ## Coverage
47//!
48//! Element-wise: `Binary(Add/Sub/Mul/Div/Min/Max/Pow)`,
49//! `Activation(*)` (Relu via dedicated `ReluBackward`, others via
50//! generic `ActivationBackward`), `Compare` (zero gradient),
51//! `Where`, `Cast`, `Quantize/Dequantize` (straight-through).
52//!
53//! Linear / reductions: `MatMul`, `Conv`, `Pool{Max,Mean}`,
54//! `Reduce{Sum,Mean,Min,Max,Prod}`, `Softmax`, `LayerNorm`
55//! (dedicated kernels), `RmsNorm` (composed), `Rope` (composed
56//! via negated sin), `SoftmaxCrossEntropyWithLogits`.
57//!
58//! Shape: `Reshape`, `Transpose`, `Expand`, `Concat`, `Narrow`,
59//! `Gather` (axis=0), `ScatterAdd`.
60//!
61//! Attention: `Op::Attention` → three [`Op::AttentionBackward`]
62//! nodes (`dQ` / `dK` / `dV`) for all mask kinds. Causal /
63//! SlidingWindow masks are applied inside the backward kernel (no
64//! mask tensor). `Custom` / `Bias` pass the forward mask input.
65//! before softmax; Custom uses the user-provided mask tensor;
66//! None is the no-mask path.
67//!
68//! Pre-pass: `UnfuseElementwiseRegions` runs automatically before
69//! the gradient walk so `Op::ElementwiseRegion` decomposes into
70//! its primitive chain (covered op-by-op above).
71//!
72//! Sampling-style (`TopK`, `Sample`): non-differentiable — emit no
73//! gradient (forward is a discrete selector / draw).
74//!
75//! Pre-pass: [`crate::prepare_ad::prepare_graph_for_ad`] runs before
76//! the gradient walk (also exposed as [`PrepareForAutodiff`] pass).
77//! It unfuses elementwise regions and tier-2 fused ops
78//! (`FusedMatMulBiasAct`, `FusedResidualLN`, `FusedResidualRmsNorm`,
79//! `FusedSwiGLU`, `FusedAttentionBlock`, `FusedTransformerLayer`,
80//! `GatedDeltaNet`, `SelectiveScan`, …), lowers `DotGeneral`, inlines
81//! `If` / unrolls `While`, inlines `CustomFn` without `vjp_body`, and
82//! rewrites scans for trajectory AD.
83//!
84//! For HIR builders, [`rlx_ir::hir::FusionPolicy::for_autodiff`] lowers
85//! to primitive MIR; [`grad_with_loss_module`] accepts HIR or MIR
86//! [`GraphModule`] stages (not LIR).
87//!
88//! Cumsum: backward composed via matmul with a constant
89//! upper-triangular ones matrix (avoids needing a new `Op::Flip`
90//! primitive across all backends). Fine for typical sequence
91//! lengths; an L×L matmul where L is the sequence size.
92//!
93//! Quantized / MoE: `Op::DequantMatMul` (QAT straight-through),
94//! `Op::QMatMul`, `Op::QConv2d`, and `Op::GroupedMatMul` are all
95//! supported via composed straight-through VJPs. Plain
96//! `Op::Quantize/Dequantize` straight-through covers the typical
97//! fake-quant fp32 training path.
98//!
99//! Coverage today: every op in the IR has a VJP rule or a
100//! pre-pass that rewrites it into ones that do. SelectiveScan
101//! (Mamba SSM step) and GatedDeltaNet (Qwen3.5 linear-attention
102//! scan) decompose by unrolling the time loop into Mul / Add /
103//! MatMul / Activation::Exp / Concat / Narrow / Reshape / Expand
104//! primitives — same shape as the rlx-mlx lowering, just emitted
105//! as IR nodes instead of MLX arrays.
106//! FusedTransformerLayer / FusedAttentionBlock / FusedSwiGLU /
107//! LoraMatMul / FusedMatMulBiasAct / FusedResidualLN are all
108//! decomposed by `rlx_fusion::unfuse_fused_for_autodiff` likewise. Op::If
109//! is rewritten to `Where(predicate, then, else)` with both
110//! branches inlined; Op::While is bounded-unrolled up to
111//! `max_iterations`.
112
113use rlx_ir::op::*;
114use rlx_ir::shape::Dim;
115use rlx_ir::*;
116use std::collections::HashMap;
117
118pub use crate::prepare_ad::{
119    AutodiffError, PrepareForAutodiff, grad_with_loss_module, jvp_module, prepare_graph_for_ad,
120    prepare_mir_for_ad, prepare_module_for_ad,
121};
122
123/// Compute the reverse-mode gradient graph and the loss value.
124///
125/// Returns a graph whose outputs are
126/// `[loss, aux₁, aux₂, …, grad_wrt[0], grad_wrt[1], …]`.
127///
128/// The **first** forward output is treated as the scalar loss and is
129/// differentiated. Any **additional** forward outputs (`aux₁ …`) are
130/// mirrored from the forward graph and emitted unchanged — gradients are
131/// not propagated through them. This is the canonical hook for emitting
132/// training-side statistics (BatchNorm batch mean/variance, debug probes,
133/// …) alongside the loss in a single forward+backward pass.
134///
135/// The returned graph contains a copy of the entire forward graph so
136/// activations needed by gradient kernels are recomputed from inputs;
137/// it also exposes a new `Op::Input` named `"d_output"` which the
138/// caller seeds with the upstream gradient of the loss (typically a
139/// scalar `1.0` for "differentiate the loss directly"). Auxiliary outputs
140/// have no `d_output`-equivalent — by construction they don't contribute
141/// to the gradient path.
142///
143/// ## Limitations
144/// - Forward graph must have **≥ 1** output. The first is the loss.
145/// - All ops in the forward graph must have an implemented VJP rule.
146///   Hitting an op without one is a panic, not a silent miscompute.
147/// Options for [`grad_with_loss_opts`].
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub struct GradWithLossOptions {
150    /// When true, parameters in `wrt` with no gradient path receive an
151    /// explicit zero tensor instead of panicking (e.g. unused `logit_bias`).
152    pub zero_missing_wrt: bool,
153}
154
155impl GradWithLossOptions {
156    pub const STRICT: Self = Self {
157        zero_missing_wrt: false,
158    };
159    pub const TRAINING: Self = Self {
160        zero_missing_wrt: true,
161    };
162}
163
164/// Build a backward graph with scalar loss + gradients w.r.t. `wrt`.
165///
166/// Panics if any `wrt` parameter receives no gradient (use
167/// [`grad_with_loss_opts`] with [`GradWithLossOptions::TRAINING`] to
168/// zero-fill instead).
169pub fn grad_with_loss(forward: &Graph, wrt: &[NodeId]) -> Graph {
170    grad_with_loss_opts(forward, wrt, GradWithLossOptions::STRICT)
171}
172
173/// Like [`grad_with_loss`] with configurable unused-parameter handling.
174pub fn grad_with_loss_opts(forward: &Graph, wrt: &[NodeId], opts: GradWithLossOptions) -> Graph {
175    assert!(
176        !forward.outputs.is_empty(),
177        "grad_with_loss: forward must have at least one output (the loss)"
178    );
179
180    // Pre-autodiff unfuse: decompose fused ops back to primitives so
181    // the per-op VJP rules cover them. Two layers:
182    //   1. `UnfuseElementwiseRegions` — splits the chain back to
183    //      Activation/Cast/Binary/Compare/Where ops.
184    //   2. `rlx_fusion::unfuse_fused_for_autodiff` (below) — handles the
185    //      tier-2 fused ops with closed-form decompositions:
186    //      FusedMatMulBiasAct, FusedResidualLN, LoraMatMul.
187    //
188    // FusedSwiGLU / FusedAttentionBlock / FusedTransformerLayer
189    // are all decomposed by `rlx_fusion::unfuse_fused_for_autodiff` (each is
190    // a multi-stage sub-graph; mirrors what `rlx-tpu/src/unfuse.rs`
191    // does for HLO emission).
192    let forward_owned = crate::prepare_ad::prepare_graph_for_ad(forward.clone());
193    let forward = &forward_owned;
194
195    let mut bwd = Graph::new(format!("{}_grad", forward.name));
196
197    // Mirror every forward node into bwd. The activations needed by
198    // gradient kernels (`x` for ReluBackward, `logits` for
199    // SoftmaxCrossEntropyBackward, etc.) are looked up by these
200    // mirrored ids.
201    let mut fwd_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
202    for node in forward.nodes() {
203        let inputs: Vec<NodeId> = node.inputs.iter().map(|i| fwd_to_bwd[i]).collect();
204        let new_id = bwd.add_node(node.op.clone(), inputs, node.shape.clone());
205        fwd_to_bwd.insert(node.id, new_id);
206    }
207
208    // Seed: the gradient of the loss w.r.t. itself is an external
209    // input the caller provides (typically `[1.0]` for a scalar loss).
210    let loss_fwd = forward.outputs[0];
211    let loss_bwd = fwd_to_bwd[&loss_fwd];
212    let loss_shape = forward.node(loss_fwd).shape.clone();
213    let d_output = bwd.input("d_output", loss_shape);
214
215    let mut grads: HashMap<NodeId, NodeId> = HashMap::new();
216    grads.insert(loss_bwd, d_output);
217
218    for fwd_node in forward.nodes().iter().rev() {
219        let bwd_id = fwd_to_bwd[&fwd_node.id];
220        let upstream = match grads.get(&bwd_id) {
221            Some(g) => *g,
222            None => continue,
223        };
224        let input_grads = vjp(fwd_node, upstream, &fwd_to_bwd, &mut bwd);
225        for (idx, grad_id) in input_grads {
226            let target = fwd_node.inputs[idx];
227            let bwd_target = fwd_to_bwd[&target];
228            // Per-consumer gradients accumulate (`+=`).
229            let new_grad = if let Some(&prev) = grads.get(&bwd_target) {
230                let shape = bwd.node(prev).shape.clone();
231                bwd.binary(BinaryOp::Add, prev, grad_id, shape)
232            } else {
233                grad_id
234            };
235            grads.insert(bwd_target, new_grad);
236        }
237    }
238
239    let n_aux = forward.outputs.len().saturating_sub(1);
240    let mut outputs = Vec::with_capacity(1 + n_aux + wrt.len());
241    outputs.push(loss_bwd);
242    // Auxiliary forward outputs (everything past `outputs[0]`): mirrored
243    // from the forward graph, no gradient propagation.
244    for &aux in &forward.outputs[1..] {
245        outputs.push(fwd_to_bwd[&aux]);
246    }
247    for &id in wrt {
248        let g = match grads.get(&fwd_to_bwd[&id]).copied() {
249            Some(g) => g,
250            None if opts.zero_missing_wrt => {
251                let shape = forward.node(id).shape.clone();
252                let n = shape.num_elements().unwrap_or(0);
253                let data: Vec<u8> = (0..n).flat_map(|_| 0.0f32.to_le_bytes()).collect();
254                bwd.add_node(Op::Constant { data }, vec![], shape)
255            }
256            None => {
257                panic!(
258                    "no gradient flowed to {id} — \
259                    either the forward graph doesn't depend on it, or one \
260                    of its consumer ops has no VJP rule"
261                )
262            }
263        };
264        outputs.push(g);
265    }
266    bwd.set_outputs(outputs);
267    bwd
268}
269
270/// Backwards-compatible single-output alias (parameter gradients only,
271/// no loss). Kept for the existing tests; prefer `grad_with_loss` for
272/// training.
273pub fn grad(forward: &Graph, wrt: &[NodeId]) -> Graph {
274    let g = grad_with_loss(forward, wrt);
275    let mut g = g;
276    // Drop the loss output, keep only gradients.
277    let outs = g.outputs.iter().skip(1).copied().collect();
278    g.set_outputs(outs);
279    g
280}
281
282/// Project a gradient back to a smaller shape it was broadcasted from.
283/// `target_shape` is the broadcast *source* shape (e.g. `[C]` for a
284/// bias added to `[N, C, H, W]`). Sums over leading prepended axes
285/// and over any axis where target was 1 but the gradient is larger.
286/// Then reshapes to drop the size-1 axes if the rank shrunk.
287/// Returns `Some(bits)` if `node_id` is the output of an
288/// `Op::FakeQuantize { bits, .. }` (or `FakeQuantizeLSQ`) in the
289/// forward graph. Used by the autodiff Conv backward to detect the
290/// QAT pattern and emit a specialized weight-grad kernel that can
291/// skip dead bins (weights that round to the same code share the
292/// gradient). Today only the detection is exposed — the
293/// specialization is a follow-up commit.
294pub fn quantized_weight_bits(forward: &Graph, node_id: NodeId) -> Option<u8> {
295    match &forward.node(node_id).op {
296        Op::FakeQuantize { bits, .. } => Some(*bits),
297        Op::FakeQuantizeLSQ { bits, .. } => Some(*bits),
298        _ => None,
299    }
300}
301
302fn unbroadcast(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
303    let grad_shape = bwd.node(grad).shape.clone();
304    if grad_shape == *target_shape {
305        return grad;
306    }
307    let g_rank = grad_shape.rank();
308    let t_rank = target_shape.rank();
309    let extra = g_rank.saturating_sub(t_rank);
310
311    // Axes (in grad's coordinate system) that need summing.
312    let mut axes: Vec<usize> = (0..extra).collect();
313    for i in 0..t_rank {
314        let g_dim = grad_shape.dim(extra + i);
315        let t_dim = target_shape.dim(i);
316        if matches!(t_dim, Dim::Static(1)) && !matches!(g_dim, Dim::Static(1)) {
317            axes.push(extra + i);
318        }
319    }
320
321    let mut current = grad;
322    if !axes.is_empty() {
323        // The CPU `Op::Reduce` thunk only supports a *single contiguous*
324        // range of axes — `[0, 2, 3]` (the canonical conv-bias-gradient
325        // pattern) would silently fall through to a `Nop`. Decompose into
326        // a chain of single-axis reductions with `keep_dim=true` so rank
327        // stays at `g_rank` and earlier axis indices remain valid; the
328        // rank shrink (if any) happens in the reshape step below.
329        let mut running_dims: Vec<Dim> = (0..g_rank).map(|i| grad_shape.dim(i)).collect();
330        for &ax in &axes {
331            running_dims[ax] = Dim::Static(1);
332            let step_shape = Shape::from_dims(&running_dims, target_shape.dtype());
333            current = bwd.add_node(
334                Op::Reduce {
335                    op: ReduceOp::Sum,
336                    axes: vec![ax],
337                    keep_dim: true,
338                },
339                vec![current],
340                step_shape,
341            );
342        }
343    }
344
345    // Drop leading 1-axes via Reshape if the target rank is smaller.
346    if bwd.node(current).shape.rank() != t_rank {
347        let new_shape: Vec<i64> = target_shape
348            .dims()
349            .iter()
350            .map(|d| match d {
351                Dim::Static(n) => *n as i64,
352                Dim::Dynamic(_) => -1,
353            })
354            .collect();
355        current = bwd.add_node(
356            Op::Reshape { new_shape },
357            vec![current],
358            target_shape.clone(),
359        );
360    }
361    current
362}
363
364/// Reshape a gradient to a target shape (used by Reshape / Mean VJPs).
365fn reshape_to(grad: NodeId, target_shape: &Shape, bwd: &mut Graph) -> NodeId {
366    if bwd.node(grad).shape == *target_shape {
367        return grad;
368    }
369    let new_shape: Vec<i64> = target_shape
370        .dims()
371        .iter()
372        .map(|d| match d {
373            Dim::Static(n) => *n as i64,
374            Dim::Dynamic(_) => -1,
375        })
376        .collect();
377    bwd.add_node(Op::Reshape { new_shape }, vec![grad], target_shape.clone())
378}
379
380/// VJP for `Op::GroupedMatMul` / dequantized MoE matmul (`dx`, `dw`).
381fn grouped_matmul_vjp(
382    bwd: &mut Graph,
383    upstream: NodeId,
384    x_bwd: NodeId,
385    w_bwd: NodeId,
386    expert_bwd: NodeId,
387    x_shape: &Shape,
388    w_shape: &Shape,
389) -> (NodeId, NodeId) {
390    let dtype = x_shape.dtype();
391    let m = x_shape.dim(0);
392    let k = x_shape.dim(1);
393    let e = w_shape.dim(0);
394    let n_out = w_shape.dim(2);
395    let m_static = match m {
396        Dim::Static(v) => v,
397        _ => panic!("GroupedMatMul VJP: M must be static"),
398    };
399    let k_static = match k {
400        Dim::Static(v) => v,
401        _ => panic!("GroupedMatMul VJP: K must be static"),
402    };
403    let n_static = match n_out {
404        Dim::Static(v) => v,
405        _ => panic!("GroupedMatMul VJP: N must be static"),
406    };
407
408    let w_per = bwd.add_node(
409        Op::Gather { axis: 0 },
410        vec![w_bwd, expert_bwd],
411        Shape::from_dims(&[m, k, n_out], dtype),
412    );
413
414    let up_3d_shape = Shape::from_dims(&[m, Dim::Static(1), n_out], dtype);
415    let up_3d = bwd.reshape(
416        upstream,
417        vec![m_static as i64, 1, n_static as i64],
418        up_3d_shape,
419    );
420    let w_per_t = bwd.add_node(
421        Op::Transpose {
422            perm: vec![0, 2, 1],
423        },
424        vec![w_per],
425        Shape::from_dims(&[m, n_out, k], dtype),
426    );
427    let dx_3d_shape = Shape::from_dims(&[m, Dim::Static(1), k], dtype);
428    let dx_3d = bwd.matmul(up_3d, w_per_t, dx_3d_shape);
429    let dx = bwd.reshape(
430        dx_3d,
431        vec![m_static as i64, k_static as i64],
432        x_shape.clone(),
433    );
434
435    let x_3d = bwd.reshape(
436        x_bwd,
437        vec![m_static as i64, k_static as i64, 1],
438        Shape::from_dims(&[m, k, Dim::Static(1)], dtype),
439    );
440    let up_for_outer = bwd.reshape(
441        upstream,
442        vec![m_static as i64, 1, n_static as i64],
443        Shape::from_dims(&[m, Dim::Static(1), n_out], dtype),
444    );
445    let dw_per = bwd.matmul(x_3d, up_for_outer, Shape::from_dims(&[m, k, n_out], dtype));
446    let dw = bwd.add_node(
447        Op::ScatterAdd,
448        vec![dw_per, expert_bwd],
449        Shape::from_dims(&[e, k, n_out], dtype),
450    );
451    (dx, dw)
452}
453
454/// Build a scalar f32 `Op::Constant` node.
455fn scalar_const(value: f32, bwd: &mut Graph) -> NodeId {
456    let bytes = value.to_le_bytes().to_vec();
457    let shape = Shape::from_dims(&[Dim::Static(1)], DType::F32);
458    bwd.add_node(Op::Constant { data: bytes }, vec![], shape)
459}
460
461/// Per-op VJP rule. Returns (input_index, gradient_node_id) pairs;
462/// inputs not listed receive no gradient (e.g. the labels argument
463/// of `SoftmaxCrossEntropyWithLogits` is non-differentiable).
464fn vjp(
465    node: &Node,
466    upstream: NodeId,
467    fwd_map: &HashMap<NodeId, NodeId>,
468    bwd: &mut Graph,
469) -> Vec<(usize, NodeId)> {
470    let upstream_shape = bwd.node(upstream).shape.clone();
471    match &node.op {
472        // Leaves — no inputs → no gradients to attribute.
473        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => vec![],
474
475        Op::Binary(BinaryOp::Add) => {
476            let a_bwd = fwd_map[&node.inputs[0]];
477            let b_bwd = fwd_map[&node.inputs[1]];
478            let a_shape = bwd.node(a_bwd).shape.clone();
479            let b_shape = bwd.node(b_bwd).shape.clone();
480            let g_a = unbroadcast(upstream, &a_shape, bwd);
481            let g_b = unbroadcast(upstream, &b_shape, bwd);
482            vec![(0, g_a), (1, g_b)]
483        }
484
485        Op::Binary(BinaryOp::Sub) => {
486            let a_bwd = fwd_map[&node.inputs[0]];
487            let b_bwd = fwd_map[&node.inputs[1]];
488            let a_shape = bwd.node(a_bwd).shape.clone();
489            let b_shape = bwd.node(b_bwd).shape.clone();
490            let neg = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
491            let g_a = unbroadcast(upstream, &a_shape, bwd);
492            let g_b = unbroadcast(neg, &b_shape, bwd);
493            vec![(0, g_a), (1, g_b)]
494        }
495
496        Op::Binary(BinaryOp::Mul) => {
497            let a_bwd = fwd_map[&node.inputs[0]];
498            let b_bwd = fwd_map[&node.inputs[1]];
499            let a_shape = bwd.node(a_bwd).shape.clone();
500            let b_shape = bwd.node(b_bwd).shape.clone();
501            // Wirtinger over C64: y = a·b → dL/dā = upstream · conj(b),
502            // dL/db̄ = upstream · conj(a). The conjugates turn the
503            // standard real Mul rule into the correct complex one
504            // without changing the kernel — `BinaryFullC64` does the
505            // native complex multiply on whatever inputs we hand it.
506            let is_c64 = upstream_shape.dtype() == DType::C64;
507            let b_for_a = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
508            let a_for_b = if is_c64 { bwd.conjugate(a_bwd) } else { a_bwd };
509            let g_a_full = bwd.binary(BinaryOp::Mul, upstream, b_for_a, upstream_shape.clone());
510            let g_b_full = bwd.binary(BinaryOp::Mul, upstream, a_for_b, upstream_shape);
511            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
512            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
513            vec![(0, g_a), (1, g_b)]
514        }
515
516        Op::Activation(kind) => {
517            let x_bwd = fwd_map[&node.inputs[0]];
518            // Dedicated `ReluBackward` kernel for the most common case
519            // (avoids the per-element kind-dispatch in
520            // `ActivationBackward`'s match). Every other activation
521            // family hits the generic kernel.
522            let dx = match kind {
523                Activation::Relu => bwd.relu_backward(x_bwd, upstream),
524                _ => bwd.activation_backward(*kind, x_bwd, upstream),
525            };
526            vec![(0, dx)]
527        }
528
529        Op::MatMul => {
530            // y [..batch, M, N] = a [..batch_a, M, K] @ b [..batch_b, K, N]
531            //   da = upstream @ b^T   (shape [..batch, M, K])
532            //   db = a^T   @ upstream (shape [..batch, K, N])
533            //
534            // The forward shape inference broadcasts `batch_a` and
535            // `batch_b` to a common `batch`; if either side was
536            // broadcasted, we sum the corresponding gradient back
537            // down via `unbroadcast` so it matches the param's actual
538            // shape. The transpose swaps the *last two* dims only,
539            // leaving batch untouched.
540            let a_bwd = fwd_map[&node.inputs[0]];
541            let b_bwd = fwd_map[&node.inputs[1]];
542            let a_shape = bwd.node(a_bwd).shape.clone();
543            let b_shape = bwd.node(b_bwd).shape.clone();
544            assert!(
545                a_shape.rank() >= 2 && b_shape.rank() >= 2,
546                "MatMul VJP: rank must be ≥ 2, got {} and {}",
547                a_shape.rank(),
548                b_shape.rank()
549            );
550            let dtype = upstream_shape.dtype();
551
552            // Transpose-last-two helper.
553            let trans_last_two = |bwd: &mut Graph, x: NodeId| -> NodeId {
554                let s = bwd.node(x).shape.clone();
555                let r = s.rank();
556                let mut perm: Vec<usize> = (0..r).collect();
557                perm.swap(r - 2, r - 1);
558                let mut dims: Vec<Dim> = s.dims().to_vec();
559                dims.swap(r - 2, r - 1);
560                let new_shape = Shape::from_dims(&dims, s.dtype());
561                bwd.add_node(Op::Transpose { perm }, vec![x], new_shape)
562            };
563
564            // Build the matmul output shape [..upstream_batch, M_or_K, K_or_N]
565            // by swapping in the trailing dims for each gradient.
566            let upstream_dims: Vec<Dim> = upstream_shape.dims().to_vec();
567            let r_up = upstream_dims.len();
568
569            // ── grad-a = upstream @ b^T (output shape [..up_batch, M, K]) ──
570            let b_t = trans_last_two(bwd, b_bwd);
571            let mut ga_dims = upstream_dims.clone();
572            ga_dims[r_up - 1] = a_shape.dim(a_shape.rank() - 1); // K
573            let ga_shape = Shape::from_dims(&ga_dims, dtype);
574            let g_a_full = bwd.matmul(upstream, b_t, ga_shape);
575            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
576
577            // ── grad-b = a^T @ upstream (output shape [..up_batch, K, N]) ──
578            let a_t = trans_last_two(bwd, a_bwd);
579            let mut gb_dims = upstream_dims.clone();
580            gb_dims[r_up - 2] = a_shape.dim(a_shape.rank() - 1); // K
581            let gb_shape = Shape::from_dims(&gb_dims, dtype);
582            let g_b_full = bwd.matmul(a_t, upstream, gb_shape);
583            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
584
585            vec![(0, g_a), (1, g_b)]
586        }
587
588        Op::DenseSolve => {
589            // X = solve(A, B) ⇒ implicit-function VJP:
590            //   dB = solve(Aᵀ, upstream)        same shape as B / X
591            //   dA = -dB · Xᵀ                   [N, N]
592            //
593            // Rank-1 (b: [N]) and rank-2 (B: [N, K]) follow the same
594            // formula; rank-1 needs a reshape-to-column trick because
595            // we don't have a vector-outer-product op (matmul is
596            // matrix-only). Rank-2 is direct matmul.
597            let a_bwd = fwd_map[&node.inputs[0]];
598            let x_bwd = fwd_map[&node.id];
599            let a_shape = bwd.node(a_bwd).shape.clone();
600            let x_shape = bwd.node(x_bwd).shape.clone();
601            assert_eq!(a_shape.rank(), 2, "DenseSolve VJP: A must be 2-D");
602            let n = match a_shape.dim(0) {
603                Dim::Static(n) => n,
604                Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic N not supported"),
605            };
606            let dtype = a_shape.dtype();
607
608            // Aᵀ — shape [N, N] (square, transpose is just a perm).
609            let mut a_t_dims: Vec<Dim> = a_shape.dims().to_vec();
610            a_t_dims.swap(0, 1);
611            let a_t_shape = Shape::from_dims(&a_t_dims, dtype);
612            let a_t = bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![a_bwd], a_t_shape);
613
614            // dB = solve(Aᵀ, upstream). Same shape as the original B.
615            let d_b = bwd.dense_solve(a_t, upstream, x_shape.clone());
616
617            // dA = -dB · Xᵀ.
618            let neg_outer = match x_shape.rank() {
619                1 => {
620                    // b: [N]. Reshape both vectors to matrices for matmul.
621                    let col_shape = Shape::from_dims(&[Dim::Static(n), Dim::Static(1)], dtype);
622                    let row_shape = Shape::from_dims(&[Dim::Static(1), Dim::Static(n)], dtype);
623                    let db_col = bwd.add_node(
624                        Op::Reshape {
625                            new_shape: vec![n as i64, 1],
626                        },
627                        vec![d_b],
628                        col_shape,
629                    );
630                    let x_row = bwd.add_node(
631                        Op::Reshape {
632                            new_shape: vec![1, n as i64],
633                        },
634                        vec![x_bwd],
635                        row_shape,
636                    );
637                    let outer = bwd.matmul(db_col, x_row, a_shape.clone());
638                    bwd.activation(Activation::Neg, outer, a_shape)
639                }
640                2 => {
641                    // B: [N, K]. dA = -MatMul(dB, Xᵀ) where Xᵀ: [K, N].
642                    let k = match x_shape.dim(1) {
643                        Dim::Static(k) => k,
644                        Dim::Dynamic(_) => panic!("DenseSolve VJP: dynamic K not supported"),
645                    };
646                    let xt_dims = vec![Dim::Static(k), Dim::Static(n)];
647                    let xt_shape = Shape::from_dims(&xt_dims, dtype);
648                    let x_t =
649                        bwd.add_node(Op::Transpose { perm: vec![1, 0] }, vec![x_bwd], xt_shape);
650                    let outer = bwd.matmul(d_b, x_t, a_shape.clone());
651                    bwd.activation(Activation::Neg, outer, a_shape)
652                }
653                r => panic!("DenseSolve VJP: B must be rank 1 or 2, got rank {r}"),
654            };
655
656            vec![(0, neg_outer), (1, d_b)]
657        }
658
659        Op::BatchedDenseSolve => {
660            // Per-batch independent. Same implicit-function VJP as
661            // DenseSolve, lifted with a leading B axis throughout:
662            //   dB = batched_solve(Aᵀ, upstream)        same shape as B/X
663            //   dA = -batched_matmul(dB, Xᵀ)            shape [B, N, N]
664            // where `Aᵀ` swaps the LAST TWO axes (perm = [0, 2, 1]).
665            let a_bwd = fwd_map[&node.inputs[0]];
666            let x_bwd = fwd_map[&node.id];
667            let a_shape = bwd.node(a_bwd).shape.clone();
668            let x_shape = bwd.node(x_bwd).shape.clone();
669            assert_eq!(
670                a_shape.rank(),
671                3,
672                "BatchedDenseSolve VJP: A must be rank-3 [B, N, N]"
673            );
674            let b_dim = match a_shape.dim(0) {
675                Dim::Static(b) => b,
676                Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic B not supported"),
677            };
678            let n = match a_shape.dim(1) {
679                Dim::Static(n) => n,
680                Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic N not supported"),
681            };
682            let dtype = a_shape.dtype();
683
684            // Aᵀ across last two dims — perm = [0, 2, 1]. Output shape
685            // is [B, N, N] (same as A; transpose of square is square).
686            let a_t = bwd.add_node(
687                Op::Transpose {
688                    perm: vec![0, 2, 1],
689                },
690                vec![a_bwd],
691                a_shape.clone(),
692            );
693
694            // dB = batched_solve(Aᵀ, upstream).
695            let d_b = bwd.batched_dense_solve(a_t, upstream, x_shape.clone());
696
697            // dA = -batched_matmul(dB, Xᵀ).
698            let neg_outer = match x_shape.rank() {
699                2 => {
700                    // b is [B, N]. Reshape to [B, N, 1] (column) for dB
701                    // and [B, 1, N] (row) for X, then batched matmul.
702                    let col_shape = Shape::from_dims(
703                        &[Dim::Static(b_dim), Dim::Static(n), Dim::Static(1)],
704                        dtype,
705                    );
706                    let row_shape = Shape::from_dims(
707                        &[Dim::Static(b_dim), Dim::Static(1), Dim::Static(n)],
708                        dtype,
709                    );
710                    let db_col = bwd.add_node(
711                        Op::Reshape {
712                            new_shape: vec![b_dim as i64, n as i64, 1],
713                        },
714                        vec![d_b],
715                        col_shape,
716                    );
717                    let x_row = bwd.add_node(
718                        Op::Reshape {
719                            new_shape: vec![b_dim as i64, 1, n as i64],
720                        },
721                        vec![x_bwd],
722                        row_shape,
723                    );
724                    let outer = bwd.matmul(db_col, x_row, a_shape.clone());
725                    bwd.activation(Activation::Neg, outer, a_shape)
726                }
727                3 => {
728                    // b is [B, N, K]. dA = -MatMul(dB, Xᵀ) with
729                    // Xᵀ = Transpose(perm=[0, 2, 1]) so [B, K, N].
730                    let k = match x_shape.dim(2) {
731                        Dim::Static(k) => k,
732                        Dim::Dynamic(_) => panic!("BatchedDenseSolve VJP: dynamic K not supported"),
733                    };
734                    let xt_shape = Shape::from_dims(
735                        &[Dim::Static(b_dim), Dim::Static(k), Dim::Static(n)],
736                        dtype,
737                    );
738                    let x_t = bwd.add_node(
739                        Op::Transpose {
740                            perm: vec![0, 2, 1],
741                        },
742                        vec![x_bwd],
743                        xt_shape,
744                    );
745                    let outer = bwd.matmul(d_b, x_t, a_shape.clone());
746                    bwd.activation(Activation::Neg, outer, a_shape)
747                }
748                r => panic!("BatchedDenseSolve VJP: b must be rank 2 or 3, got rank {r}"),
749            };
750
751            vec![(0, neg_outer), (1, d_b)]
752        }
753
754        Op::Scan {
755            body,
756            length,
757            save_trajectory,
758            num_bcast: _,
759            num_xs,
760            num_checkpoints,
761        } => {
762            // After `convert_scans_for_ad`, every scan reaching the AD
763            // walk carries its trajectory. Compile body's VJP once
764            // — w.r.t. carry AND every xs — so we can extract dinit
765            // (Op::ScanBackward) plus dxs_i for each xs
766            // (Op::ScanBackwardXs). Each variant runs its own backward
767            // sweep; this is `1 + num_xs` independent sweeps. A future
768            // optimization can fuse them via packed multi-output.
769            let init_bwd = fwd_map[&node.inputs[0]];
770            let traj_bwd = fwd_map[&node.id];
771            let init_shape = bwd.node(init_bwd).shape.clone();
772
773            // Body Inputs in NodeId order: first = carry, rest = x_t_i.
774            let mut body_input_ids: Vec<NodeId> = body
775                .nodes()
776                .iter()
777                .filter(|n| matches!(n.op, Op::Input { .. }))
778                .map(|n| n.id)
779                .collect();
780            body_input_ids.sort();
781
782            let body_vjp = grad(body, &body_input_ids);
783
784            let xs_bwd: Vec<NodeId> = (0..*num_xs as usize)
785                .map(|i| fwd_map[&node.inputs[1 + i]])
786                .collect();
787
788            // Recursive checkpointing: when num_checkpoints is set on
789            // the forward Scan, propagate it (and the forward body) to
790            // each emitted ScanBackward / ScanBackwardXs so the
791            // executor knows to recompute carries via `forward_body`
792            // between checkpoints.
793            let is_checkpointed = *num_checkpoints != 0 && *num_checkpoints != *length;
794            let forward_body_for_bwd = if is_checkpointed {
795                Some((**body).clone())
796            } else {
797                None
798            };
799
800            let dinit = bwd.scan_backward_with_checkpoints(
801                init_bwd,
802                traj_bwd,
803                upstream,
804                &xs_bwd,
805                body_vjp.clone(),
806                *length,
807                *save_trajectory,
808                *num_checkpoints,
809                forward_body_for_bwd.clone(),
810                init_shape,
811            );
812
813            let mut grads: Vec<(usize, NodeId)> = vec![(0, dinit)];
814            for i in 0..*num_xs as usize {
815                let outer_xs_id = node.inputs[1 + i];
816                let xs_shape = bwd.node(fwd_map[&outer_xs_id]).shape.clone();
817                let dxs_i = bwd.scan_backward_xs_with_checkpoints(
818                    init_bwd,
819                    traj_bwd,
820                    upstream,
821                    &xs_bwd,
822                    body_vjp.clone(),
823                    *length,
824                    *save_trajectory,
825                    i as u32,
826                    *num_checkpoints,
827                    forward_body_for_bwd.clone(),
828                    xs_shape,
829                );
830                grads.push((1 + i, dxs_i));
831            }
832            grads
833        }
834
835        Op::Conv {
836            kernel_size,
837            stride,
838            padding,
839            dilation,
840            groups,
841        } => {
842            let x_bwd = fwd_map[&node.inputs[0]];
843            let w_bwd = fwd_map[&node.inputs[1]];
844            let x_shape = bwd.node(x_bwd).shape.clone();
845            let w_shape = bwd.node(w_bwd).shape.clone();
846            let dx = bwd.conv2d_backward_input(
847                upstream,
848                w_bwd,
849                x_shape,
850                kernel_size.clone(),
851                stride.clone(),
852                padding.clone(),
853                dilation.clone(),
854                *groups,
855            );
856            // Detect the QAT pattern (`Conv` reading from a
857            // `FakeQuantize` weight) so a follow-up specialization
858            // can skip dead bins (weights that round to the same
859            // code share the gradient). For now we still emit the
860            // generic backward — the helper just exposes the bits
861            // for a future kernel variant.
862            // QAT-bits detection requires the forward graph, which isn't
863            // threaded through `vjp`. Skip for now; the generic backward
864            // is used unconditionally.
865            let _qat_bits: Option<u8> = None;
866            let dw = bwd.conv2d_backward_weight(
867                x_bwd,
868                upstream,
869                w_shape,
870                kernel_size.clone(),
871                stride.clone(),
872                padding.clone(),
873                dilation.clone(),
874                *groups,
875            );
876            vec![(0, dx), (1, dw)]
877        }
878
879        Op::Pool {
880            kind: ReduceOp::Max,
881            kernel_size,
882            stride,
883            padding,
884        } => {
885            let x_bwd = fwd_map[&node.inputs[0]];
886            let dx = bwd.maxpool2d_backward(
887                x_bwd,
888                upstream,
889                kernel_size.clone(),
890                stride.clone(),
891                padding.clone(),
892            );
893            vec![(0, dx)]
894        }
895
896        Op::SoftmaxCrossEntropyWithLogits => {
897            let logits_bwd = fwd_map[&node.inputs[0]];
898            let labels_bwd = fwd_map[&node.inputs[1]];
899            let dlogits = bwd.softmax_cross_entropy_backward(logits_bwd, labels_bwd, upstream);
900            // labels has no gradient.
901            vec![(0, dlogits)]
902        }
903
904        Op::Reduce {
905            op: ReduceOp::Sum,
906            axes,
907            keep_dim,
908        } => {
909            let x_bwd = fwd_map[&node.inputs[0]];
910            let x_shape = bwd.node(x_bwd).shape.clone();
911            let g = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
912            vec![(0, g)]
913        }
914
915        Op::Reduce {
916            op: ReduceOp::Mean,
917            axes,
918            keep_dim,
919        } => {
920            // Mean = Sum / N. Do the Sum-style expansion first, then
921            // multiply the broadcast result by 1/N. Multiplying after
922            // the expand keeps the broadcast cleanly anchored at the
923            // full input shape and sidesteps the rank-promotion when
924            // the reduced output is a scalar (shape `[]`).
925            let x_bwd = fwd_map[&node.inputs[0]];
926            let x_shape = bwd.node(x_bwd).shape.clone();
927            let count: usize = axes
928                .iter()
929                .map(|&a| match x_shape.dim(a) {
930                    Dim::Static(n) => n,
931                    _ => panic!("Reduce::Mean VJP requires static reduced dims"),
932                })
933                .product();
934            let expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
935            let inv_count = scalar_const(1.0 / count as f32, bwd);
936            let g = bwd.binary(BinaryOp::Mul, expanded, inv_count, x_shape);
937            vec![(0, g)]
938        }
939
940        Op::Reshape { .. } => {
941            let x_bwd = fwd_map[&node.inputs[0]];
942            let x_shape = bwd.node(x_bwd).shape.clone();
943            let dx = reshape_to(upstream, &x_shape, bwd);
944            vec![(0, dx)]
945        }
946
947        Op::ComplexNormSq => {
948            // Wirtinger: ∂|z|²/∂z̄ = z. Cotangent g (real) maps to
949            // dz = g·z (complex, element-wise).
950            let z_bwd = fwd_map[&node.inputs[0]];
951            let dz = bwd.complex_norm_sq_backward(z_bwd, upstream);
952            vec![(0, dz)]
953        }
954
955        Op::Conjugate => {
956            // For w = conj(z): under the JAX-style cotangent (carrying
957            // ∂L/∂z̄ for a real-valued L), the rule reduces to
958            // cotangent_z = conj(cotangent_w). So the VJP of Conjugate
959            // is Conjugate itself. Symmetric — second-order derivatives
960            // through complex graphs stay consistent.
961            let dz = bwd.conjugate(upstream);
962            vec![(0, dz)]
963        }
964
965        Op::Cast { .. } => {
966            let x_bwd = fwd_map[&node.inputs[0]];
967            let x_shape = bwd.node(x_bwd).shape.clone();
968            let dx = bwd.add_node(
969                Op::Cast {
970                    to: x_shape.dtype(),
971                },
972                vec![upstream],
973                x_shape,
974            );
975            vec![(0, dx)]
976        }
977
978        // Stop-gradient (a.k.a. `detach`): forward identity, **no**
979        // gradient contribution to the input. Returning an empty list
980        // here keeps the reverse-mode walker from accumulating any
981        // upstream into `node.inputs[0]`, which is the whole point.
982        Op::StopGradient => vec![],
983
984        // Straight-through estimator: forward simulates the lossy
985        // round-trip (x → q → x'), backward pretends it was an
986        // identity. `dx = upstream` for both ops. The upstream is the
987        // f32 gradient computed by the consumer; the int8 dtype on
988        // the input/output is ignored for the gradient — we treat
989        // the entire Q/DQ chain as a real-valued no-op for autodiff
990        // purposes. This is the foundation for QAT (quantization-
991        // aware training): the model trains in fp32 but every
992        // forward pass tastes the int8 round-tripped activations,
993        // so the learned weights are robust to deployment-time
994        // quantization.
995        Op::Quantize { .. } | Op::Dequantize { .. } => {
996            vec![(0, upstream)]
997        }
998
999        Op::FakeQuantizeLSQ { bits, axis } => {
1000            // LSQ has TWO gradients: dx (STE-clipped) and dscale
1001            // (closed-form). Route them to inputs[0] (x) and
1002            // inputs[1] (scale) respectively.
1003            let x_bwd = fwd_map[&node.inputs[0]];
1004            let scale_bwd = fwd_map[&node.inputs[1]];
1005            let x_shape = bwd.node(x_bwd).shape.clone();
1006            let scale_shape = bwd.node(scale_bwd).shape.clone();
1007            let dx = bwd.add_node(
1008                Op::FakeQuantizeLSQBackwardX {
1009                    bits: *bits,
1010                    axis: *axis,
1011                },
1012                vec![x_bwd, scale_bwd, upstream],
1013                x_shape,
1014            );
1015            let dscale = bwd.add_node(
1016                Op::FakeQuantizeLSQBackwardScale {
1017                    bits: *bits,
1018                    axis: *axis,
1019                },
1020                vec![x_bwd, scale_bwd, upstream],
1021                scale_shape,
1022            );
1023            vec![(0, dx), (1, dscale)]
1024        }
1025
1026        // FakeQuantize backward depends on the STE variant. The
1027        // default `Identity` is a clean passthrough; the others
1028        // attenuate the gradient based on `x` and the per-channel
1029        // scale, so we emit a dedicated `FakeQuantizeBackward` op.
1030        Op::FakeQuantize {
1031            bits, axis, ste, ..
1032        } => {
1033            use rlx_ir::op::SteKind;
1034            match ste {
1035                SteKind::Identity => vec![(0, upstream)],
1036                _ => {
1037                    let x_bwd = fwd_map[&node.inputs[0]];
1038                    let x_shape = bwd.node(x_bwd).shape.clone();
1039                    let dx = bwd.add_node(
1040                        Op::FakeQuantizeBackward {
1041                            bits: *bits,
1042                            axis: *axis,
1043                            ste: *ste,
1044                        },
1045                        vec![x_bwd, upstream],
1046                        x_shape,
1047                    );
1048                    vec![(0, dx)]
1049                }
1050            }
1051        }
1052
1053        Op::Expand { .. } => {
1054            let x_bwd = fwd_map[&node.inputs[0]];
1055            let x_shape = bwd.node(x_bwd).shape.clone();
1056            let dx = unbroadcast(upstream, &x_shape, bwd);
1057            vec![(0, dx)]
1058        }
1059
1060        Op::BatchNormInference { eps } => {
1061            let x_bwd = fwd_map[&node.inputs[0]];
1062            let gamma_bwd = fwd_map[&node.inputs[1]];
1063            let _beta_bwd = fwd_map[&node.inputs[2]];
1064            let mean_bwd = fwd_map[&node.inputs[3]];
1065            let var_bwd = fwd_map[&node.inputs[4]];
1066            let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1067            let dx = bwd.batch_norm_inference_backward_input(
1068                x_bwd, gamma_bwd, mean_bwd, var_bwd, upstream, *eps,
1069            );
1070            let dgamma = bwd.batch_norm_inference_backward_gamma(
1071                x_bwd,
1072                mean_bwd,
1073                var_bwd,
1074                upstream,
1075                gamma_shape.clone(),
1076                *eps,
1077            );
1078            let dbeta = bwd.batch_norm_inference_backward_beta(upstream, gamma_shape);
1079            // mean/var are frozen — no gradients.
1080            vec![(0, dx), (1, dgamma), (2, dbeta)]
1081        }
1082
1083        Op::LayerNorm { axis, eps } => {
1084            // y = LayerNorm(x, gamma, beta) over the feature axis.
1085            // d_x via the dedicated `LayerNormBackwardInput` kernel
1086            // (closed-form, recomputes mean/var/x̂ inline).
1087            // d_gamma via `LayerNormBackwardGamma` (sums over batch axes).
1088            // d_beta = sum(upstream) over batch axes — composable with
1089            // an unbroadcast back to gamma's shape (gamma and beta share shape).
1090            let x_bwd = fwd_map[&node.inputs[0]];
1091            let gamma_bwd = fwd_map[&node.inputs[1]];
1092            let _beta_bwd = fwd_map[&node.inputs[2]];
1093            let gamma_shape = bwd.node(gamma_bwd).shape.clone();
1094
1095            let dx = bwd.layer_norm_backward_input(x_bwd, gamma_bwd, upstream, *axis, *eps);
1096            let dgamma =
1097                bwd.layer_norm_backward_gamma(x_bwd, upstream, gamma_shape.clone(), *axis, *eps);
1098            let dbeta = unbroadcast(upstream, &gamma_shape, bwd);
1099            vec![(0, dx), (1, dgamma), (2, dbeta)]
1100        }
1101
1102        Op::Softmax { axis } => {
1103            // y = softmax(x, axis)  →  dy/dx[i] = y[i] · (g[i] - Σⱼ y[j]·g[j])
1104            // where the Σⱼ is over the softmax axis. Compose from existing
1105            // primitives:  yg = y * upstream
1106            //              s  = reduce_sum(yg, axis, keep_dim=true)
1107            //              s' = expand(s, target=y.shape)
1108            //              dx = y * (upstream - s')
1109            //
1110            // The forward `y` lives at `fwd_to_bwd[node.id]` — bwd
1111            // graph mirrors every forward node so its slot survives
1112            // through this VJP. We *explicitly* expand `s` to `y.shape`
1113            // before the Sub rather than relying on `Op::Binary`'s
1114            // broadcast (which has a known shape-confusion bug for the
1115            // `[..., 1]` keep-dim case — see the rlx-cpu thunk
1116            // dispatch). Going through `Op::Expand` runs the
1117            // dedicated stride-aware broadcast thunk, which is correct.
1118            let y_bwd = fwd_map[&node.id];
1119            let y_shape = bwd.node(y_bwd).shape.clone();
1120            let dtype = y_shape.dtype();
1121            let rank = y_shape.rank();
1122            let axis_pos = if *axis < 0 {
1123                (rank as i32 + *axis) as usize
1124            } else {
1125                *axis as usize
1126            };
1127
1128            let yg = bwd.binary(BinaryOp::Mul, y_bwd, upstream, y_shape.clone());
1129
1130            let mut kept_dims: Vec<Dim> = y_shape.dims().to_vec();
1131            kept_dims[axis_pos] = Dim::Static(1);
1132            let kept_shape = Shape::from_dims(&kept_dims, dtype);
1133            let s = bwd.add_node(
1134                Op::Reduce {
1135                    op: ReduceOp::Sum,
1136                    axes: vec![axis_pos],
1137                    keep_dim: true,
1138                },
1139                vec![yg],
1140                kept_shape,
1141            );
1142
1143            let target_dims: Vec<i64> = y_shape
1144                .dims()
1145                .iter()
1146                .map(|d| match d {
1147                    Dim::Static(n) => *n as i64,
1148                    Dim::Dynamic(_) => -1,
1149                })
1150                .collect();
1151            let s_expanded = bwd.add_node(
1152                Op::Expand {
1153                    target_shape: target_dims,
1154                },
1155                vec![s],
1156                y_shape.clone(),
1157            );
1158
1159            let diff = bwd.binary(BinaryOp::Sub, upstream, s_expanded, y_shape.clone());
1160            let dx = bwd.binary(BinaryOp::Mul, y_bwd, diff, y_shape);
1161            vec![(0, dx)]
1162        }
1163
1164        // ── Shape ops: just route the upstream gradient through ──
1165        Op::Transpose { perm } => {
1166            // Inverse permutation: if forward maps axis i → perm[i],
1167            // backward maps perm[i] → i.
1168            let inv: Vec<usize> = {
1169                let mut v = vec![0usize; perm.len()];
1170                for (i, &p) in perm.iter().enumerate() {
1171                    v[p] = i;
1172                }
1173                v
1174            };
1175            let x_bwd = fwd_map[&node.inputs[0]];
1176            let x_shape = bwd.node(x_bwd).shape.clone();
1177            let dx = bwd.add_node(Op::Transpose { perm: inv }, vec![upstream], x_shape);
1178            vec![(0, dx)]
1179        }
1180
1181        Op::Concat { axis } => {
1182            // Split upstream along the concat axis: each input gets
1183            // `Narrow(upstream, axis, offset, x_i.dim(axis))`.
1184            let mut grads = Vec::with_capacity(node.inputs.len());
1185            let mut offset: usize = 0;
1186            for (i, &input_id) in node.inputs.iter().enumerate() {
1187                let x_bwd = fwd_map[&input_id];
1188                let x_shape = bwd.node(x_bwd).shape.clone();
1189                let len = match x_shape.dim(*axis) {
1190                    Dim::Static(n) => n,
1191                    _ => panic!("Concat VJP: dynamic concat dim"),
1192                };
1193                let dx = bwd.add_node(
1194                    Op::Narrow {
1195                        axis: *axis,
1196                        start: offset,
1197                        len,
1198                    },
1199                    vec![upstream],
1200                    x_shape,
1201                );
1202                grads.push((i, dx));
1203                offset += len;
1204            }
1205            grads
1206        }
1207
1208        Op::Narrow { axis, start, len } => {
1209            // Inverse of slicing: pad upstream with zeros on both
1210            // sides along `axis` so the result matches input shape.
1211            // Build via Concat[zeros_pre, upstream, zeros_post].
1212            let x_bwd = fwd_map[&node.inputs[0]];
1213            let x_shape = bwd.node(x_bwd).shape.clone();
1214            let full_n = match x_shape.dim(*axis) {
1215                Dim::Static(n) => n,
1216                _ => panic!("Narrow VJP: dynamic axis"),
1217            };
1218            let pre = *start;
1219            let post = full_n - *start - *len;
1220
1221            let zero_buf = |bwd: &mut Graph, len_axis: usize| -> NodeId {
1222                if len_axis == 0 {
1223                    return upstream; // sentinel, never used (filtered below)
1224                }
1225                let dtype = x_shape.dtype();
1226                let mut dims: Vec<Dim> = x_shape.dims().to_vec();
1227                dims[*axis] = Dim::Static(len_axis);
1228                let s = Shape::from_dims(&dims, dtype);
1229                let n_elems = dims.iter().fold(1usize, |a, d| match d {
1230                    Dim::Static(k) => a * k,
1231                    _ => a,
1232                });
1233                // Bytes per element scales with dtype; bytewise-zero is
1234                // a valid zero at any precision (IEEE +0.0 / signed 0 /
1235                // unsigned 0), so a vec of zero bytes is safe.
1236                let bytes = vec![0u8; n_elems * dtype.size_bytes()];
1237                bwd.add_node(Op::Constant { data: bytes }, vec![], s)
1238            };
1239
1240            let mut parts: Vec<NodeId> = Vec::new();
1241            if pre > 0 {
1242                parts.push(zero_buf(bwd, pre));
1243            }
1244            parts.push(upstream);
1245            if post > 0 {
1246                parts.push(zero_buf(bwd, post));
1247            }
1248
1249            let dx = if parts.len() == 1 {
1250                parts[0]
1251            } else {
1252                bwd.add_node(Op::Concat { axis: *axis }, parts, x_shape)
1253            };
1254            vec![(0, dx)]
1255        }
1256
1257        Op::Gather { axis } => {
1258            let table_bwd = fwd_map[&node.inputs[0]];
1259            let indices_bwd = fwd_map[&node.inputs[1]];
1260            let table_shape = bwd.node(table_bwd).shape.clone();
1261            if *axis == 0 {
1262                let dtable = bwd.add_node(Op::ScatterAdd, vec![upstream, indices_bwd], table_shape);
1263                vec![(0, dtable)]
1264            } else {
1265                let dtable = bwd.gather_backward(
1266                    upstream,
1267                    indices_bwd,
1268                    table_shape,
1269                    (*axis).try_into().unwrap(),
1270                );
1271                vec![(0, dtable)]
1272            }
1273        }
1274
1275        // ── Non-differentiable predicates / selectors ──
1276        Op::Compare(_) => {
1277            // Compare returns a boolean tensor; gradient w.r.t.
1278            // continuous inputs is zero almost everywhere. We don't
1279            // propagate (caller will see zero grads for any path
1280            // that flows through a Compare alone).
1281            vec![]
1282        }
1283
1284        Op::Where => {
1285            // out = where(cond, a, b). Cond has zero gradient
1286            // (it's a predicate); a's gradient is `where(cond,
1287            // upstream, 0)`; b's gradient is `where(cond, 0, upstream)`.
1288            let cond = fwd_map[&node.inputs[0]];
1289            let a_bwd = fwd_map[&node.inputs[1]];
1290            let b_bwd = fwd_map[&node.inputs[2]];
1291            let a_shape = bwd.node(a_bwd).shape.clone();
1292            let b_shape = bwd.node(b_bwd).shape.clone();
1293            let out_shape = upstream_shape.clone();
1294
1295            let zero_a_bytes = vec![0u8; a_shape.num_elements().expect("Where VJP: dynamic a") * 4];
1296            let zero_b_bytes = vec![0u8; b_shape.num_elements().expect("Where VJP: dynamic b") * 4];
1297            let zero_a = bwd.add_node(Op::Constant { data: zero_a_bytes }, vec![], a_shape.clone());
1298            let zero_b = bwd.add_node(Op::Constant { data: zero_b_bytes }, vec![], b_shape.clone());
1299            // Need to match shapes for Op::Where (cond, a, b same).
1300            // Upstream shape == out_shape == broadcast of a/b.
1301            let zero_a_bcast = unbroadcast_inverse(zero_a, &out_shape, bwd);
1302            let zero_b_bcast = unbroadcast_inverse(zero_b, &out_shape, bwd);
1303            let g_a_full = bwd.add_node(
1304                Op::Where,
1305                vec![cond, upstream, zero_a_bcast],
1306                out_shape.clone(),
1307            );
1308            let g_b_full = bwd.add_node(Op::Where, vec![cond, zero_b_bcast, upstream], out_shape);
1309            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1310            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1311            vec![(1, g_a), (2, g_b)]
1312        }
1313
1314        // ── Element-wise binary ops ──
1315        Op::Binary(BinaryOp::Div) => {
1316            // Real:  d/da (a/b) = 1/b,        d/db (a/b) = -a/b² = -y/b
1317            // C64 (Wirtinger):
1318            //        d/dā = upstream / conj(b)
1319            //        d/db̄ = -upstream · conj(y) / conj(b)
1320            // Substituting `b ↦ conj(b)` and `y ↦ conj(y)` in the real
1321            // rule recovers the complex one — the kernel itself is
1322            // unchanged.
1323            let a_bwd = fwd_map[&node.inputs[0]];
1324            let b_bwd = fwd_map[&node.inputs[1]];
1325            let y_bwd = fwd_map[&node.id];
1326            let a_shape = bwd.node(a_bwd).shape.clone();
1327            let b_shape = bwd.node(b_bwd).shape.clone();
1328            let is_c64 = upstream_shape.dtype() == DType::C64;
1329
1330            let b_term = if is_c64 { bwd.conjugate(b_bwd) } else { b_bwd };
1331            let y_term = if is_c64 { bwd.conjugate(y_bwd) } else { y_bwd };
1332
1333            // d/da: upstream / b_term
1334            let g_a_full = bwd.binary(BinaryOp::Div, upstream, b_term, upstream_shape.clone());
1335            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1336
1337            // d/db: -upstream * y_term / b_term
1338            let neg_up = bwd.activation(Activation::Neg, upstream, upstream_shape.clone());
1339            let neg_up_y = bwd.binary(BinaryOp::Mul, neg_up, y_term, upstream_shape.clone());
1340            let g_b_full = bwd.binary(BinaryOp::Div, neg_up_y, b_term, upstream_shape);
1341            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1342
1343            vec![(0, g_a), (1, g_b)]
1344        }
1345
1346        // ── Reductions: gradient flows to where the reduction "saw" ──
1347        Op::Reduce {
1348            op: ReduceOp::Max,
1349            axes,
1350            keep_dim,
1351        }
1352        | Op::Reduce {
1353            op: ReduceOp::Min,
1354            axes,
1355            keep_dim,
1356        } => {
1357            // d_x[i] = upstream where x[i] equals the (broadcast)
1358            // reduce result, else 0. Composed via
1359            // expand(upstream) * (compare(x, expand(y), Eq) → 1.0).
1360            let is_max = matches!(
1361                node.op,
1362                Op::Reduce {
1363                    op: ReduceOp::Max,
1364                    ..
1365                }
1366            );
1367            let _ = is_max;
1368            let x_bwd = fwd_map[&node.inputs[0]];
1369            let y_bwd = fwd_map[&node.id];
1370            let x_shape = bwd.node(x_bwd).shape.clone();
1371            let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1372            let mask_bool = bwd.add_node(
1373                Op::Compare(CmpOp::Eq),
1374                vec![x_bwd, y_expanded],
1375                Shape::from_dims(x_shape.dims(), DType::F32),
1376            );
1377            // Convert bool→f32 via Cast (the IR encodes bool/PRED as
1378            // F32 in our backends already; this is a no-op cast on
1379            // most paths).
1380            let mask_f32 = bwd.add_node(
1381                Op::Cast {
1382                    to: x_shape.dtype(),
1383                },
1384                vec![mask_bool],
1385                x_shape.clone(),
1386            );
1387            let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1388            let dx = bwd.binary(BinaryOp::Mul, upstream_expanded, mask_f32, x_shape);
1389            vec![(0, dx)]
1390        }
1391
1392        // ── Rope: backward is forward with negated sin ──
1393        //
1394        //   forward:  out = x * cos + rotate(x) * sin
1395        //   reverse:  dx  = dy * cos + rotate(dy) * (-sin)
1396        //         =  rope(dy, cos, neg(sin))
1397        Op::Rope { head_dim, n_rot } => {
1398            let cos = fwd_map[&node.inputs[1]];
1399            let sin = fwd_map[&node.inputs[2]];
1400            let dx = bwd.rope_backward(upstream, cos, sin, *head_dim, *n_rot);
1401            vec![(0, dx)]
1402        }
1403
1404        Op::RmsNorm { axis, eps } => {
1405            let x = fwd_map[&node.inputs[0]];
1406            let gamma = fwd_map[&node.inputs[1]];
1407            let beta = fwd_map[&node.inputs[2]];
1408            let dx = bwd.rms_norm_backward_input(x, gamma, beta, upstream, *axis, *eps);
1409            let dgamma = bwd.rms_norm_backward_gamma(x, gamma, beta, upstream, *axis, *eps);
1410            let dbeta = bwd.rms_norm_backward_beta(x, gamma, beta, upstream, *axis, *eps);
1411            vec![(0, dx), (1, dgamma), (2, dbeta)]
1412        }
1413
1414        Op::GroupNorm { num_groups, eps } => {
1415            let x = fwd_map[&node.inputs[0]];
1416            let gamma = fwd_map[&node.inputs[1]];
1417            let beta = fwd_map[&node.inputs[2]];
1418            let gamma_shape = bwd.node(gamma).shape.clone();
1419            let beta_shape = bwd.node(beta).shape.clone();
1420            let dx = bwd.group_norm_backward_input(x, gamma, beta, upstream, *num_groups, *eps);
1421            let dgamma = bwd.group_norm_backward_gamma(x, upstream, gamma_shape, *num_groups, *eps);
1422            let dbeta = bwd.group_norm_backward_beta(x, upstream, beta_shape, *num_groups, *eps);
1423            vec![(0, dx), (1, dgamma), (2, dbeta)]
1424        }
1425
1426        // ── Attention → dedicated backward kernels ──────────────
1427        Op::Attention {
1428            num_heads,
1429            head_dim,
1430            mask_kind,
1431            score_scale: _,
1432            attn_logit_softcap: _,
1433        } => {
1434            let q = fwd_map[&node.inputs[0]];
1435            let k = fwd_map[&node.inputs[1]];
1436            let v = fwd_map[&node.inputs[2]];
1437            let mask = match mask_kind {
1438                MaskKind::Custom | MaskKind::Bias => Some(fwd_map[&node.inputs[3]]),
1439                _ => None,
1440            };
1441            let (dq, dk, dv) = bwd
1442                .attention_backward_all(q, k, v, upstream, *num_heads, *head_dim, *mask_kind, mask);
1443            vec![(0, dq), (1, dk), (2, dv)]
1444        }
1445
1446        // ── Reduce(Prod) ────────────────────────────────────────
1447        //
1448        // Forward: y[axes_reduced] = ∏ x[axes_reduced…]
1449        // Backward: dx[i] = upstream · y / x[i]   (per-row).
1450        // (Numerically dicey when any x[i] = 0; production users
1451        //  needing zero-safe Prod-grad should pre-mask.)
1452        Op::Reduce {
1453            op: ReduceOp::Prod,
1454            axes,
1455            keep_dim,
1456        } => {
1457            let x_bwd = fwd_map[&node.inputs[0]];
1458            let y_bwd = fwd_map[&node.id];
1459            let x_shape = bwd.node(x_bwd).shape.clone();
1460            let y_expanded = expand_to(y_bwd, &x_shape, axes, *keep_dim, bwd);
1461            let upstream_expanded = expand_to(upstream, &x_shape, axes, *keep_dim, bwd);
1462            // dx = upstream_b · y_b / x
1463            let num = bwd.binary(
1464                BinaryOp::Mul,
1465                upstream_expanded,
1466                y_expanded,
1467                x_shape.clone(),
1468            );
1469            let dx = bwd.binary(BinaryOp::Div, num, x_bwd, x_shape);
1470            vec![(0, dx)]
1471        }
1472
1473        // ── Pool(Mean) ──────────────────────────────────────────
1474        //
1475        // Forward: y[..., h_out, w_out] = mean(window).
1476        // Backward: dx[i] = upstream[output_pos(i)] / |window|
1477        //   distributed across each pool window.
1478        //
1479        // Compose via a Conv2dBackwardInput with a constant
1480        // 1/|window| kernel of shape [C, 1, kH, kW] and groups=C
1481        // (depthwise — no channel mixing). This gives the correct
1482        // "spread upstream over window" behavior including stride
1483        // and padding handling.
1484        Op::Pool {
1485            kind: ReduceOp::Mean,
1486            kernel_size,
1487            stride,
1488            padding,
1489        } => {
1490            assert_eq!(kernel_size.len(), 2, "Pool(Mean) VJP: 2-D pool only");
1491            let x_bwd = fwd_map[&node.inputs[0]];
1492            let x_shape = bwd.node(x_bwd).shape.clone();
1493            let dtype = x_shape.dtype();
1494            // Channels = x_shape.dim(1).
1495            let c = match x_shape.dim(1) {
1496                Dim::Static(n) => n,
1497                _ => panic!("Pool(Mean) VJP: dynamic channel dim"),
1498            };
1499            let kh = kernel_size[0];
1500            let kw = kernel_size[1];
1501            let inv_n = 1.0_f32 / (kh as f32 * kw as f32);
1502            let kernel_n = c * kh * kw;
1503            let mut bytes: Vec<u8> = Vec::with_capacity(kernel_n * 4);
1504            for _ in 0..kernel_n {
1505                bytes.extend_from_slice(&inv_n.to_le_bytes());
1506            }
1507            let kernel_shape = Shape::from_dims(
1508                &[
1509                    Dim::Static(c),
1510                    Dim::Static(1),
1511                    Dim::Static(kh),
1512                    Dim::Static(kw),
1513                ],
1514                dtype,
1515            );
1516            let kernel = bwd.add_node(Op::Constant { data: bytes }, vec![], kernel_shape);
1517            let dx = bwd.conv2d_backward_input(
1518                upstream,
1519                kernel,
1520                x_shape,
1521                kernel_size.clone(),
1522                stride.clone(),
1523                padding.clone(),
1524                vec![1, 1],
1525                c, // groups = c → depthwise
1526            );
1527            vec![(0, dx)]
1528        }
1529
1530        // ── Binary(Min/Max) ─────────────────────────────────────
1531        //
1532        // Element-wise min/max: gradient flows to whichever input
1533        // was selected (ties go to the first operand by convention).
1534        //   da = where(a == out, upstream, 0)
1535        //   db = where(a == out, 0, upstream)   ← exclusive
1536        Op::Binary(BinaryOp::Min) | Op::Binary(BinaryOp::Max) => {
1537            let a_bwd = fwd_map[&node.inputs[0]];
1538            let b_bwd = fwd_map[&node.inputs[1]];
1539            let y_bwd = fwd_map[&node.id];
1540            let a_shape = bwd.node(a_bwd).shape.clone();
1541            let b_shape = bwd.node(b_bwd).shape.clone();
1542            let dtype = upstream_shape.dtype();
1543
1544            let bool_shape = Shape::from_dims(upstream_shape.dims(), DType::Bool);
1545            let mask_pred = bwd.add_node(Op::Compare(CmpOp::Eq), vec![a_bwd, y_bwd], bool_shape);
1546            let mask_f32 = bwd.add_node(
1547                Op::Cast { to: dtype },
1548                vec![mask_pred],
1549                upstream_shape.clone(),
1550            );
1551            let zero_bytes = vec![
1552                0u8;
1553                upstream_shape
1554                    .num_elements()
1555                    .expect("Min/Max VJP: dyn shape")
1556                    * 4
1557            ];
1558            let zero = bwd.add_node(
1559                Op::Constant { data: zero_bytes },
1560                vec![],
1561                upstream_shape.clone(),
1562            );
1563            let g_a_full = bwd.add_node(
1564                Op::Where,
1565                vec![mask_f32, upstream, zero],
1566                upstream_shape.clone(),
1567            );
1568            let g_b_full = bwd.add_node(Op::Where, vec![mask_f32, zero, upstream], upstream_shape);
1569            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1570            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1571            vec![(0, g_a), (1, g_b)]
1572        }
1573
1574        // ── Binary(Pow) ─────────────────────────────────────────
1575        //
1576        //   d/da (aᵇ) = b · a^(b-1)
1577        //   d/db (aᵇ) = aᵇ · ln(a)
1578        //
1579        // We don't have a `Pow` activation, but `pow(a, b)` for
1580        // positive base equals `exp(b · ln(a))`, and the derivative
1581        // simplifies. Express via `Activation::Log / Exp` and `Mul`.
1582        Op::Binary(BinaryOp::Pow) => {
1583            let a_bwd = fwd_map[&node.inputs[0]];
1584            let b_bwd = fwd_map[&node.inputs[1]];
1585            let y_bwd = fwd_map[&node.id]; // a^b
1586            let a_shape = bwd.node(a_bwd).shape.clone();
1587            let b_shape = bwd.node(b_bwd).shape.clone();
1588
1589            // d/da: upstream · y / a = upstream · b · a^(b-1).
1590            // Easier route: upstream · y · b / a.
1591            let yb = bwd.binary(BinaryOp::Mul, y_bwd, b_bwd, upstream_shape.clone());
1592            let yb_over_a = bwd.binary(BinaryOp::Div, yb, a_bwd, upstream_shape.clone());
1593            let g_a_full = bwd.binary(BinaryOp::Mul, upstream, yb_over_a, upstream_shape.clone());
1594            let g_a = unbroadcast(g_a_full, &a_shape, bwd);
1595
1596            // d/db: upstream · y · ln(a)
1597            let ln_a = bwd.activation(Activation::Log, a_bwd, a_shape);
1598            let ln_a_b = unbroadcast_inverse(ln_a, &upstream_shape, bwd);
1599            let yln = bwd.binary(BinaryOp::Mul, y_bwd, ln_a_b, upstream_shape.clone());
1600            let g_b_full = bwd.binary(BinaryOp::Mul, upstream, yln, upstream_shape);
1601            let g_b = unbroadcast(g_b_full, &b_shape, bwd);
1602
1603            vec![(0, g_a), (1, g_b)]
1604        }
1605
1606        // ── DequantMatMul (QAT-style straight-through) ─────────
1607        //
1608        // Forward (Int8BlockAsym):
1609        //   w_dq = (cast<f32>(w_q) - zp_b) * scale_b
1610        //   y    = x @ w_dq
1611        //
1612        // Backward (QAT convention — scale and zp are typically
1613        // frozen during fine-tuning; w_q's int8 cast is treated as
1614        // a no-op for the gradient via straight-through):
1615        //   dx     = upstream @ w_dq^T
1616        //   dw_q   = x^T @ upstream * scale_b   (straight-through;
1617        //            the user's optimizer would project back to
1618        //            int8 after the step)
1619        //   dscale = 0   (frozen)
1620        //   dzp    = 0   (frozen)
1621        //
1622        // For full QAT with learnable scale/zp, replace the zero
1623        // gradients with the closed-form ∂y/∂scale / ∂y/∂zp.
1624        Op::DequantMatMul { scheme: _ } => {
1625            let x_bwd = fwd_map[&node.inputs[0]];
1626            let w_q_bwd = fwd_map[&node.inputs[1]];
1627            let scale_bwd = fwd_map[&node.inputs[2]];
1628            let zp_bwd = fwd_map[&node.inputs[3]];
1629            let x_shape = bwd.node(x_bwd).shape.clone();
1630            let w_shape = bwd.node(w_q_bwd).shape.clone();
1631            let scale_shape = bwd.node(scale_bwd).shape.clone();
1632            let zp_shape = bwd.node(zp_bwd).shape.clone();
1633
1634            // dx = upstream @ w_dq^T. Recompute w_dq inline.
1635            // w_q is int8 in the IR — cast to f32 for the matmul
1636            // backward graph (straight-through equivalent).
1637            let dtype = x_shape.dtype();
1638            let w_q_f32 = bwd.add_node(
1639                Op::Cast { to: dtype },
1640                vec![w_q_bwd],
1641                Shape::from_dims(w_shape.dims(), dtype),
1642            );
1643            // Broadcast scale/zp to w_shape before subtract/mul.
1644            let scale_b =
1645                unbroadcast_inverse(scale_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1646            let zp_b = unbroadcast_inverse(zp_bwd, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1647            let w_centered = bwd.binary(
1648                BinaryOp::Sub,
1649                w_q_f32,
1650                zp_b,
1651                Shape::from_dims(w_shape.dims(), dtype),
1652            );
1653            let w_dq = bwd.binary(
1654                BinaryOp::Mul,
1655                w_centered,
1656                scale_b,
1657                Shape::from_dims(w_shape.dims(), dtype),
1658            );
1659
1660            // Transpose w_dq's last two dims for dx = upstream @ w_dq^T.
1661            let w_rank = w_shape.rank();
1662            let mut perm: Vec<usize> = (0..w_rank).collect();
1663            perm.swap(w_rank - 2, w_rank - 1);
1664            let mut wdt_dims: Vec<Dim> = w_shape.dims().to_vec();
1665            wdt_dims.swap(w_rank - 2, w_rank - 1);
1666            let w_dq_t_shape = Shape::from_dims(&wdt_dims, dtype);
1667            let w_dq_t = bwd.add_node(Op::Transpose { perm }, vec![w_dq], w_dq_t_shape);
1668            let dx = bwd.matmul(upstream, w_dq_t, x_shape.clone());
1669
1670            // dw_q = (x^T @ upstream) * scale_b   (straight-through).
1671            // The result is in the int8-weight space — caller's
1672            // optimizer is expected to project back. We emit it as
1673            // f32 here and let downstream cast.
1674            let x_rank = x_shape.rank();
1675            let mut x_perm: Vec<usize> = (0..x_rank).collect();
1676            x_perm.swap(x_rank - 2, x_rank - 1);
1677            let mut x_t_dims: Vec<Dim> = x_shape.dims().to_vec();
1678            x_t_dims.swap(x_rank - 2, x_rank - 1);
1679            let x_t = bwd.add_node(
1680                Op::Transpose { perm: x_perm },
1681                vec![x_bwd],
1682                Shape::from_dims(&x_t_dims, dtype),
1683            );
1684            let dw_unscaled = bwd.matmul(x_t, upstream, Shape::from_dims(w_shape.dims(), dtype));
1685            let dw_q_f32 = bwd.binary(
1686                BinaryOp::Mul,
1687                dw_unscaled,
1688                scale_b,
1689                Shape::from_dims(w_shape.dims(), dtype),
1690            );
1691            // Cast back to the IR's int8 dtype convention.
1692            let dw_q = bwd.add_node(
1693                Op::Cast {
1694                    to: w_shape.dtype(),
1695                },
1696                vec![dw_q_f32],
1697                w_shape,
1698            );
1699
1700            // scale and zp: zero gradients (frozen QAT convention).
1701            let zero_scale_bytes =
1702                vec![0u8; scale_shape.num_elements().expect("DQMM VJP: dyn scale") * 4];
1703            let zero_zp_bytes = vec![0u8; zp_shape.num_elements().expect("DQMM VJP: dyn zp") * 4];
1704            let dscale = bwd.add_node(
1705                Op::Constant {
1706                    data: zero_scale_bytes,
1707                },
1708                vec![],
1709                scale_shape,
1710            );
1711            let dzp = bwd.add_node(
1712                Op::Constant {
1713                    data: zero_zp_bytes,
1714                },
1715                vec![],
1716                zp_shape,
1717            );
1718
1719            vec![(0, dx), (1, dw_q), (2, dscale), (3, dzp)]
1720        }
1721
1722        // ── ScatterAdd ──────────────────────────────────────────
1723        //
1724        // Forward: out[indices[i], ...] += updates[i, ...].
1725        // Backward: d_updates[i, ...] = upstream[indices[i], ...]  (gather).
1726        //   Indices are non-differentiable.
1727        Op::ScatterAdd => {
1728            let updates_bwd = fwd_map[&node.inputs[0]];
1729            let indices_bwd = fwd_map[&node.inputs[1]];
1730            let updates_shape = bwd.node(updates_bwd).shape.clone();
1731            let dupdates = bwd.add_node(
1732                Op::Gather { axis: 0 },
1733                vec![upstream, indices_bwd],
1734                updates_shape,
1735            );
1736            vec![(0, dupdates)]
1737        }
1738
1739        // ── Cumsum ──────────────────────────────────────────────
1740        //
1741        Op::Cumsum { axis, exclusive } => {
1742            let x_bwd = fwd_map[&node.inputs[0]];
1743            let x_shape = bwd.node(x_bwd).shape.clone();
1744            let dx = bwd.cumsum_backward(upstream, x_shape, *axis, *exclusive);
1745            vec![(0, dx)]
1746        }
1747
1748        // ── GroupedMatMul (MoE primitive) ──────────────────────
1749        //
1750        // Forward: y[i] = x[i] @ w[expert[i]]
1751        //   x        [M, K]
1752        //   w        [E, K, N]
1753        //   expert   [M] (f32-encoded indices)
1754        //   y        [M, N]
1755        //
1756        // Backward (composed via Gather + batched-MatMul + ScatterAdd):
1757        //   dx[i] = upstream[i] @ w[expert[i]]^T
1758        //   dw[e, k, n] = sum_{i : expert[i]=e} x[i,k] · upstream[i,n]
1759        //   dexpert: zero (non-differentiable index input).
1760        Op::GroupedMatMul => {
1761            let x_bwd = fwd_map[&node.inputs[0]];
1762            let w_bwd = fwd_map[&node.inputs[1]];
1763            let expert_bwd = fwd_map[&node.inputs[2]];
1764            let x_shape = bwd.node(x_bwd).shape.clone();
1765            let w_shape = bwd.node(w_bwd).shape.clone();
1766            let (dx, dw) =
1767                grouped_matmul_vjp(bwd, upstream, x_bwd, w_bwd, expert_bwd, &x_shape, &w_shape);
1768            vec![(0, dx), (1, dw)]
1769        }
1770
1771        // ── DequantGroupedMatMul (frozen GGUF MoE weights) ─────
1772        //
1773        // Materialize w_dq via `Op::DequantMoEWeights`, then reuse the
1774        // GroupedMatMul VJP. Packed U8 weights and expert indices are
1775        // non-differentiable (inference / QAT-frozen convention).
1776        Op::DequantGroupedMatMul { scheme } => {
1777            let x_bwd = fwd_map[&node.inputs[0]];
1778            let w_packed = fwd_map[&node.inputs[1]];
1779            let expert_bwd = fwd_map[&node.inputs[2]];
1780            let x_shape = bwd.node(x_bwd).shape.clone();
1781            let w_packed_shape = bwd.node(w_packed).shape.clone();
1782            let dtype = x_shape.dtype();
1783            let k = x_shape.dim(1);
1784            let n_out = node.shape.dim(node.shape.rank() - 1);
1785            let k_static = match k {
1786                Dim::Static(v) => v,
1787                _ => panic!("DequantGroupedMatMul VJP: K must be static"),
1788            };
1789            let n_static = match n_out {
1790                Dim::Static(v) => v,
1791                _ => panic!("DequantGroupedMatMul VJP: N must be static"),
1792            };
1793            let block_elems = scheme.gguf_block_size() as usize;
1794            let block_bytes = scheme.gguf_block_bytes() as usize;
1795            let slab_bytes = (k_static * n_static) / block_elems * block_bytes;
1796            let total_bytes = w_packed_shape
1797                .num_elements()
1798                .expect("DequantGroupedMatMul VJP: dyn packed");
1799            let e_static = total_bytes / slab_bytes.max(1);
1800            let w_shape = Shape::from_dims(
1801                &[
1802                    Dim::Static(e_static),
1803                    Dim::Static(k_static),
1804                    Dim::Static(n_static),
1805                ],
1806                dtype,
1807            );
1808            let w_dq = bwd.add_node(
1809                Op::DequantMoEWeights { scheme: *scheme },
1810                vec![w_packed],
1811                w_shape.clone(),
1812            );
1813            let (dx, _dw) =
1814                grouped_matmul_vjp(bwd, upstream, x_bwd, w_dq, expert_bwd, &x_shape, &w_shape);
1815            vec![(0, dx)]
1816        }
1817
1818        // ── QMatMul / QConv2d (straight-through INT8 backward) ──
1819        //
1820        // Real INT8 inference kernels. The forward applies
1821        //   out = clamp(round((x − x_zp) · (w − w_zp) · mult + bias)
1822        //               + out_zp, [-128, 127])
1823        // and outputs i8. For training, the standard QAT recipe
1824        // treats the round/clamp/quantize as straight-through, so
1825        // the gradient is what a plain f32 MatMul (or Conv) backward
1826        // would give applied to the dequantized representations.
1827        // Zero-points and `mult` are typically frozen (calibration
1828        // outputs); we emit zero gradients for them. Bias gets the
1829        // standard sum-over-batch gradient.
1830        Op::QMatMul {
1831            x_zp,
1832            w_zp,
1833            out_zp: _,
1834            mult,
1835        } => {
1836            let x_bwd = fwd_map[&node.inputs[0]];
1837            let w_bwd = fwd_map[&node.inputs[1]];
1838            let bias_bwd = fwd_map[&node.inputs[2]];
1839            let x_shape = bwd.node(x_bwd).shape.clone();
1840            let w_shape = bwd.node(w_bwd).shape.clone();
1841            let bias_shape = bwd.node(bias_bwd).shape.clone();
1842            let dtype = upstream_shape.dtype();
1843
1844            // Promote x and w to f32 (straight-through); subtract zps.
1845            let x_f32 = bwd.add_node(
1846                Op::Cast { to: dtype },
1847                vec![x_bwd],
1848                Shape::from_dims(x_shape.dims(), dtype),
1849            );
1850            let w_f32 = bwd.add_node(
1851                Op::Cast { to: dtype },
1852                vec![w_bwd],
1853                Shape::from_dims(w_shape.dims(), dtype),
1854            );
1855            let xzp_c = scalar_const(*x_zp as f32, bwd);
1856            let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
1857            let _ = bwd.binary(
1858                BinaryOp::Sub,
1859                x_f32,
1860                xzp_b,
1861                Shape::from_dims(x_shape.dims(), dtype),
1862            );
1863            let wzp_c = scalar_const(*w_zp as f32, bwd);
1864            let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
1865            let w_centered = bwd.binary(
1866                BinaryOp::Sub,
1867                w_f32,
1868                wzp_b,
1869                Shape::from_dims(w_shape.dims(), dtype),
1870            );
1871
1872            // mult scaling.
1873            let mult_c = scalar_const(*mult, bwd);
1874            let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
1875            let upstream_scaled =
1876                bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
1877
1878            // dx = upstream_scaled @ w_centered^T   (still i8 dtype
1879            //  on the input side; cast the gradient back).
1880            let w_rank = w_shape.rank();
1881            let mut perm: Vec<usize> = (0..w_rank).collect();
1882            perm.swap(w_rank - 2, w_rank - 1);
1883            let mut wt_dims: Vec<Dim> = w_shape.dims().to_vec();
1884            wt_dims.swap(w_rank - 2, w_rank - 1);
1885            let w_t = bwd.add_node(
1886                Op::Transpose { perm },
1887                vec![w_centered],
1888                Shape::from_dims(&wt_dims, dtype),
1889            );
1890            let dx_f32 = bwd.matmul(
1891                upstream_scaled,
1892                w_t,
1893                Shape::from_dims(x_shape.dims(), dtype),
1894            );
1895            let dx = bwd.add_node(
1896                Op::Cast {
1897                    to: x_shape.dtype(),
1898                },
1899                vec![dx_f32],
1900                x_shape.clone(),
1901            );
1902
1903            // dw = x_centered^T @ upstream_scaled  (similarly cast).
1904            let x_rank = x_shape.rank();
1905            let mut x_perm: Vec<usize> = (0..x_rank).collect();
1906            x_perm.swap(x_rank - 2, x_rank - 1);
1907            let mut xt_dims: Vec<Dim> = x_shape.dims().to_vec();
1908            xt_dims.swap(x_rank - 2, x_rank - 1);
1909            // Need to pull x_centered into scope — recompute inline.
1910            let x_f32_2 = bwd.add_node(
1911                Op::Cast { to: dtype },
1912                vec![x_bwd],
1913                Shape::from_dims(x_shape.dims(), dtype),
1914            );
1915            let x_centered = bwd.binary(
1916                BinaryOp::Sub,
1917                x_f32_2,
1918                xzp_b,
1919                Shape::from_dims(x_shape.dims(), dtype),
1920            );
1921            let x_t = bwd.add_node(
1922                Op::Transpose { perm: x_perm },
1923                vec![x_centered],
1924                Shape::from_dims(&xt_dims, dtype),
1925            );
1926            let dw_f32 = bwd.matmul(
1927                x_t,
1928                upstream_scaled,
1929                Shape::from_dims(w_shape.dims(), dtype),
1930            );
1931            let dw = bwd.add_node(
1932                Op::Cast {
1933                    to: w_shape.dtype(),
1934                },
1935                vec![dw_f32],
1936                w_shape,
1937            );
1938
1939            // dbias = sum upstream_scaled over batch axes (matches
1940            // f32 MatMul-with-bias backward shape).
1941            let bias_rank = bias_shape.rank();
1942            let reduce_axes: Vec<usize> = (0..upstream_shape.rank())
1943                .filter(|&i| i + bias_rank < upstream_shape.rank() || i == 0)
1944                .collect();
1945            let dbias_f32 = bwd.add_node(
1946                Op::Reduce {
1947                    op: ReduceOp::Sum,
1948                    axes: reduce_axes,
1949                    keep_dim: false,
1950                },
1951                vec![upstream_scaled],
1952                Shape::from_dims(bias_shape.dims(), dtype),
1953            );
1954            let dbias = bwd.add_node(
1955                Op::Cast {
1956                    to: bias_shape.dtype(),
1957                },
1958                vec![dbias_f32],
1959                bias_shape,
1960            );
1961
1962            vec![(0, dx), (1, dw), (2, dbias)]
1963        }
1964
1965        Op::QConv2d {
1966            kernel_size,
1967            stride,
1968            padding,
1969            dilation,
1970            groups,
1971            x_zp,
1972            w_zp,
1973            out_zp: _,
1974            mult,
1975        } => {
1976            // Same straight-through pattern as QMatMul, lifted to
1977            // 2-D conv via the existing Conv2dBackwardInput / Weight
1978            // kernels.
1979            let x_bwd = fwd_map[&node.inputs[0]];
1980            let w_bwd = fwd_map[&node.inputs[1]];
1981            let bias_bwd = fwd_map[&node.inputs[2]];
1982            let x_shape = bwd.node(x_bwd).shape.clone();
1983            let w_shape = bwd.node(w_bwd).shape.clone();
1984            let bias_shape = bwd.node(bias_bwd).shape.clone();
1985            let dtype = upstream_shape.dtype();
1986
1987            // Promote and dequantize.
1988            let x_f32 = bwd.add_node(
1989                Op::Cast { to: dtype },
1990                vec![x_bwd],
1991                Shape::from_dims(x_shape.dims(), dtype),
1992            );
1993            let w_f32 = bwd.add_node(
1994                Op::Cast { to: dtype },
1995                vec![w_bwd],
1996                Shape::from_dims(w_shape.dims(), dtype),
1997            );
1998            let xzp_c = scalar_const(*x_zp as f32, bwd);
1999            let xzp_b = unbroadcast_inverse(xzp_c, &Shape::from_dims(x_shape.dims(), dtype), bwd);
2000            let x_centered = bwd.binary(
2001                BinaryOp::Sub,
2002                x_f32,
2003                xzp_b,
2004                Shape::from_dims(x_shape.dims(), dtype),
2005            );
2006            let wzp_c = scalar_const(*w_zp as f32, bwd);
2007            let wzp_b = unbroadcast_inverse(wzp_c, &Shape::from_dims(w_shape.dims(), dtype), bwd);
2008            let w_centered = bwd.binary(
2009                BinaryOp::Sub,
2010                w_f32,
2011                wzp_b,
2012                Shape::from_dims(w_shape.dims(), dtype),
2013            );
2014
2015            // mult scaling on upstream.
2016            let mult_c = scalar_const(*mult, bwd);
2017            let mult_b = unbroadcast_inverse(mult_c, &upstream_shape, bwd);
2018            let upstream_scaled =
2019                bwd.binary(BinaryOp::Mul, upstream, mult_b, upstream_shape.clone());
2020
2021            // dx, dw via the existing conv-backward kernels.
2022            let dx_f32 = bwd.conv2d_backward_input(
2023                upstream_scaled,
2024                w_centered,
2025                Shape::from_dims(x_shape.dims(), dtype),
2026                kernel_size.clone(),
2027                stride.clone(),
2028                padding.clone(),
2029                dilation.clone(),
2030                *groups,
2031            );
2032            let dx = bwd.add_node(
2033                Op::Cast {
2034                    to: x_shape.dtype(),
2035                },
2036                vec![dx_f32],
2037                x_shape,
2038            );
2039            let dw_f32 = bwd.conv2d_backward_weight(
2040                x_centered,
2041                upstream_scaled,
2042                Shape::from_dims(w_shape.dims(), dtype),
2043                kernel_size.clone(),
2044                stride.clone(),
2045                padding.clone(),
2046                dilation.clone(),
2047                *groups,
2048            );
2049            let dw = bwd.add_node(
2050                Op::Cast {
2051                    to: w_shape.dtype(),
2052                },
2053                vec![dw_f32],
2054                w_shape,
2055            );
2056
2057            // dbias = sum upstream_scaled over (N, H_out, W_out) keeping C_out.
2058            let dbias_f32 = bwd.add_node(
2059                Op::Reduce {
2060                    op: ReduceOp::Sum,
2061                    axes: vec![0, 2, 3],
2062                    keep_dim: false,
2063                },
2064                vec![upstream_scaled],
2065                Shape::from_dims(bias_shape.dims(), dtype),
2066            );
2067            let dbias = bwd.add_node(
2068                Op::Cast {
2069                    to: bias_shape.dtype(),
2070                },
2071                vec![dbias_f32],
2072                bias_shape,
2073            );
2074
2075            vec![(0, dx), (1, dw), (2, dbias)]
2076        }
2077
2078        // ── Sampling-style ops: non-differentiable ──
2079        Op::TopK { .. } | Op::Sample { .. } => {
2080            // TopK selects; Sample multinomial-draws. Gradient w.r.t.
2081            // the input distribution is undefined / zero in the
2082            // standard sense. Skip propagation.
2083            vec![]
2084        }
2085
2086        Op::GaussianSplatRender {
2087            width,
2088            height,
2089            tile_size,
2090            radius_scale,
2091            alpha_cutoff,
2092            max_splat_steps,
2093            transmittance_threshold,
2094            max_list_entries,
2095            ..
2096        } => {
2097            use rlx_ir::ops::splat::{
2098                GaussianSplatBackwardParams, GaussianSplatInputs, GaussianSplatRenderParams,
2099                unpack_gaussian_splat_packed_grads,
2100            };
2101            let render = GaussianSplatRenderParams {
2102                width: *width,
2103                height: *height,
2104                tile_size: *tile_size,
2105                radius_scale: *radius_scale,
2106                alpha_cutoff: *alpha_cutoff,
2107                max_splat_steps: *max_splat_steps,
2108                transmittance_threshold: *transmittance_threshold,
2109                max_list_entries: *max_list_entries,
2110            };
2111            let inputs = GaussianSplatInputs {
2112                positions: fwd_map[&node.inputs[0]],
2113                scales: fwd_map[&node.inputs[1]],
2114                rotations: fwd_map[&node.inputs[2]],
2115                opacities: fwd_map[&node.inputs[3]],
2116                colors: fwd_map[&node.inputs[4]],
2117                sh_coeffs: fwd_map[&node.inputs[5]],
2118                meta: fwd_map[&node.inputs[6]],
2119            };
2120            let count = bwd.shape(inputs.positions).num_elements().unwrap_or(0) / 3;
2121            let sh_len = bwd.shape(inputs.sh_coeffs).num_elements().unwrap_or(0);
2122            let meta_shape = bwd.shape(inputs.meta).clone();
2123            let packed = bwd.gaussian_splat_render_backward(
2124                inputs,
2125                upstream,
2126                GaussianSplatBackwardParams {
2127                    render,
2128                    loss_grad_clip: 1.0,
2129                    sh_band: 0,
2130                    max_anisotropy: 10.0,
2131                },
2132            );
2133            let sh_coeff_count = if count == 0 {
2134                1
2135            } else {
2136                (sh_len / (count * 3)).max(1)
2137            };
2138            let grads = unpack_gaussian_splat_packed_grads(bwd, packed, count, sh_coeff_count);
2139            let meta_n = meta_shape.num_elements().unwrap_or(0);
2140            let zero_meta = bwd.add_node(
2141                Op::Constant {
2142                    data: vec![0u8; meta_n * meta_shape.dtype().size_bytes()],
2143                },
2144                vec![],
2145                meta_shape,
2146            );
2147            vec![
2148                (0, grads.positions),
2149                (1, grads.scales),
2150                (2, grads.rotations),
2151                (3, grads.opacities),
2152                (4, grads.colors),
2153                (5, grads.sh_coeffs),
2154                (6, zero_meta),
2155            ]
2156        }
2157
2158        Op::GaussianSplatRenderBackward { .. } => {
2159            // Scene/meta inputs are not differentiated through this op in v1.
2160            vec![]
2161        }
2162
2163        Op::GaussianSplatPrepare { .. } | Op::GaussianSplatRasterize { .. } => {
2164            panic!(
2165                "autodiff: decomposed splat ops must be fused before AD — \
2166                 `prepare_graph_for_ad` rewrites Prepare→Rasterize into \
2167                 `GaussianSplatRender`, or use `Op::GaussianSplatRender` directly"
2168            );
2169        }
2170
2171        // ── Anything else: explicit panic with op name ──
2172        //
2173        // All ops in the IR have either a per-op VJP rule above
2174        // or a pre-pass rewrite that decomposes them into ops
2175        // that do:
2176        //   * Op::If → control_flow::inline_if (Where + inlined
2177        //     branches).
2178        //   * Op::While → control_flow::unroll_while (bounded
2179        //     unroll up to max_iterations).
2180        //   * Op::SelectiveScan / Op::FusedTransformerLayer /
2181        //     Op::FusedAttentionBlock / Op::FusedSwiGLU /
2182        //     Op::LoraMatMul / Op::FusedMatMulBiasAct /
2183        //     Op::FusedResidualLN → rlx_fusion::unfuse_fused_for_autodiff.
2184        //
2185        // User-defined sub-graph (Op::CustomFn) with override AD.
2186        // When `vjp_body` is supplied, inline it into `bwd`: each
2187        // primal Op::Input maps to the outer forward NodeId for that
2188        // primal; the special-named "primal_output" Input maps to the
2189        // forward NodeId of this CustomFn node; "d_output" maps to
2190        // `upstream`. The body's N outputs become this op's N input
2191        // gradients in declaration order.
2192        Op::CustomFn {
2193            vjp_body: Some(vjp_body),
2194            num_inputs,
2195            ..
2196        } => {
2197            // Map vjp_body NodeIds → bwd NodeIds.
2198            let mut sub_to_bwd: HashMap<NodeId, NodeId> = HashMap::new();
2199
2200            // Collect primal-input NodeIds from vjp_body (excluding
2201            // special names), sorted by NodeId. Position k in this list
2202            // matches the outer node's input k.
2203            let mut primal_input_ids: Vec<NodeId> = vjp_body
2204                .nodes()
2205                .iter()
2206                .filter_map(|n| match &n.op {
2207                    Op::Input { name } if name != "primal_output" && name != "d_output" => {
2208                        Some(n.id)
2209                    }
2210                    _ => None,
2211                })
2212                .collect();
2213            primal_input_ids.sort();
2214            assert_eq!(primal_input_ids.len(), *num_inputs as usize);
2215
2216            // Walk vjp_body in declaration order, cloning each non-Input
2217            // node into bwd with input remapping.
2218            for sub_node in vjp_body.nodes() {
2219                let new_id = match &sub_node.op {
2220                    Op::Input { name } if name == "primal_output" => fwd_map[&node.id],
2221                    Op::Input { name } if name == "d_output" => upstream,
2222                    Op::Input { .. } => {
2223                        // Find this Input's index in primal_input_ids.
2224                        let idx = primal_input_ids
2225                            .iter()
2226                            .position(|&id| id == sub_node.id)
2227                            .expect(
2228                                "custom_fn vjp_body: primal Input \
2229                                     not found in primal list",
2230                            );
2231                        fwd_map[&node.inputs[idx]]
2232                    }
2233                    _ => {
2234                        let new_inputs: Vec<NodeId> =
2235                            sub_node.inputs.iter().map(|i| sub_to_bwd[i]).collect();
2236                        bwd.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
2237                    }
2238                };
2239                sub_to_bwd.insert(sub_node.id, new_id);
2240            }
2241
2242            // Collect outputs in set_outputs order — each maps to a
2243            // primal-input gradient.
2244            let mut grads: Vec<(usize, NodeId)> = Vec::with_capacity(*num_inputs as usize);
2245            for (i, out_id) in vjp_body.outputs.iter().enumerate() {
2246                grads.push((i, sub_to_bwd[out_id]));
2247            }
2248            grads
2249        }
2250
2251        // CustomFn without vjp_body is inlined by `inline_custom_fn_for_autodiff`
2252        // before the reverse walk — reaching here means the pre-pass missed it.
2253        Op::CustomFn { vjp_body: None, .. } => {
2254            panic!(
2255                "autodiff: Op::CustomFn has no vjp_body and was not inlined. \
2256                 This is an internal error in inline_custom_fn_for_autodiff."
2257            )
2258        }
2259
2260        // User-registered custom op — dispatch the VJP through the
2261        // op registry. The impl emits gradient nodes via the same
2262        // `bwd` builder built-in arms use; default impl returns
2263        // `vec![]` (non-differentiable).
2264        Op::Custom { name, .. } => {
2265            let ext = rlx_ir::lookup_op(name).unwrap_or_else(|| {
2266                panic!(
2267                    "autodiff: Op::Custom('{name}') is not registered \
2268                        in the op registry — register it via \
2269                        rlx_ir::register_op before compiling the graph"
2270                )
2271            });
2272            let mut ctx = rlx_ir::VjpContext {
2273                upstream,
2274                fwd_map,
2275                bwd,
2276            };
2277            ext.vjp(node, &mut ctx)
2278        }
2279
2280        Op::Conv2dBackwardInput {
2281            kernel_size,
2282            stride,
2283            padding,
2284            dilation,
2285            groups,
2286        } => {
2287            let dy_bwd = fwd_map[&node.inputs[0]];
2288            let w_bwd = fwd_map[&node.inputs[1]];
2289            let dy_shape = bwd.node(dy_bwd).shape.clone();
2290            let _x_shape = node.shape.clone();
2291            let d_dy = bwd.add_node(
2292                Op::Conv {
2293                    kernel_size: kernel_size.clone(),
2294                    stride: stride.clone(),
2295                    padding: padding.clone(),
2296                    dilation: dilation.clone(),
2297                    groups: *groups,
2298                },
2299                vec![upstream, w_bwd],
2300                dy_shape,
2301            );
2302            vec![(0, d_dy)]
2303        }
2304
2305        Op::Conv2dBackwardWeight {
2306            kernel_size,
2307            stride,
2308            padding,
2309            dilation,
2310            groups,
2311        } => {
2312            let x_bwd = fwd_map[&node.inputs[0]];
2313            let dy_bwd = fwd_map[&node.inputs[1]];
2314            let x_shape = bwd.node(x_bwd).shape.clone();
2315            let dy_shape = bwd.node(dy_bwd).shape.clone();
2316            let d_x = bwd.conv2d_backward_input(
2317                dy_bwd,
2318                upstream,
2319                x_shape,
2320                kernel_size.clone(),
2321                stride.clone(),
2322                padding.clone(),
2323                dilation.clone(),
2324                *groups,
2325            );
2326            let d_dy = bwd.add_node(
2327                Op::Conv {
2328                    kernel_size: kernel_size.clone(),
2329                    stride: stride.clone(),
2330                    padding: padding.clone(),
2331                    dilation: dilation.clone(),
2332                    groups: *groups,
2333                },
2334                vec![x_bwd, upstream],
2335                dy_shape,
2336            );
2337            vec![(0, d_x), (1, d_dy)]
2338        }
2339
2340        // 1D FFT: y = fft(x; inverse). Both forward and inverse are
2341        // unnormalized linear operators on the 2N real-block layout,
2342        // and the DFT matrix's transpose (over the real-block view)
2343        // equals the unnormalized inverse DFT. So:
2344        //   VJP(fft)  = ifft(upstream)
2345        //   VJP(ifft) = fft(upstream)
2346        // No scaling — the choice to leave both directions unnormalized
2347        // makes the chain rule a flag flip and nothing else.
2348        Op::Fft { inverse, norm } => {
2349            let n = rlx_ir::fft::fft_meta(bwd.shape(node.inputs[0])).n_complex;
2350            let s = norm.output_scale(n, *inverse) as f32;
2351            let z = if s != 1.0 {
2352                let sc = scalar_const(s, bwd);
2353                bwd.mul(upstream, sc)
2354            } else {
2355                upstream
2356            };
2357            let dx = bwd.fft(z, !*inverse);
2358            vec![(0, dx)]
2359        }
2360
2361        Op::LogMel => {
2362            let spec_bwd = fwd_map[&node.inputs[0]];
2363            let filt_bwd = fwd_map[&node.inputs[1]];
2364            let dx = bwd.log_mel_backward(spec_bwd, filt_bwd, upstream);
2365            vec![(0, dx)]
2366        }
2367
2368        // The catch-all below remains as a safety net: if a
2369        // future op is added without a VJP rule, this panic
2370        // names it for the implementer.
2371        other => panic!(
2372            "autodiff: no VJP rule for {other}. See the matching \
2373             entry in rlx-opt/src/autodiff.rs (catch-all panic) for \
2374             a pointer to what's needed to differentiate this op.",
2375        ),
2376    }
2377}
2378
2379/// Decompose tier-2 fused ops back to their primitive components so
2380/// the per-op VJP rules cover them. Conceptually identical to what a
2381/// "training-aware compile" pipeline would do as a pre-pass: avoid
2382/// running `FuseMatMulBiasAct` / `FuseResidualLN` / `FuseSwiGLU` /
2383/// `FuseAttentionBlock` if you plan to autodiff afterward. This
2384/// helper handles the case where they're already in the graph (e.g.
2385/// from a re-trained inference model).
2386///
2387/// Decomposed today: `FusedMatMulBiasAct`, `FusedResidualLN`,
2388/// `LoraMatMul`, `FusedSwiGLU`, `FusedAttentionBlock`,
2389/// `FusedTransformerLayer`, and `SelectiveScan` / `GatedDeltaNet` —
2390/// each rewritten back to its primitive chain (matmul / narrow / attention /
2391/// layer_norm / residual / activation, plus reduce-sum / concat /
2392/// Pre-AD pass: convert every `Op::Scan { save_trajectory: false }`
2393/// into `Op::Scan { save_trajectory: true }` followed by `Narrow` +
2394/// `Reshape` to extract the final carry. After this rewrite, every
2395/// scan in the graph carries its full trajectory — which is what the
2396/// VJP rule needs to compute backward through time. The user-facing
2397/// shape is unchanged (Narrow + Reshape collapse [length, *carry]
2398/// back down to *carry).
2399///
2400/// Memory cost: trajectory storage is now `O(length × carry_size)`
2401/// for the duration of the forward + backward pass. For Diffrax-style
2402/// transients this is the same as Diffrax's `RecursiveCheckpointAdjoint::All`
2403/// strategy. Recursive checkpointing is a future pass.
2404/// Pre-AD pass: rewrite `Op::Scan` nodes with `num_bcast > 0` into
2405/// equivalent `num_bcast = 0` scans by materialising each broadcast
2406/// input `b` of shape `*bcast` into a per-step xs of shape
2407/// `[length, *bcast]` (built as `ones([length, *bcast]) × b`). The
2408/// reverse-mode AD walk and the rest of `convert_scans_for_ad` then
2409/// see only carry+xs scans — the bcast channel is a forward-only
2410/// memory optimisation, transparent to backward.
2411fn materialize_bcasts_for_ad(g: Graph) -> Graph {
2412    use rlx_ir::op::BinaryOp;
2413
2414    let needs = g.nodes().iter().any(|n| {
2415        matches!(
2416            &n.op, Op::Scan { num_bcast, .. } if *num_bcast > 0
2417        )
2418    });
2419    if !needs {
2420        return g;
2421    }
2422
2423    let mut out = Graph::new(g.name.clone());
2424    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2425
2426    for node in g.nodes() {
2427        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2428        match &node.op {
2429            Op::Scan {
2430                body,
2431                length,
2432                save_trajectory,
2433                num_bcast,
2434                num_xs,
2435                num_checkpoints,
2436            } if *num_bcast > 0 => {
2437                // Each bcast input gets multiplied by an
2438                // `[length, 1, ..., 1]` ones constant of matching dtype
2439                // (broadcast against the bcast's natural shape) to
2440                // produce a `[length, *bcast]` materialised xs.
2441                let bcast_base = 1;
2442                let xs_base = 1 + *num_bcast as usize;
2443
2444                let mut new_scan_inputs = vec![new_inputs[0]];
2445
2446                // Original xs first remain xs.
2447                let mut materialised_xs: Vec<NodeId> = Vec::new();
2448                for i in 0..*num_bcast as usize {
2449                    let b_id = new_inputs[bcast_base + i];
2450                    let b_shape = out.node(b_id).shape.clone();
2451                    let dtype = b_shape.dtype();
2452
2453                    // ones with shape [length, 1, 1, ...] (matching b's rank
2454                    // beyond the leading axis we're prepending). Broadcast
2455                    // against b of shape [*bcast] gives [length, *bcast].
2456                    let mut ones_dims: Vec<rlx_ir::Dim> =
2457                        vec![rlx_ir::Dim::Static(*length as usize)];
2458                    for _ in 0..b_shape.rank() {
2459                        ones_dims.push(rlx_ir::Dim::Static(1));
2460                    }
2461                    let ones_shape = rlx_ir::Shape::from_dims(&ones_dims, dtype);
2462                    let n_elems: usize = ones_dims
2463                        .iter()
2464                        .map(|d| match d {
2465                            rlx_ir::Dim::Static(n) => *n,
2466                            rlx_ir::Dim::Dynamic(_) => 1,
2467                        })
2468                        .product();
2469                    let elem_size = dtype.size_bytes();
2470                    let mut data = Vec::with_capacity(n_elems * elem_size);
2471                    match dtype {
2472                        rlx_ir::DType::F64 => {
2473                            for _ in 0..n_elems {
2474                                data.extend_from_slice(&1.0_f64.to_le_bytes());
2475                            }
2476                        }
2477                        rlx_ir::DType::F32 => {
2478                            for _ in 0..n_elems {
2479                                data.extend_from_slice(&1.0_f32.to_le_bytes());
2480                            }
2481                        }
2482                        other => {
2483                            panic!("materialize_bcasts_for_ad: unsupported bcast dtype {other:?}")
2484                        }
2485                    }
2486                    let ones = out.add_node(Op::Constant { data }, vec![], ones_shape);
2487
2488                    // Output shape of broadcast Mul: [length, *bcast].
2489                    let mut xs_dims: Vec<rlx_ir::Dim> = vec![rlx_ir::Dim::Static(*length as usize)];
2490                    for i in 0..b_shape.rank() {
2491                        xs_dims.push(b_shape.dim(i));
2492                    }
2493                    let xs_shape = rlx_ir::Shape::from_dims(&xs_dims, dtype);
2494                    let xs_id = out.add_node(Op::Binary(BinaryOp::Mul), vec![ones, b_id], xs_shape);
2495                    materialised_xs.push(xs_id);
2496                }
2497
2498                new_scan_inputs.extend_from_slice(&materialised_xs);
2499                for i in 0..*num_xs as usize {
2500                    new_scan_inputs.push(new_inputs[xs_base + i]);
2501                }
2502
2503                let new_id = out.add_node(
2504                    Op::Scan {
2505                        body: body.clone(),
2506                        length: *length,
2507                        save_trajectory: *save_trajectory,
2508                        num_bcast: 0,
2509                        num_xs: *num_bcast + *num_xs,
2510                        num_checkpoints: *num_checkpoints,
2511                    },
2512                    new_scan_inputs,
2513                    node.shape.clone(),
2514                );
2515                id_map.insert(node.id, new_id);
2516            }
2517            _ => {
2518                let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2519                id_map.insert(node.id, new_id);
2520            }
2521        }
2522    }
2523
2524    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2525    out.set_outputs(new_outputs);
2526    out
2527}
2528
2529pub fn convert_scans_for_ad(g: Graph) -> Graph {
2530    use rlx_ir::shape::Shape as IrShape;
2531
2532    // First, materialise broadcast inputs into per-step xs. The AD
2533    // walk and the rest of this pre-pass don't know about bcasts
2534    // (forward-only memory optimisation); after this rewrite the bwd
2535    // graph treats them as regular xs.
2536    let g = materialize_bcasts_for_ad(g);
2537
2538    // Quick check: does any scan need rewriting? Avoid a full graph
2539    // rebuild when the input is already trajectory-only.
2540    let needs = g.nodes().iter().any(|n| {
2541        matches!(
2542            &n.op,
2543            Op::Scan {
2544                save_trajectory: false,
2545                ..
2546            }
2547        )
2548    });
2549    if !needs {
2550        return g;
2551    }
2552
2553    let mut out = Graph::new(g.name.clone());
2554    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2555
2556    for node in g.nodes() {
2557        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2558        match &node.op {
2559            Op::Scan {
2560                body,
2561                length,
2562                save_trajectory: false,
2563                num_xs,
2564                num_checkpoints,
2565                ..
2566            } => {
2567                let carry_shape = node.shape.clone();
2568                // Trajectory shape: [length, *carry_shape].
2569                //
2570                // NB: when `num_checkpoints` is set (recursive
2571                // checkpointing), the executor only writes `K` rows
2572                // into this buffer (the saved checkpoints, indexed by
2573                // k=0..K-1 at offsets 0..K·cb). Rows K..length-1 stay
2574                // zero. The Narrow + Reshape below extracts row
2575                // `length-1`, which is **zero** in checkpointed mode
2576                // — i.e. the rewritten forward output is wrong (the
2577                // FORWARD value of `scan_checkpointed` followed by a
2578                // direct read is not currently supported).
2579                //
2580                // Backward gradients are still correct: Narrow's VJP
2581                // scatters the upstream into row `length-1` of the
2582                // gradient tensor, ScanBackward reads upstream[t·cb]
2583                // for t in 0..length finds zero everywhere except at
2584                // t=length-1 where it picks up `d_loss`, and the
2585                // segment-cached recompute uses the K saved
2586                // checkpoints (at offsets 0..K·cb) plus the forward
2587                // body to reconstruct intermediate carries.
2588                let mut traj_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2589                traj_dims.push(Dim::Static(*length as usize));
2590                for i in 0..carry_shape.rank() {
2591                    traj_dims.push(carry_shape.dim(i));
2592                }
2593                let traj_shape = IrShape::from_dims(&traj_dims, carry_shape.dtype());
2594                let traj = out.add_node(
2595                    Op::Scan {
2596                        body: body.clone(),
2597                        length: *length,
2598                        save_trajectory: true,
2599                        num_bcast: 0,
2600                        num_xs: *num_xs,
2601                        num_checkpoints: *num_checkpoints,
2602                    },
2603                    new_inputs,
2604                    traj_shape,
2605                );
2606                // Narrow last row → [1, *carry].
2607                let mut narrow_dims: Vec<Dim> = Vec::with_capacity(carry_shape.rank() + 1);
2608                narrow_dims.push(Dim::Static(1));
2609                for i in 0..carry_shape.rank() {
2610                    narrow_dims.push(carry_shape.dim(i));
2611                }
2612                let narrow_shape = IrShape::from_dims(&narrow_dims, carry_shape.dtype());
2613                let narrowed = out.add_node(
2614                    Op::Narrow {
2615                        axis: 0,
2616                        start: (*length as usize).saturating_sub(1),
2617                        len: 1,
2618                    },
2619                    vec![traj],
2620                    narrow_shape,
2621                );
2622                // Reshape to drop the leading 1 → carry_shape.
2623                let new_shape: Vec<i64> = (0..carry_shape.rank())
2624                    .map(|i| match carry_shape.dim(i) {
2625                        Dim::Static(n) => n as i64,
2626                        Dim::Dynamic(_) => -1,
2627                    })
2628                    .collect();
2629                let final_id = out.add_node(Op::Reshape { new_shape }, vec![narrowed], carry_shape);
2630                id_map.insert(node.id, final_id);
2631            }
2632            _ => {
2633                let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
2634                id_map.insert(node.id, new_id);
2635            }
2636        }
2637    }
2638
2639    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|o| id_map[o]).collect();
2640    out.set_outputs(new_outputs);
2641    out
2642}
2643
2644/// Pre-AD pass: inline `Op::CustomFn` nodes that have neither a
2645/// `vjp_body` nor a `jvp_body` by expanding their `fwd_body` into the
2646/// parent graph. When either override body is present, keep the
2647/// `CustomFn` wrapper so reverse- / forward-mode AD can dispatch to it.
2648pub fn inline_custom_fn_for_autodiff(g: Graph) -> Graph {
2649    use rlx_fusion::control_flow::inline_subgraph_into;
2650
2651    let mut out = Graph::new(g.name.clone());
2652    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
2653    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
2654
2655    for node in &nodes {
2656        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
2657        let new_id = match &node.op {
2658            Op::CustomFn {
2659                vjp_body: None,
2660                jvp_body: None,
2661                fwd_body,
2662                num_inputs,
2663                ..
2664            } => {
2665                assert_eq!(
2666                    new_inputs.len(),
2667                    *num_inputs as usize,
2668                    "custom_fn: outer input count mismatch"
2669                );
2670                inline_subgraph_into(fwd_body, &new_inputs, &mut out)
2671            }
2672            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
2673        };
2674        id_map.insert(node.id, new_id);
2675    }
2676
2677    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
2678    out.set_outputs(new_outputs);
2679    out
2680}
2681
2682/// Inverse of `unbroadcast`: broadcast a small tensor up to a target
2683/// shape via `Op::Expand`. Convenience wrapper for the few VJPs that
2684/// need it.
2685pub(crate) fn unbroadcast_inverse(x: NodeId, target: &Shape, bwd: &mut Graph) -> NodeId {
2686    let target_dims: Vec<i64> = target
2687        .dims()
2688        .iter()
2689        .map(|d| match d {
2690            Dim::Static(n) => *n as i64,
2691            Dim::Dynamic(_) => -1,
2692        })
2693        .collect();
2694    bwd.add_node(
2695        Op::Expand {
2696            target_shape: target_dims,
2697        },
2698        vec![x],
2699        target.clone(),
2700    )
2701}
2702
2703/// Expand a gradient back to its pre-reduction shape: optionally
2704/// reshape to insert size-1 axes (when forward had `keep_dim=false`),
2705/// then `Op::Expand` to broadcast to `x_shape`. The reverse of
2706/// `Reduce::Sum`.
2707fn expand_to(
2708    grad: NodeId,
2709    x_shape: &Shape,
2710    axes: &[usize],
2711    keep_dim: bool,
2712    bwd: &mut Graph,
2713) -> NodeId {
2714    let mut current = grad;
2715    if !keep_dim {
2716        // Insert size-1 axes at the reduced positions so the rank
2717        // matches x_shape and Expand can broadcast cleanly.
2718        let kept_dims: Vec<Dim> = (0..x_shape.rank())
2719            .map(|i| {
2720                if axes.contains(&i) {
2721                    Dim::Static(1)
2722                } else {
2723                    x_shape.dim(i)
2724                }
2725            })
2726            .collect();
2727        let kept = Shape::from_dims(&kept_dims, x_shape.dtype());
2728        current = reshape_to(current, &kept, bwd);
2729    }
2730    let target_shape: Vec<i64> = x_shape
2731        .dims()
2732        .iter()
2733        .map(|d| match d {
2734            Dim::Static(n) => *n as i64,
2735            Dim::Dynamic(_) => -1,
2736        })
2737        .collect();
2738    bwd.add_node(Op::Expand { target_shape }, vec![current], x_shape.clone())
2739}
2740
2741#[cfg(test)]
2742mod tests {
2743    use super::*;
2744
2745    #[test]
2746    fn grad_of_add_is_identity() {
2747        let mut g = Graph::new("test");
2748        let x = g.input("x", Shape::new(&[4], DType::F32));
2749        let y = g.input("y", Shape::new(&[4], DType::F32));
2750        let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2751        g.set_outputs(vec![z]);
2752
2753        let bwd = grad(&g, &[x, y]);
2754        // bwd graph should expose two outputs: dz/dx and dz/dy, both = d_output.
2755        assert_eq!(bwd.outputs.len(), 2);
2756    }
2757
2758    #[test]
2759    fn grad_of_mul_uses_other_operand() {
2760        let mut g = Graph::new("test");
2761        let x = g.input("x", Shape::new(&[4], DType::F32));
2762        let y = g.input("y", Shape::new(&[4], DType::F32));
2763        let z = g.binary(BinaryOp::Mul, x, y, Shape::new(&[4], DType::F32));
2764        g.set_outputs(vec![z]);
2765
2766        let bwd = grad(&g, &[x, y]);
2767        // bwd should contain Mul nodes (upstream * y, upstream * x).
2768        assert!(
2769            bwd.nodes()
2770                .iter()
2771                .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
2772                .count()
2773                >= 2
2774        );
2775    }
2776
2777    #[test]
2778    fn grad_with_loss_returns_loss_first() {
2779        let mut g = Graph::new("loss");
2780        let x = g.input("x", Shape::new(&[4], DType::F32));
2781        let y = g.input("y", Shape::new(&[4], DType::F32));
2782        let z = g.binary(BinaryOp::Add, x, y, Shape::new(&[4], DType::F32));
2783        g.set_outputs(vec![z]);
2784
2785        let bwd = grad_with_loss(&g, &[x, y]);
2786        // [loss, dz/dx, dz/dy] — three outputs.
2787        assert_eq!(bwd.outputs.len(), 3);
2788    }
2789
2790    #[test]
2791    fn grad_of_dense_solve_emits_implicit_function_rule() {
2792        // Forward:
2793        //   A      : Param [2,2]
2794        //   b      : Input [2]
2795        //   x      = solve(A, b)
2796        //   loss   = sum(x)         (scalar)
2797        //
2798        // Backward must contain:
2799        //   - a Transpose of A
2800        //   - a second DenseSolve (dx_int = solve(Aᵀ, upstream))
2801        //   - a MatMul (the outer product dx_int · xᵀ)
2802        //   - a Neg (the −outer)
2803        //
2804        // Outputs are [loss, dA, db].
2805        let mut g = Graph::new("solve_test");
2806        let a = g.param("A", Shape::new(&[2, 2], DType::F32));
2807        let b = g.input("b", Shape::new(&[2], DType::F32));
2808        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F32));
2809        let loss = g.reduce(
2810            x,
2811            ReduceOp::Sum,
2812            vec![0],
2813            false,
2814            Shape::new(&[1], DType::F32),
2815        );
2816        g.set_outputs(vec![loss]);
2817
2818        let bwd = grad_with_loss(&g, &[a, b]);
2819        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
2820
2821        let count =
2822            |pred: fn(&Op) -> bool| -> usize { bwd.nodes().iter().filter(|n| pred(&n.op)).count() };
2823
2824        // Forward is mirrored into bwd, so we expect 1 + 1 = 2 DenseSolves
2825        // (forward copy + reverse).
2826        assert!(
2827            count(|o| matches!(o, Op::DenseSolve)) >= 2,
2828            "expected ≥2 DenseSolve nodes (forward mirror + reverse), got\n{bwd}"
2829        );
2830        assert!(
2831            count(|o| matches!(o, Op::Transpose { .. })) >= 1,
2832            "expected a Transpose for Aᵀ, got\n{bwd}"
2833        );
2834        assert!(
2835            count(|o| matches!(o, Op::MatMul)) >= 1,
2836            "expected a MatMul for the outer product, got\n{bwd}"
2837        );
2838        assert!(
2839            count(|o| matches!(o, Op::Activation(Activation::Neg))) >= 1,
2840            "expected a Neg for −outer, got\n{bwd}"
2841        );
2842    }
2843
2844    #[test]
2845    fn inline_if_replaces_with_where() {
2846        // Build a parent graph:
2847        //   x       : Input
2848        //   pred    : Input (scalar bool)
2849        //   then_b  : sub-graph with Input(0) → Activation(Relu)
2850        //   else_b  : sub-graph with Input(0) → Activation(Sigmoid)
2851        //   out     : If(pred, [x] -> then_b, else_b)
2852        let s = Shape::new(&[4], DType::F32);
2853        let pred_s = Shape::new(&[1], DType::F32);
2854
2855        let mut then_g = Graph::new("then_branch");
2856        let then_in = then_g.input("captured", s.clone());
2857        let then_out = then_g.activation(Activation::Relu, then_in, s.clone());
2858        then_g.set_outputs(vec![then_out]);
2859
2860        let mut else_g = Graph::new("else_branch");
2861        let else_in = else_g.input("captured", s.clone());
2862        let else_out = else_g.activation(Activation::Sigmoid, else_in, s.clone());
2863        else_g.set_outputs(vec![else_out]);
2864
2865        let mut g = Graph::new("parent");
2866        let x = g.input("x", s.clone());
2867        let pred = g.input("pred", pred_s);
2868        let if_out = g.add_node(
2869            Op::If {
2870                then_branch: Box::new(then_g),
2871                else_branch: Box::new(else_g),
2872            },
2873            vec![pred, x],
2874            s,
2875        );
2876        g.set_outputs(vec![if_out]);
2877
2878        let inlined = rlx_fusion::control_flow::inline_if(g);
2879
2880        // After inlining: no Op::If, exactly one Op::Where, one
2881        // Activation(Relu), one Activation(Sigmoid). Inputs (x,
2882        // pred) and the original output count are preserved.
2883        let has_if = inlined
2884            .nodes()
2885            .iter()
2886            .any(|n| matches!(n.op, Op::If { .. }));
2887        let has_where = inlined.nodes().iter().any(|n| matches!(n.op, Op::Where));
2888        let has_relu = inlined
2889            .nodes()
2890            .iter()
2891            .any(|n| matches!(n.op, Op::Activation(Activation::Relu)));
2892        let has_sigmoid = inlined
2893            .nodes()
2894            .iter()
2895            .any(|n| matches!(n.op, Op::Activation(Activation::Sigmoid)));
2896        assert!(!has_if, "Op::If should be inlined away");
2897        assert!(has_where, "Op::Where should replace the Op::If");
2898        assert!(has_relu, "then_branch's Activation(Relu) should be inlined");
2899        assert!(
2900            has_sigmoid,
2901            "else_branch's Activation(Sigmoid) should be inlined"
2902        );
2903        assert_eq!(inlined.outputs.len(), 1);
2904    }
2905
2906    #[test]
2907    fn grad_through_if_propagates() {
2908        // Sanity: autodiff a graph with Op::If and confirm it
2909        // produces a gradient (the Where VJP handles the join).
2910        let s = Shape::new(&[4], DType::F32);
2911        let pred_s = Shape::new(&[1], DType::F32);
2912
2913        let mut then_g = Graph::new("th");
2914        let ti = then_g.input("c", s.clone());
2915        let to = then_g.binary(BinaryOp::Mul, ti, ti, s.clone());
2916        then_g.set_outputs(vec![to]);
2917
2918        let mut else_g = Graph::new("el");
2919        let ei = else_g.input("c", s.clone());
2920        let eo = else_g.activation(Activation::Relu, ei, s.clone());
2921        else_g.set_outputs(vec![eo]);
2922
2923        let mut g = Graph::new("parent");
2924        let x = g.input("x", s.clone());
2925        let pred = g.input("pred", pred_s);
2926        let z = g.add_node(
2927            Op::If {
2928                then_branch: Box::new(then_g),
2929                else_branch: Box::new(else_g),
2930            },
2931            vec![pred, x],
2932            s,
2933        );
2934        g.set_outputs(vec![z]);
2935
2936        let bwd = grad_with_loss(&g, &[x]);
2937        // [loss, dz/dx] — two outputs.
2938        assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
2939    }
2940
2941    #[test]
2942    fn unroll_while_replicates_body_n_times() {
2943        // Build a parent graph:
2944        //   x   : Input
2945        //   out : While(cond=trivial, body=Activation(Relu), N=3)
2946        // After unrolling we expect zero Op::While, three Activation
2947        // (Relu) nodes (one per replica).
2948        let s = Shape::new(&[4], DType::F32);
2949        let bool_s = Shape::new(&[1], DType::F32);
2950
2951        let mut cond_g = Graph::new("cond");
2952        let ci = cond_g.input("c", s.clone());
2953        // dummy bool: just feed input through (cond is not evaluated
2954        // by the unroll pass, so its body doesn't matter).
2955        cond_g.set_outputs(vec![ci]);
2956        // Replace output shape: cond's output is logically a scalar
2957        // bool — but the unroll pass never inspects it.
2958        let _ = bool_s;
2959
2960        let mut body_g = Graph::new("body");
2961        let bi = body_g.input("c", s.clone());
2962        let bo = body_g.activation(Activation::Relu, bi, s.clone());
2963        body_g.set_outputs(vec![bo]);
2964
2965        let mut g = Graph::new("parent");
2966        let x = g.input("x", s.clone());
2967        let w = g.add_node(
2968            Op::While {
2969                cond: Box::new(cond_g),
2970                body: Box::new(body_g),
2971                max_iterations: Some(3),
2972            },
2973            vec![x],
2974            s,
2975        );
2976        g.set_outputs(vec![w]);
2977
2978        let unrolled = rlx_fusion::control_flow::unroll_while(g);
2979
2980        let has_while = unrolled
2981            .nodes()
2982            .iter()
2983            .any(|n| matches!(n.op, Op::While { .. }));
2984        let relu_count = unrolled
2985            .nodes()
2986            .iter()
2987            .filter(|n| matches!(n.op, Op::Activation(Activation::Relu)))
2988            .count();
2989        assert!(!has_while, "Op::While should be unrolled away");
2990        assert_eq!(
2991            relu_count, 3,
2992            "body's Activation(Relu) should appear once per iteration"
2993        );
2994        assert_eq!(unrolled.outputs.len(), 1);
2995    }
2996
2997    #[test]
2998    fn grad_through_while_propagates() {
2999        // Sanity: autodiff a graph with Op::While and confirm the
3000        // gradient pipeline produces a result (the unroll pass turns
3001        // it into a chain of body replicas before the gradient walk).
3002        let s = Shape::new(&[4], DType::F32);
3003
3004        let mut cond_g = Graph::new("cond");
3005        let ci = cond_g.input("c", s.clone());
3006        cond_g.set_outputs(vec![ci]);
3007
3008        let mut body_g = Graph::new("body");
3009        let bi = body_g.input("c", s.clone());
3010        let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
3011        body_g.set_outputs(vec![bo]);
3012
3013        let mut g = Graph::new("parent");
3014        let x = g.input("x", s.clone());
3015        let w = g.add_node(
3016            Op::While {
3017                cond: Box::new(cond_g),
3018                body: Box::new(body_g),
3019                max_iterations: Some(2),
3020            },
3021            vec![x],
3022            s,
3023        );
3024        g.set_outputs(vec![w]);
3025
3026        let bwd = grad_with_loss(&g, &[x]);
3027        assert_eq!(bwd.outputs.len(), 2, "expected loss + 1 grad output");
3028    }
3029
3030    /// Build a tiny BERT-style FTL graph with the given bias mode.
3031    /// Returns (graph, hidden_input_id, all_param_ids).
3032    fn build_ftl_graph(has_bias: bool) -> (Graph, NodeId, Vec<NodeId>) {
3033        // B=1, S=2, hidden=4, heads=2, head_dim=2, intermediate=8.
3034        let mut g = Graph::new("ftl_test");
3035        let h_shape = Shape::new(&[1, 2, 4], DType::F32);
3036        let h = g.input("h", h_shape.clone());
3037        let qkv_w = g.param("qkv_w", Shape::new(&[4, 12], DType::F32));
3038        let out_w = g.param("out_w", Shape::new(&[4, 4], DType::F32));
3039        let ln1_g = g.param("ln1_g", Shape::new(&[4], DType::F32));
3040        let fc1_w = g.param("fc1_w", Shape::new(&[4, 8], DType::F32));
3041        let fc2_w = g.param("fc2_w", Shape::new(&[8, 4], DType::F32));
3042        let ln2_g = g.param("ln2_g", Shape::new(&[4], DType::F32));
3043        let mask = g.input("mask", Shape::new(&[1, 2, 2, 2], DType::F32));
3044
3045        let (inputs, params) = if has_bias {
3046            let qkv_b = g.param("qkv_b", Shape::new(&[12], DType::F32));
3047            let out_b = g.param("out_b", Shape::new(&[4], DType::F32));
3048            let ln1_b = g.param("ln1_b", Shape::new(&[4], DType::F32));
3049            let fc1_b = g.param("fc1_b", Shape::new(&[8], DType::F32));
3050            let fc2_b = g.param("fc2_b", Shape::new(&[4], DType::F32));
3051            let ln2_b = g.param("ln2_b", Shape::new(&[4], DType::F32));
3052            (
3053                vec![
3054                    h, qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3055                    ln2_b, mask,
3056                ],
3057                vec![
3058                    qkv_w, qkv_b, out_w, out_b, ln1_g, ln1_b, fc1_w, fc1_b, fc2_w, fc2_b, ln2_g,
3059                    ln2_b,
3060                ],
3061            )
3062        } else {
3063            (
3064                vec![h, qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g, mask],
3065                vec![qkv_w, out_w, ln1_g, fc1_w, fc2_w, ln2_g],
3066            )
3067        };
3068        let y = g.add_node(
3069            Op::FusedTransformerLayer {
3070                num_heads: 2,
3071                head_dim: 2,
3072                intermediate_size: 8,
3073                eps1: 1e-5,
3074                eps2: 1e-5,
3075                activation: rlx_ir::op::Activation::Gelu,
3076                has_bias,
3077            },
3078            inputs,
3079            h_shape,
3080        );
3081        g.set_outputs(vec![y]);
3082        (g, h, params)
3083    }
3084
3085    #[test]
3086    fn unfuse_decomposes_fused_transformer_layer() {
3087        // After rlx_fusion::unfuse_fused_for_autodiff, the FTL node is gone and
3088        // primitives appear: at least 4 MatMul (qkv, out, fc1, fc2),
3089        // 1 Attention, 2 LayerNorm, plus narrows / adds / activation.
3090        let (g, _h, _params) = build_ftl_graph(true);
3091        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3092
3093        let has_ftl = unfused
3094            .nodes()
3095            .iter()
3096            .any(|n| matches!(n.op, Op::FusedTransformerLayer { .. }));
3097        assert!(!has_ftl, "Op::FusedTransformerLayer should be unfused");
3098
3099        let count = |pred: fn(&Op) -> bool| -> usize {
3100            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3101        };
3102        assert!(
3103            count(|o| matches!(o, Op::MatMul)) >= 4,
3104            "expected >=4 MatMul after FTL unfuse"
3105        );
3106        assert_eq!(
3107            count(|o| matches!(o, Op::Attention { .. })),
3108            1,
3109            "expected exactly 1 Attention after FTL unfuse"
3110        );
3111        assert_eq!(
3112            count(|o| matches!(o, Op::LayerNorm { .. })),
3113            2,
3114            "expected exactly 2 LayerNorm after FTL unfuse"
3115        );
3116        assert!(
3117            count(|o| matches!(o, Op::Narrow { .. })) >= 3,
3118            "expected >=3 Narrow (Q/K/V split) after FTL unfuse"
3119        );
3120        assert_eq!(
3121            count(|o| matches!(o, Op::Activation(_))),
3122            1,
3123            "expected exactly 1 Activation (FFN) after FTL unfuse"
3124        );
3125    }
3126
3127    #[test]
3128    fn grad_through_fused_transformer_layer_propagates() {
3129        // End-to-end: grad_with_loss through an FTL graph returns
3130        // [loss, ...grads]. Confirms every primitive emitted by the
3131        // unfuse has a VJP rule on the gradient walk.
3132        let (g, _h, params) = build_ftl_graph(true);
3133        let bwd = grad_with_loss(&g, &params);
3134        assert_eq!(
3135            bwd.outputs.len(),
3136            1 + params.len(),
3137            "expected loss + {} param grads",
3138            params.len()
3139        );
3140    }
3141
3142    #[test]
3143    fn grad_through_fused_transformer_layer_no_bias() {
3144        // No-bias variant exercises the synthesized zero-beta path
3145        // for both LayerNorms.
3146        let (g, _h, params) = build_ftl_graph(false);
3147        let bwd = grad_with_loss(&g, &params);
3148        assert_eq!(
3149            bwd.outputs.len(),
3150            1 + params.len(),
3151            "expected loss + {} param grads (no-bias)",
3152            params.len()
3153        );
3154    }
3155
3156    /// Build a tiny SelectiveScan graph: B=1, S=3, H=2, N=4.
3157    /// Returns (graph, [x, delta, a, b, c]).
3158    fn build_ssm_graph() -> (Graph, NodeId, Vec<NodeId>) {
3159        let mut g = Graph::new("ssm_test");
3160        let bsh = Shape::new(&[1, 3, 2], DType::F32);
3161        let hn = Shape::new(&[2, 4], DType::F32);
3162        let bsn = Shape::new(&[1, 3, 4], DType::F32);
3163
3164        let x = g.input("x", bsh.clone());
3165        let delta = g.input("delta", bsh.clone());
3166        let a = g.param("a", hn);
3167        let b = g.input("b", bsn.clone());
3168        let c = g.input("c", bsn);
3169        let y = g.selective_scan(x, delta, a, b, c, 4, bsh);
3170        g.set_outputs(vec![y]);
3171        (g, x, vec![a])
3172    }
3173
3174    #[test]
3175    fn unfuse_decomposes_selective_scan() {
3176        // After unfuse, no Op::SelectiveScan; instead we see Concat
3177        // (one for S>1), per-step Reduce(Sum), per-step Activation::Exp,
3178        // and many Mul / Add / Narrow / Reshape / Expand nodes.
3179        // S=3 → 3 timesteps.
3180        let (g, _x, _params) = build_ssm_graph();
3181        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3182
3183        let has_ssm = unfused
3184            .nodes()
3185            .iter()
3186            .any(|n| matches!(n.op, Op::SelectiveScan { .. }));
3187        assert!(!has_ssm, "Op::SelectiveScan should be unfused");
3188
3189        let count = |pred: fn(&Op) -> bool| -> usize {
3190            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3191        };
3192        assert_eq!(
3193            count(|o| matches!(o, Op::Concat { .. })),
3194            1,
3195            "expected 1 Concat (over the 3 time steps)"
3196        );
3197        assert_eq!(
3198            count(|o| matches!(
3199                o,
3200                Op::Reduce {
3201                    op: ReduceOp::Sum,
3202                    ..
3203                }
3204            )),
3205            3,
3206            "expected one Reduce(Sum) per time step (S=3)"
3207        );
3208        assert_eq!(
3209            count(|o| matches!(o, Op::Activation(Activation::Exp))),
3210            3,
3211            "expected one exp(δA) per time step (S=3)"
3212        );
3213        assert!(
3214            count(|o| matches!(o, Op::Narrow { .. })) >= 12,
3215            "expected >=12 Narrows (4 per step × 3 steps)"
3216        );
3217    }
3218
3219    #[test]
3220    fn grad_through_selective_scan_propagates() {
3221        // End-to-end: grad_with_loss through SelectiveScan returns
3222        // [loss, da] — confirms every primitive emitted by the
3223        // unroll has a VJP rule on the gradient walk (Mul, Add,
3224        // Activation::Exp, Reduce::Sum, Concat, Narrow, Reshape,
3225        // Expand).
3226        let (g, _x, params) = build_ssm_graph();
3227        let bwd = grad_with_loss(&g, &params);
3228        assert_eq!(
3229            bwd.outputs.len(),
3230            1 + params.len(),
3231            "expected loss + {} param grads",
3232            params.len()
3233        );
3234    }
3235
3236    /// Tiny GatedDeltaNet: B=1, S=3, H=2, N=4.
3237    fn build_gdn_graph() -> (Graph, NodeId, Vec<NodeId>) {
3238        let (b, s, h, n) = (1usize, 3, 2, 4);
3239        let mut g = Graph::new("gdn_test");
3240        let bshn = Shape::new(&[b, s, h, n], DType::F32);
3241        let bsh = Shape::new(&[b, s, h], DType::F32);
3242        let q = g.input("q", bshn.clone());
3243        let k = g.input("k", bshn.clone());
3244        let v = g.input("v", bshn.clone());
3245        let g_in = g.input("g", bsh.clone());
3246        let beta = g.input("beta", bsh);
3247        let y = g.gated_delta_net(q, k, v, g_in, beta, n, bshn);
3248        g.set_outputs(vec![y]);
3249        (g, q, vec![q, k, v, g_in, beta])
3250    }
3251
3252    #[test]
3253    fn unfuse_decomposes_gated_delta_net() {
3254        let (g, _q, _params) = build_gdn_graph();
3255        let unfused = rlx_fusion::unfuse_fused_for_autodiff(g);
3256
3257        let has_gdn = unfused
3258            .nodes()
3259            .iter()
3260            .any(|n| matches!(n.op, Op::GatedDeltaNet { .. }));
3261        assert!(!has_gdn, "Op::GatedDeltaNet should be unfused");
3262
3263        let count = |pred: fn(&Op) -> bool| -> usize {
3264            unfused.nodes().iter().filter(|n| pred(&n.op)).count()
3265        };
3266        assert_eq!(
3267            count(|o| matches!(o, Op::Concat { .. })),
3268            1,
3269            "expected 1 Concat over S=3 steps"
3270        );
3271        assert!(
3272            count(|o| matches!(o, Op::MatMul)) >= 3,
3273            "expected >=3 MatMul per step (sk + out) × S=3"
3274        );
3275        assert_eq!(
3276            count(|o| matches!(o, Op::Activation(Activation::Exp))),
3277            3,
3278            "expected one exp(g) per time step"
3279        );
3280    }
3281
3282    #[test]
3283    fn grad_through_gated_delta_net_propagates() {
3284        let (g, _q, params) = build_gdn_graph();
3285        let bwd = grad_with_loss(&g, &params);
3286        assert_eq!(
3287            bwd.outputs.len(),
3288            1 + params.len(),
3289            "expected loss + {} input grads",
3290            params.len()
3291        );
3292    }
3293
3294    #[test]
3295    fn custom_fn_vjp_body_is_inlined_into_bwd() {
3296        // Forward: y = x² via custom_fn (fwd_body = Mul(x, x)).
3297        // Override VJP to return Activation::Sin(d_output) — a unique
3298        // marker that natural autodiff of Mul would never emit. If
3299        // grad_with_loss inlines the override correctly, the bwd graph
3300        // must contain a Sin node; if it falls back to recursing into
3301        // fwd_body, it would emit two Muls (upstream·x + x·upstream)
3302        // and no Sin.
3303        let n = 4usize;
3304        let shape = Shape::new(&[n], DType::F32);
3305
3306        // fwd_body: x → x · x.
3307        let mut fwd_body = Graph::new("square_fwd");
3308        let xb = fwd_body.input("x", shape.clone());
3309        let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3310        fwd_body.set_outputs(vec![yb]);
3311
3312        // vjp_body: (x, primal_output, d_output) → sin(d_output).
3313        let mut vjp_body = Graph::new("square_vjp");
3314        let _vx = vjp_body.input("x", shape.clone());
3315        let _vp = vjp_body.input("primal_output", shape.clone());
3316        let vd = vjp_body.input("d_output", shape.clone());
3317        let dx = vjp_body.activation(Activation::Sin, vd, shape.clone());
3318        vjp_body.set_outputs(vec![dx]);
3319
3320        let mut g = Graph::new("custom_fn_test");
3321        let x = g.input("x", shape.clone());
3322        let y = g.custom_fn(vec![x], fwd_body, Some(vjp_body), None);
3323        let loss = g.reduce(
3324            y,
3325            ReduceOp::Sum,
3326            vec![0],
3327            false,
3328            Shape::new(&[1], DType::F32),
3329        );
3330        g.set_outputs(vec![loss]);
3331
3332        let bwd = grad_with_loss(&g, &[x]);
3333        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3334        let sin_count = bwd
3335            .nodes()
3336            .iter()
3337            .filter(|n| matches!(n.op, Op::Activation(Activation::Sin)))
3338            .count();
3339        assert!(
3340            sin_count >= 1,
3341            "expected the vjp_body's Sin to be inlined into bwd, got\n{bwd}"
3342        );
3343    }
3344
3345    #[test]
3346    fn custom_fn_without_vjp_inlines_fwd_body_for_autodiff() {
3347        // Forward: y = x² via custom_fn without vjp_body. After the
3348        // inline pre-pass, autodiff should recurse into Mul and emit
3349        // dx = 2·x·d_output (two Mul nodes in the backward graph).
3350        let n = 4usize;
3351        let shape = Shape::new(&[n], DType::F32);
3352
3353        let mut fwd_body = Graph::new("square_fwd");
3354        let xb = fwd_body.input("x", shape.clone());
3355        let yb = fwd_body.binary(BinaryOp::Mul, xb, xb, shape.clone());
3356        fwd_body.set_outputs(vec![yb]);
3357
3358        let mut g = Graph::new("custom_fn_no_vjp");
3359        let x = g.input("x", shape.clone());
3360        let y = g.custom_fn(vec![x], fwd_body, None, None);
3361        let loss = g.reduce(
3362            y,
3363            ReduceOp::Sum,
3364            vec![0],
3365            false,
3366            Shape::new(&[1], DType::F32),
3367        );
3368        g.set_outputs(vec![loss]);
3369
3370        let bwd = grad_with_loss(&g, &[x]);
3371        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
3372        let custom_fn_count = bwd
3373            .nodes()
3374            .iter()
3375            .filter(|n| matches!(n.op, Op::CustomFn { .. }))
3376            .count();
3377        assert_eq!(
3378            custom_fn_count, 0,
3379            "CustomFn should be inlined away before autodiff"
3380        );
3381        let mul_count = bwd
3382            .nodes()
3383            .iter()
3384            .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
3385            .count();
3386        assert!(mul_count >= 2, "expected Mul-based VJP for x², got\n{bwd}");
3387    }
3388
3389    #[test]
3390    fn convert_scans_for_ad_forces_save_trajectory_true() {
3391        // grad_with_loss runs `convert_scans_for_ad` as a pre-pass: any
3392        // forward Op::Scan with `save_trajectory: false` is rewritten
3393        // to `save_trajectory: true` followed by Narrow + Reshape so
3394        // the reverse pass has the trajectory it needs. This test
3395        // verifies the rewrite happens — the bwd graph should contain
3396        // at least one Scan with save_trajectory == true.
3397        let n = 2usize;
3398        let length = 3u32;
3399        let carry = Shape::new(&[n], DType::F32);
3400        let xs_shape = Shape::new(&[length as usize, n], DType::F32);
3401
3402        // body: (carry, x_t) → carry + x_t. One primal Input each.
3403        let mut body = Graph::new("scan_body");
3404        let bc = body.input("carry", carry.clone());
3405        let bx = body.input("x_t", carry.clone());
3406        let by = body.binary(BinaryOp::Add, bc, bx, carry.clone());
3407        body.set_outputs(vec![by]);
3408
3409        let mut g = Graph::new("scan_save_false");
3410        let init = g.input("init", carry.clone());
3411        let xs = g.input("xs", xs_shape);
3412        let scan_out = g.add_node(
3413            Op::Scan {
3414                body: Box::new(body),
3415                length,
3416                save_trajectory: false,
3417                num_bcast: 0,
3418                num_xs: 1,
3419                num_checkpoints: 0,
3420            },
3421            vec![init, xs],
3422            carry.clone(),
3423        );
3424        let loss = g.reduce(
3425            scan_out,
3426            ReduceOp::Sum,
3427            vec![0],
3428            false,
3429            Shape::new(&[1], DType::F32),
3430        );
3431        g.set_outputs(vec![loss]);
3432
3433        let bwd = grad_with_loss(&g, &[init, xs]);
3434        let saved_traj = bwd.nodes().iter().any(|n| {
3435            matches!(
3436                &n.op,
3437                Op::Scan {
3438                    save_trajectory: true,
3439                    ..
3440                }
3441            )
3442        });
3443        assert!(
3444            saved_traj,
3445            "convert_scans_for_ad should rewrite save_trajectory=false → \
3446             save_trajectory=true in the AD-prepared graph; got\n{bwd}"
3447        );
3448    }
3449}