Skip to main content

rlx_autodiff/
vmap.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Batched function transformation (vmap).
17//!
18//! Lifts a graph that operates on shape `[*]` to operate on shape
19//! `[B, *]`, threading a leading batch axis through every op. Mirror
20//! of JAX's `vmap` with MVP constraints:
21//!
22//! * Leading-axis batching only — `in_axes` is a list of input names
23//!   to batch on axis 0; everything else is shared across the batch.
24//! * Outputs always land with the batch axis at 0.
25//! * Per-op rules cover the elementwise / shape / reduce / matmul
26//!   subset. Ops without a rule panic, mirroring the autodiff pass's
27//!   policy of "no silent miscompute."
28//!
29//! ## Use case
30//!
31//! Parameter sweeps: build a graph parameterised by a small input
32//! vector, `vmap` over the batched parameter values to evaluate every
33//! variant in one shot, take a gradient w.r.t. the parameter vector.
34//! Pairs naturally with `Op::BatchedDenseSolve` for batched implicit
35//! solves.
36
37use rlx_ir::shape::Dim;
38use rlx_ir::*;
39use std::collections::{HashMap, HashSet};
40
41/// Vectorize `forward` over a leading batch axis.
42///
43/// `batched_input_names` lists the `Op::Input` names whose leading
44/// axis is the batch axis after vmap. Inputs/Params not in the list
45/// are shared across the batch (they get broadcast on demand by ops
46/// that consume them alongside batched values).
47///
48/// The returned graph:
49/// * Has the same input names as `forward`. Batched inputs gain a
50///   leading `[batch_size, ...]` dim.
51/// * Has the same output count. Every output gains a leading batch
52///   axis (out_axes = 0 implicit).
53/// * Has the same set of `Op::Param` slots — params are always shared.
54///
55/// # Panics
56/// Panics on any op without a vmap rule. Add rules incrementally.
57pub fn vmap(forward: &Graph, batched_input_names: &[&str], batch_size: usize) -> Graph {
58    let batched_set: HashSet<&str> = batched_input_names.iter().copied().collect();
59    let mut out = Graph::new(format!("{}_vmap", forward.name));
60    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
61    // Set of node IDs (in the OUTPUT graph) that carry a leading batch
62    // axis. `lift_to_batched` reads this to decide whether a value
63    // needs broadcasting before being combined with a batched value.
64    let mut batched: HashSet<NodeId> = HashSet::new();
65
66    for node in forward.nodes() {
67        let new_id = match &node.op {
68            Op::Input { name } => {
69                if batched_set.contains(name.as_str()) {
70                    let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
71                    dims.extend(node.shape.dims().iter().copied());
72                    let s = Shape::from_dims(&dims, node.shape.dtype());
73                    let id = out.input(name.clone(), s);
74                    batched.insert(id);
75                    id
76                } else {
77                    out.input(name.clone(), node.shape.clone())
78                }
79            }
80            Op::Param { name } => {
81                // Params are always shared in the MVP. Convert to
82                // Input if you need batched params.
83                out.param(name.clone(), node.shape.clone())
84            }
85            Op::Constant { data } => out.add_node(
86                Op::Constant { data: data.clone() },
87                vec![],
88                node.shape.clone(),
89            ),
90            _ => {
91                let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
92                let any_batched = new_inputs.iter().any(|i| batched.contains(i));
93                if !any_batched {
94                    // No batched input reaches this node — the original
95                    // op shape applies and the node is shared.
96                    out.add_node(node.op.clone(), new_inputs, node.shape.clone())
97                } else {
98                    let id = vmap_op(node, &new_inputs, &mut out, &mut batched, batch_size);
99                    batched.insert(id);
100                    id
101                }
102            }
103        };
104        id_map.insert(node.id, new_id);
105    }
106
107    let new_outputs: Vec<NodeId> = forward.outputs.iter().map(|o| id_map[o]).collect();
108    out.set_outputs(new_outputs);
109    out
110}
111
112/// Apply the per-op vmap rule. At least one input is batched.
113fn vmap_op(
114    node: &Node,
115    new_inputs: &[NodeId],
116    out: &mut Graph,
117    batched: &mut HashSet<NodeId>,
118    batch_size: usize,
119) -> NodeId {
120    let orig_shape = &node.shape;
121    let dtype = orig_shape.dtype();
122
123    // Output shape with leading batch axis.
124    let batched_shape = || -> Shape {
125        let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
126        dims.extend(orig_shape.dims().iter().copied());
127        Shape::from_dims(&dims, dtype)
128    };
129
130    match &node.op {
131        // ── Pure elementwise — broadcast unbatched inputs, apply op ──
132        Op::Binary(_) | Op::Activation(_) | Op::Where | Op::Compare(_) | Op::Cast { .. } => {
133            let lifted: Vec<NodeId> = new_inputs
134                .iter()
135                .map(|&id| lift_to_batched(out, id, batched, batch_size))
136                .collect();
137            for &id in &lifted {
138                batched.insert(id);
139            }
140            out.add_node(node.op.clone(), lifted, batched_shape())
141        }
142
143        // ── Reshape: prepend batch dim ──
144        Op::Reshape { new_shape } => {
145            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
146            batched.insert(lifted);
147            let mut bsh: Vec<i64> = vec![batch_size as i64];
148            bsh.extend(new_shape.iter().copied());
149            out.add_node(
150                Op::Reshape { new_shape: bsh },
151                vec![lifted],
152                batched_shape(),
153            )
154        }
155
156        // ── Transpose: shift perm by 1, prepend 0 ──
157        Op::Transpose { perm } => {
158            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
159            batched.insert(lifted);
160            let mut new_perm: Vec<usize> = vec![0];
161            new_perm.extend(perm.iter().map(|p| p + 1));
162            out.add_node(
163                Op::Transpose { perm: new_perm },
164                vec![lifted],
165                batched_shape(),
166            )
167        }
168
169        // ── Expand: prepend batch dim to target_shape ──
170        Op::Expand { target_shape } => {
171            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
172            batched.insert(lifted);
173            let mut bsh: Vec<i64> = vec![batch_size as i64];
174            bsh.extend(target_shape.iter().copied());
175            out.add_node(
176                Op::Expand { target_shape: bsh },
177                vec![lifted],
178                batched_shape(),
179            )
180        }
181
182        // ── Reduce: shift axes by 1 (don't reduce batch axis) ──
183        Op::Reduce { op, axes, keep_dim } => {
184            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
185            batched.insert(lifted);
186            let new_axes: Vec<usize> = axes.iter().map(|a| a + 1).collect();
187            out.add_node(
188                Op::Reduce {
189                    op: *op,
190                    axes: new_axes,
191                    keep_dim: *keep_dim,
192                },
193                vec![lifted],
194                batched_shape(),
195            )
196        }
197
198        // ── MatMul: rely on built-in batch broadcasting ──
199        // Per Op::MatMul docs: "Batch dimensions are broadcast." So
200        // [B, M, K] @ [B, K, N] → [B, M, N], and [B, M, K] @ [K, N]
201        // also works via broadcasting.
202        Op::MatMul => {
203            let a = lift_to_batched(out, new_inputs[0], batched, batch_size);
204            let b = lift_to_batched(out, new_inputs[1], batched, batch_size);
205            batched.insert(a);
206            batched.insert(b);
207            out.matmul(a, b, batched_shape())
208        }
209
210        // ── DenseSolve: emit BatchedDenseSolve ──
211        // A becomes [B, N, N], b becomes [B, N] or [B, N, K].
212        Op::DenseSolve => {
213            let a = lift_to_batched(out, new_inputs[0], batched, batch_size);
214            let b = lift_to_batched(out, new_inputs[1], batched, batch_size);
215            batched.insert(a);
216            batched.insert(b);
217            out.batched_dense_solve(a, b, batched_shape())
218        }
219
220        // ── Scan: recursively vmap the body ──
221        //
222        // Forward Op::Scan iterates `length` times, with carry shape
223        // `*carry` and per-step xs shape `*per_step_i`. After vmap:
224        //   * init becomes `[B, *carry]`
225        //   * each xs_i becomes `[B, length, *per_step_i]`; we
226        //     transpose to `[length, B, *per_step_i]` so the Scan
227        //     reads per-step slices of shape `[B, *per_step_i]`.
228        //   * the body is recursively vmap'd — its inputs gain a
229        //     leading B and its computations become batched.
230        //   * the inner Scan output is `[B, *carry]` (final-only)
231        //     or `[length, B, *carry]` (trajectory). For trajectory,
232        //     we add a final transpose to put batch at axis 0:
233        //     `[B, length, *carry]`.
234        Op::Scan {
235            body,
236            length,
237            save_trajectory,
238            num_xs,
239            num_checkpoints: _,
240            num_bcast,
241        } => {
242            // Lift init to [B, *carry].
243            let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
244            batched.insert(init_b);
245
246            // Bcasts are lifted to [B, *bcast]; the body sees them
247            // un-transposed (no per-step axis). Same handling as the
248            // carry — each iteration the body reads the lifted slot.
249            let mut bcasts_b: Vec<NodeId> = Vec::with_capacity(*num_bcast as usize);
250            for i in 0..*num_bcast as usize {
251                let bcast_in = new_inputs[1 + i];
252                let lifted = lift_to_batched(out, bcast_in, batched, batch_size);
253                batched.insert(lifted);
254                bcasts_b.push(lifted);
255            }
256
257            // For each xs: lift to [B, length, *per_step], then
258            // transpose first two axes → [length, B, *per_step].
259            let xs_base = 1 + *num_bcast as usize;
260            let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
261            for i in 0..*num_xs as usize {
262                let xs_in = new_inputs[xs_base + i];
263                let lifted = lift_to_batched(out, xs_in, batched, batch_size);
264                batched.insert(lifted);
265                let xs_shape = out.node(lifted).shape.clone();
266                let r = xs_shape.rank();
267                let mut perm: Vec<usize> = vec![1, 0];
268                for k in 2..r {
269                    perm.push(k);
270                }
271                let mut new_dims: Vec<Dim> = xs_shape.dims().to_vec();
272                new_dims.swap(0, 1);
273                let new_shape = Shape::from_dims(&new_dims, xs_shape.dtype());
274                let transposed = out.add_node(Op::Transpose { perm }, vec![lifted], new_shape);
275                batched.insert(transposed);
276                xs_t.push(transposed);
277            }
278
279            // Recursively vmap the body. All body inputs are batched
280            // — carry comes in as `[B, *carry]`, each x_t comes in as
281            // `[B, *per_step_i]`. Collect names and dispatch.
282            let body_input_names_owned: Vec<String> = body
283                .nodes()
284                .iter()
285                .filter_map(|n| match &n.op {
286                    Op::Input { name } => Some(name.clone()),
287                    _ => None,
288                })
289                .collect();
290            let body_input_names: Vec<&str> =
291                body_input_names_owned.iter().map(|s| s.as_str()).collect();
292            let body_b = vmap(body, &body_input_names, batch_size);
293
294            // Compute the inner Scan's natural output shape.
295            //   final-only:   [B, *carry]
296            //   trajectory:   [length, B, *carry]
297            let dtype = orig_shape.dtype();
298            // `orig_shape` was either `*carry` (final-only) or
299            // `[length, *carry]` (trajectory).
300            let inner_out_shape: Shape = if *save_trajectory {
301                // Original was [length, *carry]; new is [length, B, *carry].
302                let mut dims: Vec<Dim> = vec![orig_shape.dim(0)];
303                dims.push(Dim::Static(batch_size));
304                for i in 1..orig_shape.rank() {
305                    dims.push(orig_shape.dim(i));
306                }
307                Shape::from_dims(&dims, dtype)
308            } else {
309                // Original was *carry; new is [B, *carry].
310                let mut dims: Vec<Dim> = vec![Dim::Static(batch_size)];
311                for i in 0..orig_shape.rank() {
312                    dims.push(orig_shape.dim(i));
313                }
314                Shape::from_dims(&dims, dtype)
315            };
316
317            // Build inputs: init + lifted bcasts + transposed xs.
318            let mut inner_inputs = vec![init_b];
319            inner_inputs.extend_from_slice(&bcasts_b);
320            inner_inputs.extend_from_slice(&xs_t);
321
322            let inner_id = out.add_node(
323                Op::Scan {
324                    body: Box::new(body_b),
325                    length: *length,
326                    save_trajectory: *save_trajectory,
327                    num_xs: *num_xs,
328                    num_checkpoints: 0,
329                    num_bcast: *num_bcast,
330                },
331                inner_inputs,
332                inner_out_shape,
333            );
334
335            if *save_trajectory {
336                // Trajectory: transpose [length, B, *carry] → [B, length, *carry].
337                let r = orig_shape.rank() + 1; // includes leading length + B
338                let mut perm: Vec<usize> = vec![1, 0];
339                for k in 2..r {
340                    perm.push(k);
341                }
342                out.add_node(Op::Transpose { perm }, vec![inner_id], batched_shape())
343            } else {
344                inner_id
345            }
346        }
347
348        // ── Narrow: shift axis by 1 ──
349        Op::Narrow { axis, start, len } => {
350            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
351            batched.insert(lifted);
352            out.add_node(
353                Op::Narrow {
354                    axis: axis + 1,
355                    start: *start,
356                    len: *len,
357                },
358                vec![lifted],
359                batched_shape(),
360            )
361        }
362
363        // ── Concat: shift axis by 1 ──
364        Op::Concat { axis } => {
365            let lifted: Vec<NodeId> = new_inputs
366                .iter()
367                .map(|&id| lift_to_batched(out, id, batched, batch_size))
368                .collect();
369            for &id in &lifted {
370                batched.insert(id);
371            }
372            out.add_node(Op::Concat { axis: axis + 1 }, lifted, batched_shape())
373        }
374
375        // ── Softmax: shift axis by 1 ──
376        Op::Softmax { axis } => {
377            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
378            batched.insert(lifted);
379            out.add_node(
380                Op::Softmax { axis: *axis + 1 },
381                vec![lifted],
382                batched_shape(),
383            )
384        }
385
386        // ── Cumsum: shift axis by 1 ──
387        Op::Cumsum { axis, exclusive } => {
388            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
389            batched.insert(lifted);
390            out.add_node(
391                Op::Cumsum {
392                    axis: *axis + 1,
393                    exclusive: *exclusive,
394                },
395                vec![lifted],
396                batched_shape(),
397            )
398        }
399
400        // ── LayerNorm: shift axis by 1 ──
401        // Inputs: [x, gamma, beta]. gamma/beta apply on the feature
402        // axis only — they stay shared across batch (lift_to_batched
403        // is a no-op if they're already batched, otherwise they're
404        // broadcast on demand by the kernel).
405        Op::LayerNorm { axis, eps } => {
406            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
407            batched.insert(x);
408            out.add_node(
409                Op::LayerNorm {
410                    axis: *axis + 1,
411                    eps: *eps,
412                },
413                vec![x, new_inputs[1], new_inputs[2]],
414                batched_shape(),
415            )
416        }
417
418        // ── RmsNorm: shift axis by 1 ──
419        Op::RmsNorm { axis, eps } => {
420            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
421            batched.insert(x);
422            out.add_node(
423                Op::RmsNorm {
424                    axis: *axis + 1,
425                    eps: *eps,
426                },
427                vec![x, new_inputs[1], new_inputs[2]],
428                batched_shape(),
429            )
430        }
431
432        // ── Gather: shift axis by 1; both table and indices lift ──
433        // table[indices] selects along the original axis, with B
434        // prepended both inputs index per-batch.
435        Op::Gather { axis } => {
436            let table = lift_to_batched(out, new_inputs[0], batched, batch_size);
437            let indices = lift_to_batched(out, new_inputs[1], batched, batch_size);
438            batched.insert(table);
439            batched.insert(indices);
440            out.add_node(
441                Op::Gather { axis: axis + 1 },
442                vec![table, indices],
443                batched_shape(),
444            )
445        }
446
447        // ── ScatterAdd: lift updates and indices; output gains B axis ──
448        // Forward: output[indices[i]] += updates[i]. After vmap each
449        // batch's scatter is independent; the existing kernel iterates
450        // a flat updates list, so as long as updates and indices are
451        // batched on axis 0 and the output's leading dim is B, the
452        // executor handles per-batch slicing via the scatter indices.
453        Op::ScatterAdd => {
454            let updates = lift_to_batched(out, new_inputs[0], batched, batch_size);
455            let indices = lift_to_batched(out, new_inputs[1], batched, batch_size);
456            batched.insert(updates);
457            batched.insert(indices);
458            out.add_node(Op::ScatterAdd, vec![updates, indices], batched_shape())
459        }
460
461        // ── ElementwiseRegion: same policy as plain elementwise ──
462        // The chain operates on shape `[*]` per element; lifting all
463        // inputs to `[B, *]` and letting the chain run with the wider
464        // shape gives the right per-batch result. The fused kernel's
465        // `input_modulus` machinery already handles broadcast inputs
466        // — but for true unbatched-into-batched broadcast we'd need
467        // to update those moduli. For MVP: lift everything to
468        // batched (so all inputs share `[B, *]`), keep the chain.
469        Op::ElementwiseRegion { .. } => {
470            let lifted: Vec<NodeId> = new_inputs
471                .iter()
472                .map(|&id| lift_to_batched(out, id, batched, batch_size))
473                .collect();
474            for &id in &lifted {
475                batched.insert(id);
476            }
477            out.add_node(node.op.clone(), lifted, batched_shape())
478        }
479
480        // ── DotGeneral: shift contracting + batch dim indices by 1 ──
481        Op::DotGeneral {
482            lhs_contracting,
483            rhs_contracting,
484            lhs_batch,
485            rhs_batch,
486        } => {
487            let lhs = lift_to_batched(out, new_inputs[0], batched, batch_size);
488            let rhs = lift_to_batched(out, new_inputs[1], batched, batch_size);
489            batched.insert(lhs);
490            batched.insert(rhs);
491            // Every dim index shifts by 1; axis 0 (the new batch axis)
492            // joins lhs_batch and rhs_batch since both operands are
493            // now batched on it.
494            let mut new_lhs_b: Vec<usize> = vec![0];
495            new_lhs_b.extend(lhs_batch.iter().map(|i| i + 1));
496            let mut new_rhs_b: Vec<usize> = vec![0];
497            new_rhs_b.extend(rhs_batch.iter().map(|i| i + 1));
498            out.add_node(
499                Op::DotGeneral {
500                    lhs_contracting: lhs_contracting.iter().map(|i| i + 1).collect(),
501                    rhs_contracting: rhs_contracting.iter().map(|i| i + 1).collect(),
502                    lhs_batch: new_lhs_b,
503                    rhs_batch: new_rhs_b,
504                },
505                vec![lhs, rhs],
506                batched_shape(),
507            )
508        }
509
510        // ── Backward ops emitted by autodiff: same shape lift as
511        // elementwise. ReluBackward and ActivationBackward read
512        // (x, dy) and write dx — same shape across all three. Lift
513        // and keep the op kind unchanged.
514        Op::ReluBackward | Op::ActivationBackward { .. } => {
515            let lifted: Vec<NodeId> = new_inputs
516                .iter()
517                .map(|&id| lift_to_batched(out, id, batched, batch_size))
518                .collect();
519            for &id in &lifted {
520                batched.insert(id);
521            }
522            out.add_node(node.op.clone(), lifted, batched_shape())
523        }
524
525        // ── ScanBackward: recursive AD-loop vmap ──
526        // Same shape-juggling as Op::Scan's vmap rule, plus the
527        // extra `upstream` input and a body_vjp instead of a body.
528        Op::ScanBackward {
529            body_vjp,
530            length,
531            save_trajectory,
532            num_xs,
533            num_checkpoints: _,
534            forward_body: _,
535        } => {
536            // init [B, *carry]
537            let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
538            batched.insert(init_b);
539
540            // trajectory after lift is [B, length, *carry]; ScanBackward's
541            // executor reads it row-by-row indexed by t along axis 0,
542            // so transpose to [length, B, *carry].
543            let traj_lifted = lift_to_batched(out, new_inputs[1], batched, batch_size);
544            batched.insert(traj_lifted);
545            let traj_t = transpose_swap_01(out, traj_lifted);
546            batched.insert(traj_t);
547
548            // upstream layout depends on save_trajectory:
549            //   save_trajectory=true:  same shape as trajectory →
550            //     transpose [B, length, *carry] → [length, B, *carry]
551            //   save_trajectory=false: [B, *carry] (carry shape; no
552            //     length axis) → no transpose, but lift if needed.
553            let up_lifted = lift_to_batched(out, new_inputs[2], batched, batch_size);
554            batched.insert(up_lifted);
555            let up_t = if *save_trajectory {
556                let id = transpose_swap_01(out, up_lifted);
557                batched.insert(id);
558                id
559            } else {
560                up_lifted
561            };
562
563            // Per-xs: lift to [B, length, *per_step], transpose to [length, B, *per_step].
564            let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
565            for i in 0..*num_xs as usize {
566                let xs_in = new_inputs[3 + i];
567                let lifted = lift_to_batched(out, xs_in, batched, batch_size);
568                batched.insert(lifted);
569                let t = transpose_swap_01(out, lifted);
570                batched.insert(t);
571                xs_t.push(t);
572            }
573
574            // Recursively vmap body_vjp. All its Op::Input nodes are
575            // marked batched (carry, every x_t_i, AND "d_output").
576            let body_input_names_owned: Vec<String> = body_vjp
577                .nodes()
578                .iter()
579                .filter_map(|n| match &n.op {
580                    Op::Input { name } => Some(name.clone()),
581                    _ => None,
582                })
583                .collect();
584            let body_input_names: Vec<&str> =
585                body_input_names_owned.iter().map(|s| s.as_str()).collect();
586            let body_vjp_b = vmap(body_vjp, &body_input_names, batch_size);
587
588            // dinit shape: [B, *carry] (orig_shape was *carry).
589            let mut dinit_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
590            for i in 0..orig_shape.rank() {
591                dinit_dims.push(orig_shape.dim(i));
592            }
593            let dinit_shape = Shape::from_dims(&dinit_dims, dtype);
594
595            let mut inner_inputs = vec![init_b, traj_t, up_t];
596            inner_inputs.extend_from_slice(&xs_t);
597
598            out.scan_backward(
599                init_b,
600                traj_t,
601                up_t,
602                &xs_t,
603                body_vjp_b,
604                *length,
605                *save_trajectory,
606                dinit_shape,
607            )
608        }
609
610        // ── ScanBackwardXs: like ScanBackward but output is per-step
611        // dxs_i. Inner output is [length, B, *per_step]; transpose
612        // back to [B, length, *per_step] so batch ends up at axis 0.
613        Op::ScanBackwardXs {
614            body_vjp,
615            length,
616            save_trajectory,
617            num_xs,
618            xs_idx,
619            num_checkpoints: _,
620            forward_body: _,
621        } => {
622            let init_b = lift_to_batched(out, new_inputs[0], batched, batch_size);
623            batched.insert(init_b);
624            let traj_lifted = lift_to_batched(out, new_inputs[1], batched, batch_size);
625            batched.insert(traj_lifted);
626            let traj_t = transpose_swap_01(out, traj_lifted);
627            batched.insert(traj_t);
628            let up_lifted = lift_to_batched(out, new_inputs[2], batched, batch_size);
629            batched.insert(up_lifted);
630            let up_t = if *save_trajectory {
631                let id = transpose_swap_01(out, up_lifted);
632                batched.insert(id);
633                id
634            } else {
635                up_lifted
636            };
637
638            let mut xs_t: Vec<NodeId> = Vec::with_capacity(*num_xs as usize);
639            for i in 0..*num_xs as usize {
640                let xs_in = new_inputs[3 + i];
641                let lifted = lift_to_batched(out, xs_in, batched, batch_size);
642                batched.insert(lifted);
643                let t = transpose_swap_01(out, lifted);
644                batched.insert(t);
645                xs_t.push(t);
646            }
647
648            let body_input_names_owned: Vec<String> = body_vjp
649                .nodes()
650                .iter()
651                .filter_map(|n| match &n.op {
652                    Op::Input { name } => Some(name.clone()),
653                    _ => None,
654                })
655                .collect();
656            let body_input_names: Vec<&str> =
657                body_input_names_owned.iter().map(|s| s.as_str()).collect();
658            let body_vjp_b = vmap(body_vjp, &body_input_names, batch_size);
659
660            // Inner output natural shape is [length, B, *per_step]
661            // (orig_shape is [length, *per_step]).
662            let mut inner_dims: Vec<Dim> = vec![orig_shape.dim(0)];
663            inner_dims.push(Dim::Static(batch_size));
664            for i in 1..orig_shape.rank() {
665                inner_dims.push(orig_shape.dim(i));
666            }
667            let inner_shape = Shape::from_dims(&inner_dims, dtype);
668
669            let inner_id = out.scan_backward_xs(
670                init_b,
671                traj_t,
672                up_t,
673                &xs_t,
674                body_vjp_b,
675                *length,
676                *save_trajectory,
677                *xs_idx,
678                inner_shape,
679            );
680
681            // Final transpose [length, B, *per_step] → [B, length, *per_step].
682            transpose_swap_01(out, inner_id)
683        }
684
685        // ── Quantize / Dequantize: per-channel; chan_axis +1 if Some ──
686        Op::Quantize {
687            axis,
688            scales,
689            zero_points,
690        } => {
691            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
692            batched.insert(lifted);
693            let new_axis = axis.map(|a| a + 1);
694            out.add_node(
695                Op::Quantize {
696                    axis: new_axis,
697                    scales: scales.clone(),
698                    zero_points: zero_points.clone(),
699                },
700                vec![lifted],
701                batched_shape(),
702            )
703        }
704        Op::Dequantize {
705            axis,
706            scales,
707            zero_points,
708        } => {
709            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
710            batched.insert(lifted);
711            let new_axis = axis.map(|a| a + 1);
712            out.add_node(
713                Op::Dequantize {
714                    axis: new_axis,
715                    scales: scales.clone(),
716                    zero_points: zero_points.clone(),
717                },
718                vec![lifted],
719                batched_shape(),
720            )
721        }
722        Op::FakeQuantize {
723            bits,
724            axis,
725            ste,
726            scale_mode,
727        } => {
728            let lifted: Vec<NodeId> = new_inputs
729                .iter()
730                .map(|&id| lift_to_batched(out, id, batched, batch_size))
731                .collect();
732            for &id in &lifted {
733                batched.insert(id);
734            }
735            let new_axis = axis.map(|a| a + 1);
736            out.add_node(
737                Op::FakeQuantize {
738                    bits: *bits,
739                    axis: new_axis,
740                    ste: *ste,
741                    scale_mode: *scale_mode,
742                },
743                lifted,
744                batched_shape(),
745            )
746        }
747        Op::FakeQuantizeBackward { bits, axis, ste } => {
748            let lifted: Vec<NodeId> = new_inputs
749                .iter()
750                .map(|&id| lift_to_batched(out, id, batched, batch_size))
751                .collect();
752            for &id in &lifted {
753                batched.insert(id);
754            }
755            let new_axis = axis.map(|a| a + 1);
756            out.add_node(
757                Op::FakeQuantizeBackward {
758                    bits: *bits,
759                    axis: new_axis,
760                    ste: *ste,
761                },
762                lifted,
763                batched_shape(),
764            )
765        }
766        Op::FakeQuantizeLSQ { bits, axis } => {
767            let lifted: Vec<NodeId> = new_inputs
768                .iter()
769                .map(|&id| lift_to_batched(out, id, batched, batch_size))
770                .collect();
771            for &id in &lifted {
772                batched.insert(id);
773            }
774            out.add_node(
775                Op::FakeQuantizeLSQ {
776                    bits: *bits,
777                    axis: axis.map(|a| a + 1),
778                },
779                lifted,
780                batched_shape(),
781            )
782        }
783        Op::FakeQuantizeLSQBackwardX { bits, axis } => {
784            let lifted: Vec<NodeId> = new_inputs
785                .iter()
786                .map(|&id| lift_to_batched(out, id, batched, batch_size))
787                .collect();
788            for &id in &lifted {
789                batched.insert(id);
790            }
791            out.add_node(
792                Op::FakeQuantizeLSQBackwardX {
793                    bits: *bits,
794                    axis: axis.map(|a| a + 1),
795                },
796                lifted,
797                batched_shape(),
798            )
799        }
800        Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
801            let lifted: Vec<NodeId> = new_inputs
802                .iter()
803                .map(|&id| lift_to_batched(out, id, batched, batch_size))
804                .collect();
805            for &id in &lifted {
806                batched.insert(id);
807            }
808            out.add_node(
809                Op::FakeQuantizeLSQBackwardScale {
810                    bits: *bits,
811                    axis: axis.map(|a| a + 1),
812                },
813                lifted,
814                batched_shape(),
815            )
816        }
817
818        // ── LayerNorm/RmsNorm backward: axis +1, lift inputs ──
819        Op::LayerNormBackwardInput { axis, eps } => {
820            let lifted: Vec<NodeId> = new_inputs
821                .iter()
822                .map(|&id| lift_to_batched(out, id, batched, batch_size))
823                .collect();
824            for &id in &lifted {
825                batched.insert(id);
826            }
827            out.add_node(
828                Op::LayerNormBackwardInput {
829                    axis: axis + 1,
830                    eps: *eps,
831                },
832                lifted,
833                batched_shape(),
834            )
835        }
836        Op::LayerNormBackwardGamma { axis, eps } => {
837            let lifted: Vec<NodeId> = new_inputs
838                .iter()
839                .map(|&id| lift_to_batched(out, id, batched, batch_size))
840                .collect();
841            for &id in &lifted {
842                batched.insert(id);
843            }
844            out.add_node(
845                Op::LayerNormBackwardGamma {
846                    axis: axis + 1,
847                    eps: *eps,
848                },
849                lifted,
850                batched_shape(),
851            )
852        }
853
854        // ── TopK / Sample: operate on the last axis (logits). After
855        // vmap they still operate on the last axis — just lift inputs.
856        Op::TopK { k } => {
857            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
858            batched.insert(lifted);
859            out.add_node(Op::TopK { k: *k }, vec![lifted], batched_shape())
860        }
861        Op::Sample {
862            top_k,
863            top_p,
864            temperature,
865            seed,
866        } => {
867            let lifted = lift_to_batched(out, new_inputs[0], batched, batch_size);
868            batched.insert(lifted);
869            out.add_node(
870                Op::Sample {
871                    top_k: *top_k,
872                    top_p: *top_p,
873                    temperature: *temperature,
874                    seed: *seed,
875                },
876                vec![lifted],
877                batched_shape(),
878            )
879        }
880
881        // ── LoraMatMul: lift x/w/a/b, output [B, *] ──
882        Op::LoraMatMul { scale } => {
883            let lifted: Vec<NodeId> = new_inputs
884                .iter()
885                .map(|&id| lift_to_batched(out, id, batched, batch_size))
886                .collect();
887            for &id in &lifted {
888                batched.insert(id);
889            }
890            out.add_node(Op::LoraMatMul { scale: *scale }, lifted, batched_shape())
891        }
892
893        // ── Conv / Pool / Attention / Rope: reshape-trick ──
894        // The kernel expects a specific input rank. We fold the new
895        // batch axis into the existing leading axis (N for Conv/Pool,
896        // batch for Attention) via Reshape, run the op, then reshape
897        // back to expose the vmap batch axis.
898        Op::Conv {
899            kernel_size,
900            stride,
901            padding,
902            dilation,
903            groups,
904        } => {
905            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
906            let w = new_inputs[1]; // weights stay shared
907            batched.insert(x);
908            // x is [B, N, C_in, H, W]. Flatten B*N → 4-D, run conv,
909            // reshape back.
910            let x_shape = out.node(x).shape.clone();
911            let r = x_shape.rank();
912            assert!(r == 5, "vmap Conv: expected 5-D after lift, got {r}");
913            let n_orig = match x_shape.dim(1) {
914                Dim::Static(n) => n,
915                _ => panic!("dynamic N"),
916            };
917            let bn = batch_size * n_orig;
918            let inner_dims_static: Vec<i64> = (2..r)
919                .map(|i| match x_shape.dim(i) {
920                    Dim::Static(d) => d as i64,
921                    _ => -1,
922                })
923                .collect();
924            let mut flat_dims = vec![bn as i64];
925            flat_dims.extend(inner_dims_static.iter().copied());
926            let mut flat_dim_objs = vec![Dim::Static(bn)];
927            for i in 2..r {
928                flat_dim_objs.push(x_shape.dim(i));
929            }
930            let flat_shape = Shape::from_dims(&flat_dim_objs, x_shape.dtype());
931            let x_flat = out.add_node(
932                Op::Reshape {
933                    new_shape: flat_dims,
934                },
935                vec![x],
936                flat_shape,
937            );
938            // Conv output: [B*N, C_out, H_out, W_out] in flat form.
939            let mut conv_out_dims = vec![Dim::Static(bn)];
940            for i in 1..orig_shape.rank() {
941                conv_out_dims.push(orig_shape.dim(i));
942            }
943            let conv_out_shape = Shape::from_dims(&conv_out_dims, dtype);
944            let conv_out = out.add_node(
945                Op::Conv {
946                    kernel_size: kernel_size.clone(),
947                    stride: stride.clone(),
948                    padding: padding.clone(),
949                    dilation: dilation.clone(),
950                    groups: *groups,
951                },
952                vec![x_flat, w],
953                conv_out_shape,
954            );
955            // Reshape back to [B, N, C_out, H_out, W_out].
956            let mut final_dims_static: Vec<i64> = vec![batch_size as i64];
957            for i in 0..orig_shape.rank() {
958                final_dims_static.push(match orig_shape.dim(i) {
959                    Dim::Static(d) => d as i64,
960                    _ => -1,
961                });
962            }
963            out.add_node(
964                Op::Reshape {
965                    new_shape: final_dims_static,
966                },
967                vec![conv_out],
968                batched_shape(),
969            )
970        }
971        Op::Pool {
972            kind,
973            kernel_size,
974            stride,
975            padding,
976        } => {
977            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
978            batched.insert(x);
979            let x_shape = out.node(x).shape.clone();
980            let r = x_shape.rank();
981            assert!(r == 5, "vmap Pool: expected 5-D after lift, got {r}");
982            let n_orig = match x_shape.dim(1) {
983                Dim::Static(n) => n,
984                _ => panic!("dynamic N"),
985            };
986            let bn = batch_size * n_orig;
987            let mut flat_dims = vec![bn as i64];
988            for i in 2..r {
989                flat_dims.push(match x_shape.dim(i) {
990                    Dim::Static(d) => d as i64,
991                    _ => -1,
992                });
993            }
994            let mut flat_dim_objs = vec![Dim::Static(bn)];
995            for i in 2..r {
996                flat_dim_objs.push(x_shape.dim(i));
997            }
998            let flat_shape = Shape::from_dims(&flat_dim_objs, x_shape.dtype());
999            let x_flat = out.add_node(
1000                Op::Reshape {
1001                    new_shape: flat_dims,
1002                },
1003                vec![x],
1004                flat_shape,
1005            );
1006            let mut pool_dims = vec![Dim::Static(bn)];
1007            for i in 1..orig_shape.rank() {
1008                pool_dims.push(orig_shape.dim(i));
1009            }
1010            let pool_out_shape = Shape::from_dims(&pool_dims, dtype);
1011            let pool_out = out.add_node(
1012                Op::Pool {
1013                    kind: *kind,
1014                    kernel_size: kernel_size.clone(),
1015                    stride: stride.clone(),
1016                    padding: padding.clone(),
1017                },
1018                vec![x_flat],
1019                pool_out_shape,
1020            );
1021            let mut final_dims_static: Vec<i64> = vec![batch_size as i64];
1022            for i in 0..orig_shape.rank() {
1023                final_dims_static.push(match orig_shape.dim(i) {
1024                    Dim::Static(d) => d as i64,
1025                    _ => -1,
1026                });
1027            }
1028            out.add_node(
1029                Op::Reshape {
1030                    new_shape: final_dims_static,
1031                },
1032                vec![pool_out],
1033                batched_shape(),
1034            )
1035        }
1036
1037        // ── Ops with hard kernel-shape requirements that need real
1038        // engineering before vmap can support them. Panic with a
1039        // pointer to the right follow-up rather than silently lifting
1040        // and producing wrong shapes.
1041        Op::Attention { .. }
1042        | Op::FusedAttentionBlock { .. }
1043        | Op::FusedTransformerLayer { .. }
1044        | Op::Rope { .. } => panic!(
1045            "vmap: {:?} kernels expect a fixed input rank — extra batch \
1046             axis would need either decomposition (use rlx-opt unfuse \
1047             passes first) or a kernel rewrite. Skipped in MVP.",
1048            node.op,
1049        ),
1050
1051        // ── Conv2dBackwardInput: reshape-trick around the kernel ──
1052        // Inputs: [dy, w]. dy is [N, C_out, H_out, W_out]; w is
1053        // [C_out, C_in/g, kH, kW]. Output: [N, C_in, H, W].
1054        // After vmap with N batched: dy [B, N, C_out, H_out, W_out],
1055        // weights stay shared. Fold B into N, run Conv2dBackwardInput,
1056        // fold back.
1057        Op::Conv2dBackwardInput {
1058            kernel_size,
1059            stride,
1060            padding,
1061            dilation,
1062            groups,
1063        } => {
1064            let dy = lift_to_batched(out, new_inputs[0], batched, batch_size);
1065            let w = new_inputs[1];
1066            batched.insert(dy);
1067            let dy_shape = out.node(dy).shape.clone();
1068            assert_eq!(
1069                dy_shape.rank(),
1070                5,
1071                "vmap Conv2dBackwardInput: expected 5-D dy"
1072            );
1073            let n_orig = match dy_shape.dim(1) {
1074                Dim::Static(n) => n,
1075                _ => panic!("dynamic N"),
1076            };
1077            let bn = batch_size * n_orig;
1078            let mut flat_dims_static: Vec<i64> = vec![bn as i64];
1079            for i in 2..dy_shape.rank() {
1080                flat_dims_static.push(match dy_shape.dim(i) {
1081                    Dim::Static(d) => d as i64,
1082                    _ => -1,
1083                });
1084            }
1085            let mut flat_dim_objs = vec![Dim::Static(bn)];
1086            for i in 2..dy_shape.rank() {
1087                flat_dim_objs.push(dy_shape.dim(i));
1088            }
1089            let dy_flat = out.add_node(
1090                Op::Reshape {
1091                    new_shape: flat_dims_static,
1092                },
1093                vec![dy],
1094                Shape::from_dims(&flat_dim_objs, dy_shape.dtype()),
1095            );
1096            // Output flat shape: [B*N, C_in, H, W].
1097            let mut out_flat_dim_objs = vec![Dim::Static(bn)];
1098            for i in 1..orig_shape.rank() {
1099                out_flat_dim_objs.push(orig_shape.dim(i));
1100            }
1101            let out_flat_shape = Shape::from_dims(&out_flat_dim_objs, dtype);
1102            let out_flat = out.add_node(
1103                Op::Conv2dBackwardInput {
1104                    kernel_size: kernel_size.clone(),
1105                    stride: stride.clone(),
1106                    padding: padding.clone(),
1107                    dilation: dilation.clone(),
1108                    groups: *groups,
1109                },
1110                vec![dy_flat, w],
1111                out_flat_shape,
1112            );
1113            // Reshape back to [B, N, C_in, H, W].
1114            let mut final_dims: Vec<i64> = vec![batch_size as i64];
1115            for i in 0..orig_shape.rank() {
1116                final_dims.push(match orig_shape.dim(i) {
1117                    Dim::Static(d) => d as i64,
1118                    _ => -1,
1119                });
1120            }
1121            out.add_node(
1122                Op::Reshape {
1123                    new_shape: final_dims,
1124                },
1125                vec![out_flat],
1126                batched_shape(),
1127            )
1128        }
1129
1130        // ── MaxPool2dBackward: reshape-trick like Conv ──
1131        // Inputs: [x, dy]. Both 4-D NCHW. After vmap: [B, N, C, H, W].
1132        // Fold B into N for both, run, fold back.
1133        Op::MaxPool2dBackward {
1134            kernel_size,
1135            stride,
1136            padding,
1137        } => {
1138            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
1139            let dy = lift_to_batched(out, new_inputs[1], batched, batch_size);
1140            batched.insert(x);
1141            batched.insert(dy);
1142            let x_shape = out.node(x).shape.clone();
1143            assert_eq!(x_shape.rank(), 5, "vmap MaxPool2dBackward: expected 5-D x");
1144            let n_orig = match x_shape.dim(1) {
1145                Dim::Static(n) => n,
1146                _ => panic!("dynamic N"),
1147            };
1148            let bn = batch_size * n_orig;
1149            let flatten = |out: &mut Graph, id: NodeId| -> NodeId {
1150                let s = out.node(id).shape.clone();
1151                let mut flat_objs = vec![Dim::Static(bn)];
1152                for i in 2..s.rank() {
1153                    flat_objs.push(s.dim(i));
1154                }
1155                let flat_shape = Shape::from_dims(&flat_objs, s.dtype());
1156                let mut flat_static: Vec<i64> = vec![bn as i64];
1157                for i in 2..s.rank() {
1158                    flat_static.push(match s.dim(i) {
1159                        Dim::Static(d) => d as i64,
1160                        _ => -1,
1161                    });
1162                }
1163                out.add_node(
1164                    Op::Reshape {
1165                        new_shape: flat_static,
1166                    },
1167                    vec![id],
1168                    flat_shape,
1169                )
1170            };
1171            let x_flat = flatten(out, x);
1172            let dy_flat = flatten(out, dy);
1173            let mut out_flat_objs = vec![Dim::Static(bn)];
1174            for i in 1..orig_shape.rank() {
1175                out_flat_objs.push(orig_shape.dim(i));
1176            }
1177            let out_flat_shape = Shape::from_dims(&out_flat_objs, dtype);
1178            let pool_out = out.add_node(
1179                Op::MaxPool2dBackward {
1180                    kernel_size: kernel_size.clone(),
1181                    stride: stride.clone(),
1182                    padding: padding.clone(),
1183                },
1184                vec![x_flat, dy_flat],
1185                out_flat_shape,
1186            );
1187            let mut final_dims: Vec<i64> = vec![batch_size as i64];
1188            for i in 0..orig_shape.rank() {
1189                final_dims.push(match orig_shape.dim(i) {
1190                    Dim::Static(d) => d as i64,
1191                    _ => -1,
1192                });
1193            }
1194            out.add_node(
1195                Op::Reshape {
1196                    new_shape: final_dims,
1197                },
1198                vec![pool_out],
1199                batched_shape(),
1200            )
1201        }
1202
1203        Op::Conv2dBackwardWeight { .. } => panic!(
1204            "vmap: Conv2dBackwardWeight: weight gradient is summed across \
1205             samples — vmap-batching gives a B-stack of independent dWs. \
1206             Reshape-trick doesn't apply since the output isn't naturally \
1207             N-leading. Add a per-batch dW pattern when needed.",
1208        ),
1209
1210        Op::SelectiveScan { .. }
1211        | Op::GroupedMatMul
1212        | Op::QMatMul { .. }
1213        | Op::QConv2d { .. }
1214        | Op::DequantMatMul { .. } => panic!(
1215            "vmap: {:?} has its own internal batch handling; \
1216             the right rule depends on whether the user wants \
1217             nested batching or to fold into the existing batch \
1218             dim. Add a rule when a real workload demands it.",
1219            node.op,
1220        ),
1221
1222        // ── DequantGroupedMatMul: shared expert weights, batched tokens ──
1223        Op::DequantGroupedMatMul { scheme } => {
1224            let x = lift_to_batched(out, new_inputs[0], batched, batch_size);
1225            let idx = lift_to_batched(out, new_inputs[2], batched, batch_size);
1226            let w = new_inputs[1];
1227            batched.insert(x);
1228            batched.insert(idx);
1229            let x_shape = out.node(x).shape.clone();
1230            assert_eq!(
1231                x_shape.rank(),
1232                3,
1233                "vmap DequantGroupedMatMul: expected 3-D x"
1234            );
1235            let m_orig = match x_shape.dim(1) {
1236                Dim::Static(v) => v,
1237                _ => panic!("dynamic M"),
1238            };
1239            let k = match x_shape.dim(2) {
1240                Dim::Static(v) => v as i64,
1241                _ => -1,
1242            };
1243            let bm = batch_size * m_orig;
1244            let n = match orig_shape.dim(orig_shape.rank() - 1) {
1245                Dim::Static(v) => v as i64,
1246                _ => -1,
1247            };
1248            let x_flat = out.add_node(
1249                Op::Reshape {
1250                    new_shape: vec![bm as i64, k],
1251                },
1252                vec![x],
1253                Shape::from_dims(&[Dim::Static(bm), x_shape.dim(2)], orig_shape.dtype()),
1254            );
1255            let idx_flat = out.add_node(
1256                Op::Reshape {
1257                    new_shape: vec![bm as i64],
1258                },
1259                vec![idx],
1260                Shape::from_dims(&[Dim::Static(bm)], orig_shape.dtype()),
1261            );
1262            let y_flat = out.add_node(
1263                Op::DequantGroupedMatMul { scheme: *scheme },
1264                vec![x_flat, w, idx_flat],
1265                Shape::from_dims(
1266                    &[Dim::Static(bm), orig_shape.dim(orig_shape.rank() - 1)],
1267                    orig_shape.dtype(),
1268                ),
1269            );
1270            let mut final_dims: Vec<i64> = vec![batch_size as i64, m_orig as i64];
1271            final_dims.push(n);
1272            out.add_node(
1273                Op::Reshape {
1274                    new_shape: final_dims,
1275                },
1276                vec![y_flat],
1277                batched_shape(),
1278            )
1279        }
1280
1281        Op::DequantMoEWeights { .. } => panic!(
1282            "vmap: DequantMoEWeights is a weight materialization helper; \
1283             vmap the downstream GroupedMatMul / DequantGroupedMatMul instead.",
1284        ),
1285
1286        Op::FusedSwiGLU { .. }
1287        | Op::FusedMatMulBiasAct { .. }
1288        | Op::FusedResidualLN { .. }
1289        | Op::FusedResidualRmsNorm { .. } => {
1290            panic!(
1291                "vmap: {:?} is fused — decompose first via \
1292             `rlx_fusion::UnfuseElementwiseRegions` (or \
1293             `rlx_fusion::unfuse_fused_for_autodiff`) so the simpler \
1294             ops get vmap'd individually.",
1295                node.op,
1296            )
1297        }
1298
1299        Op::SoftmaxCrossEntropyWithLogits | Op::SoftmaxCrossEntropyBackward => panic!(
1300            "vmap: SoftmaxCrossEntropy* expect 2-D logits; lifting to \
1301             3-D would need a kernel change. Workaround: reshape \
1302             logits to 2-D before the op and back after.",
1303        ),
1304
1305        Op::Custom { name, .. } => {
1306            // Dispatch through the OpExtension registry. The op's
1307            // `vmap` impl receives the already-lifted inputs and
1308            // returns the lifted output. Default impl returns None,
1309            // which we surface as a clear panic.
1310            let ext = rlx_ir::lookup_op(name)
1311                .unwrap_or_else(|| panic!("vmap: Op::Custom('{name}') not registered"));
1312            let is_batched: Vec<bool> = new_inputs.iter().map(|i| batched.contains(i)).collect();
1313            let mut ctx = rlx_ir::VmapContext {
1314                lifted_inputs: new_inputs,
1315                is_batched: &is_batched,
1316                batch_size,
1317                out,
1318            };
1319            match ext.vmap(node, &mut ctx) {
1320                Some(id) => id,
1321                None => panic!(
1322                    "vmap: Op::Custom('{name}') has no vmap rule registered. \
1323                     Override `OpExtension::vmap` on the impl to add one."
1324                ),
1325            }
1326        }
1327
1328        // CustomFn: recursively vmap each body (fwd / vjp / jvp). All
1329        // Inputs in each body are treated as batched — primals become
1330        // [B, *primal] (matching the lifted outer inputs), and the
1331        // AD-special-named Inputs ("primal_output", "d_output",
1332        // "tangent_*") are likewise batched since the outer graph
1333        // wires them to batched producers post-vmap.
1334        Op::CustomFn {
1335            fwd_body,
1336            vjp_body,
1337            jvp_body,
1338            num_inputs,
1339        } => {
1340            // Lift each primal input to [B, *primal].
1341            let mut lifted_inputs: Vec<NodeId> = Vec::with_capacity(*num_inputs as usize);
1342            for &raw in new_inputs.iter() {
1343                let lifted = lift_to_batched(out, raw, batched, batch_size);
1344                batched.insert(lifted);
1345                lifted_inputs.push(lifted);
1346            }
1347
1348            let vmap_body = |body: &Graph| -> Graph {
1349                let names_owned: Vec<String> = body
1350                    .nodes()
1351                    .iter()
1352                    .filter_map(|n| match &n.op {
1353                        Op::Input { name } => Some(name.clone()),
1354                        _ => None,
1355                    })
1356                    .collect();
1357                let names: Vec<&str> = names_owned.iter().map(|s| s.as_str()).collect();
1358                vmap(body, &names, batch_size)
1359            };
1360
1361            let fwd_b = vmap_body(fwd_body);
1362            let vjp_b = vjp_body.as_ref().map(|g| vmap_body(g));
1363            let jvp_b = jvp_body.as_ref().map(|g| vmap_body(g));
1364
1365            // Output shape: [B, *orig_output].
1366            let mut out_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
1367            for i in 0..orig_shape.rank() {
1368                out_dims.push(orig_shape.dim(i));
1369            }
1370            let out_shape = Shape::from_dims(&out_dims, orig_shape.dtype());
1371
1372            let id = out.add_node(
1373                Op::CustomFn {
1374                    fwd_body: Box::new(fwd_b),
1375                    vjp_body: vjp_b.map(Box::new),
1376                    jvp_body: jvp_b.map(Box::new),
1377                    num_inputs: *num_inputs,
1378                },
1379                lifted_inputs,
1380                out_shape,
1381            );
1382            batched.insert(id);
1383            id
1384        }
1385
1386        other => panic!(
1387            "vmap: no rule for op {:?}. Add a per-op rule in vmap.rs.",
1388            other,
1389        ),
1390    }
1391}
1392
1393/// Swap the first two axes of a tensor (perm = [1, 0, 2, 3, ...]).
1394/// Used by the Scan / ScanBackward / ScanBackwardXs vmap rules to
1395/// move the batch axis between the natural-after-vmap leading
1396/// position and the position the inner Scan-family op expects
1397/// (`length` first, batch second per row).
1398fn transpose_swap_01(out: &mut Graph, id: NodeId) -> NodeId {
1399    let s = out.node(id).shape.clone();
1400    let r = s.rank();
1401    debug_assert!(r >= 2, "transpose_swap_01 needs rank ≥ 2");
1402    let mut perm: Vec<usize> = vec![1, 0];
1403    for i in 2..r {
1404        perm.push(i);
1405    }
1406    let mut new_dims: Vec<Dim> = s.dims().to_vec();
1407    new_dims.swap(0, 1);
1408    let new_shape = Shape::from_dims(&new_dims, s.dtype());
1409    out.add_node(Op::Transpose { perm }, vec![id], new_shape)
1410}
1411
1412/// Make sure `id` carries a leading batch axis. If it already does,
1413/// return it unchanged. Otherwise emit `Reshape([1, *])` followed by
1414/// `Expand([B, *])`.
1415fn lift_to_batched(
1416    out: &mut Graph,
1417    id: NodeId,
1418    batched: &HashSet<NodeId>,
1419    batch_size: usize,
1420) -> NodeId {
1421    if batched.contains(&id) {
1422        return id;
1423    }
1424    let orig_shape = out.node(id).shape.clone();
1425    let dtype = orig_shape.dtype();
1426
1427    // Reshape [orig...] → [1, orig...].
1428    let mut dims_with_1: Vec<Dim> = vec![Dim::Static(1)];
1429    dims_with_1.extend(orig_shape.dims().iter().copied());
1430    let with1_shape = Shape::from_dims(&dims_with_1, dtype);
1431    let reshape_dims: Vec<i64> = dims_with_1
1432        .iter()
1433        .map(|d| match d {
1434            Dim::Static(n) => *n as i64,
1435            Dim::Dynamic(_) => -1,
1436        })
1437        .collect();
1438    let with1 = out.add_node(
1439        Op::Reshape {
1440            new_shape: reshape_dims,
1441        },
1442        vec![id],
1443        with1_shape,
1444    );
1445
1446    // Expand [1, orig...] → [B, orig...].
1447    let mut target_dims: Vec<i64> = vec![batch_size as i64];
1448    for d in orig_shape.dims().iter() {
1449        target_dims.push(match d {
1450            Dim::Static(n) => *n as i64,
1451            Dim::Dynamic(_) => -1,
1452        });
1453    }
1454    let mut target_shape_dims: Vec<Dim> = vec![Dim::Static(batch_size)];
1455    target_shape_dims.extend(orig_shape.dims().iter().copied());
1456    let target_shape = Shape::from_dims(&target_shape_dims, dtype);
1457    out.add_node(
1458        Op::Expand {
1459            target_shape: target_dims,
1460        },
1461        vec![with1],
1462        target_shape,
1463    )
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468    use super::*;
1469    use rlx_ir::op::{BinaryOp, ReduceOp};
1470
1471    /// Smallest possible vmap: elementwise scaling. f(x) = 2·x.
1472    /// vmap(f) over batch=4 should produce a graph with:
1473    ///   Input "x" : [4, 3] f64
1474    ///   Constant 2: [3] f64 (shared, lifted via Reshape+Expand at use site)
1475    ///   Mul        : [4, 3] f64
1476    /// Asserts the structural shape and that the batched output node
1477    /// is recorded as such.
1478    #[test]
1479    fn vmap_elementwise_scaling_lifts_to_batched_shape() {
1480        let n = 3usize;
1481        let batch = 4usize;
1482        let mut g = Graph::new("scale");
1483        let x = g.input("x", Shape::new(&[n], DType::F64));
1484        let two_bytes: Vec<u8> = (0..n).flat_map(|_| 2.0_f64.to_le_bytes()).collect();
1485        let two = g.add_node(
1486            Op::Constant { data: two_bytes },
1487            vec![],
1488            Shape::new(&[n], DType::F64),
1489        );
1490        let y = g.binary(BinaryOp::Mul, x, two, Shape::new(&[n], DType::F64));
1491        g.set_outputs(vec![y]);
1492
1493        let bg = vmap(&g, &["x"], batch);
1494        // Output should be [batch, n].
1495        let out_id = bg.outputs[0];
1496        let out_shape = &bg.node(out_id).shape;
1497        assert_eq!(out_shape.dims().len(), 2);
1498        assert_eq!(out_shape.dim(0), Dim::Static(batch));
1499        assert_eq!(out_shape.dim(1), Dim::Static(n));
1500    }
1501
1502    /// vmap of a matmul: f(x) = MatMul(x, w). x is `[m, k]`, batched
1503    /// to `[B, m, k]`. w stays `[k, n]`. Output: `[B, m, n]`. Built-in
1504    /// MatMul batch broadcasting handles it; vmap just lifts x.
1505    #[test]
1506    fn vmap_matmul_with_shared_weight() {
1507        let m = 2usize;
1508        let k = 3usize;
1509        let n = 4usize;
1510        let batch = 5usize;
1511        let mut g = Graph::new("mm");
1512        let x = g.input("x", Shape::new(&[m, k], DType::F64));
1513        let w = g.input("w", Shape::new(&[k, n], DType::F64));
1514        let y = g.matmul(x, w, Shape::new(&[m, n], DType::F64));
1515        g.set_outputs(vec![y]);
1516
1517        let bg = vmap(&g, &["x"], batch);
1518        let out_id = bg.outputs[0];
1519        let out_shape = &bg.node(out_id).shape;
1520        assert_eq!(out_shape.dims().len(), 3);
1521        assert_eq!(out_shape.dim(0), Dim::Static(batch));
1522        assert_eq!(out_shape.dim(1), Dim::Static(m));
1523        assert_eq!(out_shape.dim(2), Dim::Static(n));
1524    }
1525
1526    /// basic test for the recently-added rules: build a graph that
1527    /// exercises Gather, ElementwiseRegion fallback, ReluBackward,
1528    /// and ActivationBackward — vmap and assert it completes without
1529    /// hitting the catch-all panic.
1530    #[test]
1531    fn vmap_extended_op_set_lifts_without_panic() {
1532        // Gather pattern: table[indices] with table batched.
1533        let mut g = Graph::new("gather_check");
1534        let table = g.input("table", Shape::new(&[5, 4], DType::F32));
1535        let idx = g.input("idx", Shape::new(&[3], DType::F32));
1536        let out_node = g.add_node(
1537            Op::Gather { axis: 0 },
1538            vec![table, idx],
1539            Shape::new(&[3, 4], DType::F32),
1540        );
1541        g.set_outputs(vec![out_node]);
1542        let bg = vmap(&g, &["table"], 2);
1543        // Output: [B, 3, 4].
1544        let s = &bg.node(bg.outputs[0]).shape;
1545        assert_eq!(s.rank(), 3);
1546        assert_eq!(s.dim(0), Dim::Static(2));
1547
1548        // ReluBackward check — inputs (x, dy), output dx.
1549        let mut g = Graph::new("relu_bwd_check");
1550        let x = g.input("x", Shape::new(&[4], DType::F32));
1551        let dy = g.input("dy", Shape::new(&[4], DType::F32));
1552        let dx = g.add_node(Op::ReluBackward, vec![x, dy], Shape::new(&[4], DType::F32));
1553        g.set_outputs(vec![dx]);
1554        let bg = vmap(&g, &["x"], 3);
1555        let s = &bg.node(bg.outputs[0]).shape;
1556        assert_eq!(s.rank(), 2);
1557        assert_eq!(s.dim(0), Dim::Static(3));
1558    }
1559
1560    /// vmap composition test: f(x) = sum(x · w + b) → loss per-batch.
1561    /// Asserts the output is `[batch]` (sum over axis 1 of [batch, n]).
1562    #[test]
1563    fn vmap_combined_matmul_add_reduce() {
1564        let n = 3usize;
1565        let batch = 4usize;
1566        let mut g = Graph::new("combined");
1567        let x = g.input("x", Shape::new(&[n], DType::F64));
1568        let w = g.input("w", Shape::new(&[n, n], DType::F64));
1569        let b = g.input("b", Shape::new(&[n], DType::F64));
1570        // Reshape x to [1, n] so MatMul works on [1, n] @ [n, n] = [1, n]
1571        let x_row = g.add_node(
1572            Op::Reshape {
1573                new_shape: vec![1, n as i64],
1574            },
1575            vec![x],
1576            Shape::new(&[1, n], DType::F64),
1577        );
1578        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
1579        let mm_flat = g.add_node(
1580            Op::Reshape {
1581                new_shape: vec![n as i64],
1582            },
1583            vec![mm],
1584            Shape::new(&[n], DType::F64),
1585        );
1586        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
1587        let loss = g.reduce(
1588            yv,
1589            ReduceOp::Sum,
1590            vec![0],
1591            false,
1592            Shape::new(&[1], DType::F64),
1593        );
1594        g.set_outputs(vec![loss]);
1595
1596        let bg = vmap(&g, &["x"], batch);
1597        let out = bg.node(bg.outputs[0]);
1598        // After Reduce::Sum on shifted axis (1), keep_dim=false → shape [B, 1].
1599        // (Reduce shifts axis 0 → 1; the original [1] output becomes [B, 1].)
1600        assert_eq!(out.shape.dim(0), Dim::Static(batch));
1601        assert_eq!(out.shape.rank(), 2);
1602    }
1603}