Skip to main content

rlx_ir/ops/
special.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Specialised builders: SSM selective scan + space for future
17//! exotic ops (plan #53).
18
19use crate::shape::Dim;
20use crate::{DType, Graph, NodeId, Op, Shape};
21
22/// Handle to a multi-output [`Op::CustomFn`] built via
23/// [`Graph::custom_fn_multi`]. Internally the op produces a flat 1-D
24/// concatenated output; this handle remembers each sub-output's
25/// offset + original shape so [`Self::output`] can materialize
26/// component `i` lazily via `Op::Narrow` + `Op::Reshape`.
27#[derive(Debug, Clone)]
28pub struct MultiOutputHandle {
29    /// NodeId of the wrapped Op::CustomFn (1-D F32, length =
30    /// `Σ sub_shapes[i].num_elements`).
31    pub source: NodeId,
32    /// Original per-sub-output shapes (in declaration order).
33    pub sub_shapes: Vec<Shape>,
34    /// Per-sub-output start offsets into `source` (element-counted,
35    /// not byte-counted).
36    pub offsets: Vec<usize>,
37}
38
39impl MultiOutputHandle {
40    /// Number of sub-outputs.
41    pub fn n_outputs(&self) -> usize {
42        self.sub_shapes.len()
43    }
44
45    /// Materialize sub-output `idx` as an outer-graph NodeId.
46    /// Internally: `Op::Narrow(source, axis=0, start=offsets[idx],
47    /// len=numel(sub_shapes[idx]))` → `Op::Reshape` back to the
48    /// declared shape.
49    pub fn output(&self, g: &mut Graph, idx: usize) -> NodeId {
50        assert!(idx < self.sub_shapes.len(), "output index out of range");
51        let sub = &self.sub_shapes[idx];
52        let n_elems: usize = sub
53            .dims()
54            .iter()
55            .map(|d| match d {
56                Dim::Static(k) => *k,
57                Dim::Dynamic(_) => panic!("dynamic sub-output dim"),
58            })
59            .product();
60        let flat_shape = Shape::from_dims(&[Dim::Static(n_elems)], sub.dtype());
61        let narrowed = g.add_node(
62            Op::Narrow {
63                axis: 0,
64                start: self.offsets[idx],
65                len: n_elems,
66            },
67            vec![self.source],
68            flat_shape,
69        );
70        if sub.rank() == 1 {
71            // Already the right shape.
72            narrowed
73        } else {
74            let dims: Vec<i64> = sub
75                .dims()
76                .iter()
77                .map(|d| match d {
78                    Dim::Static(k) => *k as i64,
79                    Dim::Dynamic(_) => unreachable!(),
80                })
81                .collect();
82            g.add_node(Op::Reshape { new_shape: dims }, vec![narrowed], sub.clone())
83        }
84    }
85}
86
87impl Graph {
88    /// Mamba-style selective scan: y = SSM(x, Δ, A, B, C).
89    /// Inputs: x \[b,s,h\], delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
90    /// Output \[b,s,h\]. n is the state size.
91    pub fn selective_scan(
92        &mut self,
93        x: NodeId,
94        delta: NodeId,
95        a: NodeId,
96        b: NodeId,
97        c: NodeId,
98        state_size: usize,
99        shape: Shape,
100    ) -> NodeId {
101        self.push(
102            Op::SelectiveScan { state_size },
103            vec![x, delta, a, b, c],
104            shape,
105            None,
106        )
107    }
108
109    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk,
110    /// Qwen3-Next, Kimi-Linear). See [`Op::GatedDeltaNet`] for the
111    /// recurrence math. All five inputs are `f32`. Shapes:
112    /// `q,k,v`: `[b, s, h_v, n]`; `g,beta`: `[b, s, h_v]`. Output:
113    /// `[b, s, h_v, n]`. State is implicit (reset per batch) unless
114    /// `carry_state` is set — then pass `state` as a sixth input.
115    pub fn gated_delta_net(
116        &mut self,
117        q: NodeId,
118        k: NodeId,
119        v: NodeId,
120        g: NodeId,
121        beta: NodeId,
122        state_size: usize,
123        shape: Shape,
124    ) -> NodeId {
125        self.push(
126            Op::GatedDeltaNet {
127                state_size,
128                carry_state: false,
129            },
130            vec![q, k, v, g, beta],
131            shape,
132            None,
133        )
134    }
135
136    /// Same as [`Self::gated_delta_net`] but threads `state`
137    /// `[b, h_v, n, n]` in/out for decode-mode recurrence.
138    pub fn gated_delta_net_carry(
139        &mut self,
140        q: NodeId,
141        k: NodeId,
142        v: NodeId,
143        g: NodeId,
144        beta: NodeId,
145        state: NodeId,
146        state_size: usize,
147        shape: Shape,
148    ) -> NodeId {
149        self.push(
150            Op::GatedDeltaNet {
151                state_size,
152                carry_state: true,
153            },
154            vec![q, k, v, g, beta, state],
155            shape,
156            None,
157        )
158    }
159
160    /// Multi-layer (optionally bidirectional) LSTM with packed weights.
161    /// See [`Op::Lstm`] for the gate math, weight packing, and shapes.
162    /// Output `shape` = `[batch, seq, D*hidden]` (`D = 2` iff
163    /// `bidirectional`). Initial `h0`/`c0` are zero.
164    #[allow(clippy::too_many_arguments)]
165    pub fn lstm(
166        &mut self,
167        x: NodeId,
168        w_ih: NodeId,
169        w_hh: NodeId,
170        bias: NodeId,
171        hidden_size: usize,
172        num_layers: usize,
173        bidirectional: bool,
174        shape: Shape,
175    ) -> NodeId {
176        self.push(
177            Op::Lstm {
178                hidden_size,
179                num_layers,
180                bidirectional,
181                carry: false,
182            },
183            vec![x, w_ih, w_hh, bias],
184            shape,
185            None,
186        )
187    }
188
189    /// Same as [`Self::lstm`] but threads decode state: `h0`/`c0`
190    /// `[L*D, batch, hidden]` in, `hn`/`cn` written back in place.
191    #[allow(clippy::too_many_arguments)]
192    pub fn lstm_carry(
193        &mut self,
194        x: NodeId,
195        w_ih: NodeId,
196        w_hh: NodeId,
197        bias: NodeId,
198        h0: NodeId,
199        c0: NodeId,
200        hidden_size: usize,
201        num_layers: usize,
202        bidirectional: bool,
203        shape: Shape,
204    ) -> NodeId {
205        self.push(
206            Op::Lstm {
207                hidden_size,
208                num_layers,
209                bidirectional,
210                carry: true,
211            },
212            vec![x, w_ih, w_hh, bias, h0, c0],
213            shape,
214            None,
215        )
216    }
217
218    /// Multi-layer (optionally bidirectional) GRU. Inputs
219    /// `[x, w_ih, w_hh, b_ih, b_hh]` (`+ [h0]` when `carry`). See
220    /// [`Op::Gru`]. Output `shape` = `[batch, seq, D*hidden]`.
221    #[allow(clippy::too_many_arguments)]
222    pub fn gru(
223        &mut self,
224        x: NodeId,
225        w_ih: NodeId,
226        w_hh: NodeId,
227        b_ih: NodeId,
228        b_hh: NodeId,
229        h0: Option<NodeId>,
230        hidden_size: usize,
231        num_layers: usize,
232        bidirectional: bool,
233        shape: Shape,
234    ) -> NodeId {
235        let mut inputs = vec![x, w_ih, w_hh, b_ih, b_hh];
236        let carry = h0.is_some();
237        if let Some(h) = h0 {
238            inputs.push(h);
239        }
240        self.push(
241            Op::Gru {
242                hidden_size,
243                num_layers,
244                bidirectional,
245                carry,
246            },
247            inputs,
248            shape,
249            None,
250        )
251    }
252
253    /// Multi-layer (optionally bidirectional) Elman RNN. Inputs
254    /// `[x, w_ih, w_hh, bias]` (`+ [h0]` when `carry`). `relu` selects the
255    /// activation (else tanh). See [`Op::Rnn`]. Output `[batch, seq, D*hidden]`.
256    #[allow(clippy::too_many_arguments)]
257    pub fn rnn(
258        &mut self,
259        x: NodeId,
260        w_ih: NodeId,
261        w_hh: NodeId,
262        bias: NodeId,
263        h0: Option<NodeId>,
264        hidden_size: usize,
265        num_layers: usize,
266        bidirectional: bool,
267        relu: bool,
268        shape: Shape,
269    ) -> NodeId {
270        let mut inputs = vec![x, w_ih, w_hh, bias];
271        let carry = h0.is_some();
272        if let Some(h) = h0 {
273            inputs.push(h);
274        }
275        self.push(
276            Op::Rnn {
277                hidden_size,
278                num_layers,
279                bidirectional,
280                carry,
281                relu,
282            },
283            inputs,
284            shape,
285            None,
286        )
287    }
288
289    /// Mamba-2 / SSD scalar-decay SSM scan. See [`Op::Mamba2`]. Inputs
290    /// `[x, dt, a, b, c]`; output `shape` = `x` shape `[B,S,H,P]`.
291    pub fn mamba2(
292        &mut self,
293        x: NodeId,
294        dt: NodeId,
295        a: NodeId,
296        b: NodeId,
297        c: NodeId,
298        head_dim: usize,
299        state_size: usize,
300        shape: Shape,
301    ) -> NodeId {
302        self.push(
303            Op::Mamba2 {
304                head_dim,
305                state_size,
306            },
307            vec![x, dt, a, b, c],
308            shape,
309            None,
310        )
311    }
312
313    /// Bounded scan returning the final carry. Body must have exactly
314    /// one `Op::Input` (the carry) and one output, both same shape as
315    /// `init`. Output shape matches `init`.
316    pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
317        let init_shape = self.shape(init).clone();
318        self.push(
319            Op::Scan {
320                body: Box::new(body),
321                length,
322                save_trajectory: false,
323                num_bcast: 0,
324                num_xs: 0,
325                num_checkpoints: 0,
326            },
327            vec![init],
328            init_shape,
329            None,
330        )
331    }
332
333    /// Bounded scan with recursive checkpointing for memory-bounded
334    /// backward AD. Equivalent to [`Self::scan`] for the forward
335    /// computation, but during backward only `num_checkpoints` carry
336    /// values are cached; intermediate carries are recomputed via the
337    /// body. Memory: `O(num_checkpoints · carry_size)`. Time: forward
338    /// unchanged; backward `O(length)` (segment-cached).
339    ///
340    /// The AD pre-pass propagates `num_checkpoints` into the rewritten
341    /// trajectory-saving Scan and into the emitted ScanBackward, so a
342    /// single call to [`crate::Graph::scan_checkpointed`] is enough
343    /// to enable the memory bound across the whole forward+backward
344    /// pipeline.
345    pub fn scan_checkpointed(
346        &mut self,
347        init: NodeId,
348        body: Graph,
349        length: u32,
350        num_checkpoints: u32,
351    ) -> NodeId {
352        assert!(
353            num_checkpoints > 0 && num_checkpoints <= length,
354            "scan_checkpointed: num_checkpoints={num_checkpoints} \
355             must be in 1..=length={length}"
356        );
357        let init_shape = self.shape(init).clone();
358        self.push(
359            Op::Scan {
360                body: Box::new(body),
361                length,
362                save_trajectory: false,
363                num_bcast: 0,
364                num_xs: 0,
365                num_checkpoints,
366            },
367            vec![init],
368            init_shape,
369            None,
370        )
371    }
372
373    /// Bounded scan with broadcast and per-step inputs.
374    ///
375    /// Body `Op::Input`s in NodeId order: `[carry, bcast_0..bcast_{B-1},
376    /// x_t_0..x_t_{X-1}]`. Bcast inputs keep their natural shape (the
377    /// CPU executor fills them once before the scan loop). xs\[i\] has
378    /// shape `[length, *per_step]` and the body sees `xs[i][t]` per
379    /// iteration. Output shape matches `init`.
380    pub fn scan_with_bcasts_and_xs(
381        &mut self,
382        init: NodeId,
383        bcasts: &[NodeId],
384        xs: &[NodeId],
385        body: Graph,
386        length: u32,
387    ) -> NodeId {
388        let init_shape = self.shape(init).clone();
389        let mut inputs = vec![init];
390        inputs.extend_from_slice(bcasts);
391        inputs.extend_from_slice(xs);
392        self.push(
393            Op::Scan {
394                body: Box::new(body),
395                length,
396                save_trajectory: false,
397                num_bcast: bcasts.len() as u32,
398                num_xs: xs.len() as u32,
399                num_checkpoints: 0,
400            },
401            inputs,
402            init_shape,
403            None,
404        )
405    }
406
407    /// Bounded scan with per-step `xs` inputs returning the final carry.
408    /// Body has `1 + xs.len()` Op::Inputs in NodeId construction order
409    /// (first declared is the carry; the remaining match `xs` in order).
410    /// Each `xs[i]` has shape `[length, *per_step_shape_i]`; the body
411    /// sees a `per_step_shape_i` slice on iteration `t`.
412    pub fn scan_with_xs(
413        &mut self,
414        init: NodeId,
415        xs: &[NodeId],
416        body: Graph,
417        length: u32,
418    ) -> NodeId {
419        let init_shape = self.shape(init).clone();
420        let mut inputs = vec![init];
421        inputs.extend_from_slice(xs);
422        self.push(
423            Op::Scan {
424                body: Box::new(body),
425                length,
426                save_trajectory: false,
427                num_bcast: 0,
428                num_xs: xs.len() as u32,
429                num_checkpoints: 0,
430            },
431            inputs,
432            init_shape,
433            None,
434        )
435    }
436
437    /// Reverse-mode AD companion to [`Self::scan`] /
438    /// [`Self::scan_trajectory`]. Typically constructed by the
439    /// autodiff pass, not by hand.
440    ///
441    /// `xs` is the list of per-step input tensors (must match the
442    /// forward Op::Scan's xs in count, order, and per-step shape).
443    /// Body_vjp's `1 + xs.len() + 1` Op::Inputs match the forward
444    /// body's inputs plus a fresh `"d_output"` Input.
445    pub fn scan_backward(
446        &mut self,
447        init: NodeId,
448        trajectory: NodeId,
449        upstream: NodeId,
450        xs: &[NodeId],
451        body_vjp: Graph,
452        length: u32,
453        save_trajectory: bool,
454        out_shape: Shape,
455    ) -> NodeId {
456        self.scan_backward_with_checkpoints(
457            init,
458            trajectory,
459            upstream,
460            xs,
461            body_vjp,
462            length,
463            save_trajectory,
464            0,
465            None,
466            out_shape,
467        )
468    }
469
470    /// Lower-level `scan_backward` with explicit checkpointing config.
471    /// `num_checkpoints == 0` (default) means no checkpointing — the
472    /// trajectory cache holds every step's carry. `0 < K < length`
473    /// enables segment-cached recompute via `forward_body` (must be
474    /// `Some`).
475    #[allow(clippy::too_many_arguments)]
476    pub fn scan_backward_with_checkpoints(
477        &mut self,
478        init: NodeId,
479        trajectory: NodeId,
480        upstream: NodeId,
481        xs: &[NodeId],
482        body_vjp: Graph,
483        length: u32,
484        save_trajectory: bool,
485        num_checkpoints: u32,
486        forward_body: Option<Graph>,
487        out_shape: Shape,
488    ) -> NodeId {
489        let mut inputs = vec![init, trajectory, upstream];
490        inputs.extend_from_slice(xs);
491        self.push(
492            Op::ScanBackward {
493                body_vjp: Box::new(body_vjp),
494                length,
495                save_trajectory,
496                num_xs: xs.len() as u32,
497                num_checkpoints,
498                forward_body: forward_body.map(Box::new),
499            },
500            inputs,
501            out_shape,
502            None,
503        )
504    }
505
506    /// Per-step xs gradient companion to [`Self::scan_backward`].
507    /// Same inputs and same `body_vjp` graph, plus an `xs_idx`
508    /// selecting which body_vjp output to stack into the result.
509    /// Output shape is `[length, *per_step_xs_shape]`.
510    pub fn scan_backward_xs(
511        &mut self,
512        init: NodeId,
513        trajectory: NodeId,
514        upstream: NodeId,
515        xs: &[NodeId],
516        body_vjp: Graph,
517        length: u32,
518        save_trajectory: bool,
519        xs_idx: u32,
520        out_shape: Shape,
521    ) -> NodeId {
522        self.scan_backward_xs_with_checkpoints(
523            init,
524            trajectory,
525            upstream,
526            xs,
527            body_vjp,
528            length,
529            save_trajectory,
530            xs_idx,
531            0,
532            None,
533            out_shape,
534        )
535    }
536
537    #[allow(clippy::too_many_arguments)]
538    pub fn scan_backward_xs_with_checkpoints(
539        &mut self,
540        init: NodeId,
541        trajectory: NodeId,
542        upstream: NodeId,
543        xs: &[NodeId],
544        body_vjp: Graph,
545        length: u32,
546        save_trajectory: bool,
547        xs_idx: u32,
548        num_checkpoints: u32,
549        forward_body: Option<Graph>,
550        out_shape: Shape,
551    ) -> NodeId {
552        let mut inputs = vec![init, trajectory, upstream];
553        inputs.extend_from_slice(xs);
554        self.push(
555            Op::ScanBackwardXs {
556                body_vjp: Box::new(body_vjp),
557                length,
558                save_trajectory,
559                num_xs: xs.len() as u32,
560                xs_idx,
561                num_checkpoints,
562                forward_body: forward_body.map(Box::new),
563            },
564            inputs,
565            out_shape,
566            None,
567        )
568    }
569
570    /// User-defined sub-graph with optional override AD rules.
571    /// JAX-shaped `custom_vjp` / `custom_jvp` — see [`Op::CustomFn`].
572    ///
573    /// `inputs.len()` must equal the number of `Op::Input` nodes in
574    /// `fwd_body`. Output shape is inferred from `fwd_body`'s declared
575    /// output. When supplied, `vjp_body` and `jvp_body` must follow the
576    /// conventions documented on [`Op::CustomFn`] (special-named
577    /// `"primal_output"` / `"d_output"` / `"tangent_*"` Inputs).
578    pub fn custom_fn(
579        &mut self,
580        inputs: Vec<NodeId>,
581        fwd_body: Graph,
582        vjp_body: Option<Graph>,
583        jvp_body: Option<Graph>,
584    ) -> NodeId {
585        let n_in = inputs.len();
586        // Count fwd_body's primal Inputs (no special names — fwd has none).
587        let fwd_inputs: usize = fwd_body
588            .nodes()
589            .iter()
590            .filter(|n| matches!(n.op, Op::Input { .. }))
591            .count();
592        assert_eq!(
593            fwd_inputs, n_in,
594            "custom_fn: fwd_body has {fwd_inputs} Op::Input(s); outer call \
595             provides {n_in}. Counts must match.",
596        );
597        let fwd_out_id = fwd_body
598            .outputs
599            .first()
600            .copied()
601            .expect("custom_fn: fwd_body must declare exactly one output");
602        let out_shape = fwd_body.node(fwd_out_id).shape.clone();
603
604        if let Some(vjp) = vjp_body.as_ref() {
605            let primal_count = vjp
606                .nodes()
607                .iter()
608                .filter(|n| {
609                    matches!(&n.op,
610                    Op::Input { name } if name != "primal_output" && name != "d_output")
611                })
612                .count();
613            assert_eq!(
614                primal_count, n_in,
615                "custom_fn: vjp_body has {primal_count} primal Op::Input(s) \
616                 (excluding 'primal_output' / 'd_output'); expected {n_in}",
617            );
618            let has_primal_out = vjp
619                .nodes()
620                .iter()
621                .any(|n| matches!(&n.op, Op::Input { name } if name == "primal_output"));
622            let has_d_output = vjp
623                .nodes()
624                .iter()
625                .any(|n| matches!(&n.op, Op::Input { name } if name == "d_output"));
626            assert!(
627                has_primal_out,
628                "custom_fn: vjp_body must declare an Op::Input named 'primal_output'"
629            );
630            assert!(
631                has_d_output,
632                "custom_fn: vjp_body must declare an Op::Input named 'd_output'"
633            );
634            assert_eq!(
635                vjp.outputs.len(),
636                n_in,
637                "custom_fn: vjp_body has {} outputs; expected {n_in} \
638                 (one gradient per primal input)",
639                vjp.outputs.len(),
640            );
641        }
642        if let Some(jvp) = jvp_body.as_ref() {
643            let primal_count = jvp
644                .nodes()
645                .iter()
646                .filter(|n| {
647                    matches!(&n.op,
648                    Op::Input { name }
649                        if !name.starts_with("tangent_") && name != "primal_output")
650                })
651                .count();
652            assert_eq!(
653                primal_count, n_in,
654                "custom_fn: jvp_body has {primal_count} primal Op::Input(s) \
655                 (excluding 'primal_output' / 'tangent_*'); expected {n_in}",
656            );
657            for i in 0..n_in {
658                let want = format!("tangent_{i}");
659                let has = jvp
660                    .nodes()
661                    .iter()
662                    .any(|n| matches!(&n.op, Op::Input { name } if name == &want));
663                assert!(
664                    has,
665                    "custom_fn: jvp_body must declare an Op::Input named '{want}'"
666                );
667            }
668            assert_eq!(
669                jvp.outputs.len(),
670                1,
671                "custom_fn: jvp_body has {} outputs; expected 1 (output tangent)",
672                jvp.outputs.len(),
673            );
674        }
675
676        self.push(
677            Op::CustomFn {
678                fwd_body: Box::new(fwd_body),
679                vjp_body: vjp_body.map(Box::new),
680                jvp_body: jvp_body.map(Box::new),
681                num_inputs: n_in as u32,
682            },
683            inputs,
684            out_shape,
685            None,
686        )
687    }
688
689    /// Multi-output `custom_fn` via the **concat-with-Narrow** design:
690    /// rewrites `fwd_body` to flatten + concat its `K` declared outputs
691    /// into a single 1-D F32 output, wraps that as [`Op::CustomFn`],
692    /// and returns a [`MultiOutputHandle`] the caller uses to extract
693    /// each sub-output via `Op::Narrow` + `Op::Reshape`.
694    ///
695    /// Per PLAN line 484, this avoids rewriting rlx's "1 Op = 1 output"
696    /// IR contract: the wrapped Op::CustomFn still has one output (the
697    /// flat concat), and `MultiOutputHandle::output(g, i)` materializes
698    /// component `i` lazily on the outer graph.
699    ///
700    /// Constraints (MVP):
701    /// - All sub-outputs must be `DType::F32`. Tuples-of-mixed-dtype
702    ///   need either a per-dtype split or a future tuple-type
703    ///   extension.
704    /// - All sub-output shapes must be statically known (no
705    ///   `Dim::Dynamic`).
706    /// - `vjp_body` / `jvp_body` aren't yet rewritten through the
707    ///   concat — caller must provide bodies that already expect
708    ///   the flat-concat output convention if they need custom AD.
709    pub fn custom_fn_multi(
710        &mut self,
711        inputs: Vec<NodeId>,
712        mut fwd_body: Graph,
713    ) -> MultiOutputHandle {
714        use crate::op::BinaryOp;
715        // Snapshot the original outputs + their shapes BEFORE
716        // appending concat ops. Outputs land at the end of the graph;
717        // we'll replace them.
718        let original_outputs = fwd_body.outputs.clone();
719        assert!(
720            !original_outputs.is_empty(),
721            "custom_fn_multi: fwd_body must have ≥ 1 declared output"
722        );
723        let mut sub_shapes: Vec<Shape> = Vec::with_capacity(original_outputs.len());
724        let mut offsets: Vec<usize> = Vec::with_capacity(original_outputs.len());
725        let mut total_len: usize = 0;
726        for &out_id in &original_outputs {
727            let s = fwd_body.node(out_id).shape.clone();
728            assert_eq!(
729                s.dtype(),
730                DType::F32,
731                "custom_fn_multi MVP: all sub-outputs must be F32, got {:?} \
732                 (sub-output #{})",
733                s.dtype(),
734                sub_shapes.len()
735            );
736            let n_elems: usize = s
737                .dims()
738                .iter()
739                .map(|d| match d {
740                    Dim::Static(k) => *k,
741                    Dim::Dynamic(_) => {
742                        panic!("custom_fn_multi MVP: dynamic dims not supported")
743                    }
744                })
745                .product();
746            offsets.push(total_len);
747            total_len += n_elems;
748            sub_shapes.push(s);
749        }
750        // Flatten each sub-output to [n_elems] and concat along axis 0.
751        let mut flats: Vec<NodeId> = Vec::with_capacity(original_outputs.len());
752        for (out_id, sh) in original_outputs.iter().zip(sub_shapes.iter()) {
753            let n: usize = sh
754                .dims()
755                .iter()
756                .map(|d| match d {
757                    Dim::Static(k) => *k,
758                    Dim::Dynamic(_) => unreachable!(),
759                })
760                .product();
761            let flat_shape = Shape::from_dims(&[Dim::Static(n)], DType::F32);
762            let flat = fwd_body.add_node(
763                Op::Reshape {
764                    new_shape: vec![n as i64],
765                },
766                vec![*out_id],
767                flat_shape,
768            );
769            flats.push(flat);
770        }
771        let concat_shape = Shape::from_dims(&[Dim::Static(total_len)], DType::F32);
772        let concat = fwd_body.add_node(Op::Concat { axis: 0 }, flats.clone(), concat_shape);
773        let _ = BinaryOp::Add; // import preserved if we extend later
774        fwd_body.set_outputs(vec![concat]);
775
776        // Now build the outer custom_fn with the rewritten body. Reuses
777        // the single-output asserts; flat concat satisfies them.
778        let source = self.custom_fn(inputs, fwd_body, None, None);
779
780        MultiOutputHandle {
781            source,
782            sub_shapes,
783            offsets,
784        }
785    }
786
787    /// Bounded scan returning the stacked trajectory.
788    /// Output shape is `[length, *init.shape]` — row `t` is the carry
789    /// after step `t+1`, so row `length-1` equals the result of plain
790    /// [`Self::scan`].
791    pub fn scan_trajectory(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
792        let init_shape = self.shape(init).clone();
793        let mut traj_dims: Vec<crate::Dim> = Vec::with_capacity(init_shape.rank() + 1);
794        traj_dims.push(crate::Dim::Static(length as usize));
795        for i in 0..init_shape.rank() {
796            traj_dims.push(init_shape.dim(i));
797        }
798        let traj_shape = crate::Shape::from_dims(&traj_dims, init_shape.dtype());
799        self.push(
800            Op::Scan {
801                body: Box::new(body),
802                length,
803                save_trajectory: true,
804                num_xs: 0,
805                num_bcast: 0,
806                num_checkpoints: 0,
807            },
808            vec![init],
809            traj_shape,
810            None,
811        )
812    }
813}