rlx_ir/ops/special.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Specialised builders: SSM selective scan + space for future
17//! exotic ops (plan #53).
18
19use crate::shape::Dim;
20use crate::{DType, Graph, NodeId, Op, Shape};
21
22/// Handle to a multi-output [`Op::CustomFn`] built via
23/// [`Graph::custom_fn_multi`]. Internally the op produces a flat 1-D
24/// concatenated output; this handle remembers each sub-output's
25/// offset + original shape so [`Self::output`] can materialize
26/// component `i` lazily via `Op::Narrow` + `Op::Reshape`.
27#[derive(Debug, Clone)]
28pub struct MultiOutputHandle {
29 /// NodeId of the wrapped Op::CustomFn (1-D F32, length =
30 /// `Σ sub_shapes[i].num_elements`).
31 pub source: NodeId,
32 /// Original per-sub-output shapes (in declaration order).
33 pub sub_shapes: Vec<Shape>,
34 /// Per-sub-output start offsets into `source` (element-counted,
35 /// not byte-counted).
36 pub offsets: Vec<usize>,
37}
38
39impl MultiOutputHandle {
40 /// Number of sub-outputs.
41 pub fn n_outputs(&self) -> usize {
42 self.sub_shapes.len()
43 }
44
45 /// Materialize sub-output `idx` as an outer-graph NodeId.
46 /// Internally: `Op::Narrow(source, axis=0, start=offsets[idx],
47 /// len=numel(sub_shapes[idx]))` → `Op::Reshape` back to the
48 /// declared shape.
49 pub fn output(&self, g: &mut Graph, idx: usize) -> NodeId {
50 assert!(idx < self.sub_shapes.len(), "output index out of range");
51 let sub = &self.sub_shapes[idx];
52 let n_elems: usize = sub
53 .dims()
54 .iter()
55 .map(|d| match d {
56 Dim::Static(k) => *k,
57 Dim::Dynamic(_) => panic!("dynamic sub-output dim"),
58 })
59 .product();
60 let flat_shape = Shape::from_dims(&[Dim::Static(n_elems)], sub.dtype());
61 let narrowed = g.add_node(
62 Op::Narrow {
63 axis: 0,
64 start: self.offsets[idx],
65 len: n_elems,
66 },
67 vec![self.source],
68 flat_shape,
69 );
70 if sub.rank() == 1 {
71 // Already the right shape.
72 narrowed
73 } else {
74 let dims: Vec<i64> = sub
75 .dims()
76 .iter()
77 .map(|d| match d {
78 Dim::Static(k) => *k as i64,
79 Dim::Dynamic(_) => unreachable!(),
80 })
81 .collect();
82 g.add_node(Op::Reshape { new_shape: dims }, vec![narrowed], sub.clone())
83 }
84 }
85}
86
87impl Graph {
88 /// Mamba-style selective scan: y = SSM(x, Δ, A, B, C).
89 /// Inputs: x \[b,s,h\], delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
90 /// Output \[b,s,h\]. n is the state size.
91 pub fn selective_scan(
92 &mut self,
93 x: NodeId,
94 delta: NodeId,
95 a: NodeId,
96 b: NodeId,
97 c: NodeId,
98 state_size: usize,
99 shape: Shape,
100 ) -> NodeId {
101 self.push(
102 Op::SelectiveScan { state_size },
103 vec![x, delta, a, b, c],
104 shape,
105 None,
106 )
107 }
108
109 /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk,
110 /// Qwen3-Next, Kimi-Linear). See [`Op::GatedDeltaNet`] for the
111 /// recurrence math. All five inputs are `f32`. Shapes:
112 /// `q,k,v`: `[b, s, h_v, n]`; `g,beta`: `[b, s, h_v]`. Output:
113 /// `[b, s, h_v, n]`. State is implicit (reset per batch) unless
114 /// `carry_state` is set — then pass `state` as a sixth input.
115 pub fn gated_delta_net(
116 &mut self,
117 q: NodeId,
118 k: NodeId,
119 v: NodeId,
120 g: NodeId,
121 beta: NodeId,
122 state_size: usize,
123 shape: Shape,
124 ) -> NodeId {
125 self.push(
126 Op::GatedDeltaNet {
127 state_size,
128 carry_state: false,
129 },
130 vec![q, k, v, g, beta],
131 shape,
132 None,
133 )
134 }
135
136 /// Same as [`Self::gated_delta_net`] but threads `state`
137 /// `[b, h_v, n, n]` in/out for decode-mode recurrence.
138 pub fn gated_delta_net_carry(
139 &mut self,
140 q: NodeId,
141 k: NodeId,
142 v: NodeId,
143 g: NodeId,
144 beta: NodeId,
145 state: NodeId,
146 state_size: usize,
147 shape: Shape,
148 ) -> NodeId {
149 self.push(
150 Op::GatedDeltaNet {
151 state_size,
152 carry_state: true,
153 },
154 vec![q, k, v, g, beta, state],
155 shape,
156 None,
157 )
158 }
159
160 /// Multi-layer (optionally bidirectional) LSTM with packed weights.
161 /// See [`Op::Lstm`] for the gate math, weight packing, and shapes.
162 /// Output `shape` = `[batch, seq, D*hidden]` (`D = 2` iff
163 /// `bidirectional`). Initial `h0`/`c0` are zero.
164 #[allow(clippy::too_many_arguments)]
165 pub fn lstm(
166 &mut self,
167 x: NodeId,
168 w_ih: NodeId,
169 w_hh: NodeId,
170 bias: NodeId,
171 hidden_size: usize,
172 num_layers: usize,
173 bidirectional: bool,
174 shape: Shape,
175 ) -> NodeId {
176 self.push(
177 Op::Lstm {
178 hidden_size,
179 num_layers,
180 bidirectional,
181 carry: false,
182 },
183 vec![x, w_ih, w_hh, bias],
184 shape,
185 None,
186 )
187 }
188
189 /// Same as [`Self::lstm`] but threads decode state: `h0`/`c0`
190 /// `[L*D, batch, hidden]` in, `hn`/`cn` written back in place.
191 #[allow(clippy::too_many_arguments)]
192 pub fn lstm_carry(
193 &mut self,
194 x: NodeId,
195 w_ih: NodeId,
196 w_hh: NodeId,
197 bias: NodeId,
198 h0: NodeId,
199 c0: NodeId,
200 hidden_size: usize,
201 num_layers: usize,
202 bidirectional: bool,
203 shape: Shape,
204 ) -> NodeId {
205 self.push(
206 Op::Lstm {
207 hidden_size,
208 num_layers,
209 bidirectional,
210 carry: true,
211 },
212 vec![x, w_ih, w_hh, bias, h0, c0],
213 shape,
214 None,
215 )
216 }
217
218 /// Multi-layer (optionally bidirectional) GRU. Inputs
219 /// `[x, w_ih, w_hh, b_ih, b_hh]` (`+ [h0]` when `carry`). See
220 /// [`Op::Gru`]. Output `shape` = `[batch, seq, D*hidden]`.
221 #[allow(clippy::too_many_arguments)]
222 pub fn gru(
223 &mut self,
224 x: NodeId,
225 w_ih: NodeId,
226 w_hh: NodeId,
227 b_ih: NodeId,
228 b_hh: NodeId,
229 h0: Option<NodeId>,
230 hidden_size: usize,
231 num_layers: usize,
232 bidirectional: bool,
233 shape: Shape,
234 ) -> NodeId {
235 let mut inputs = vec![x, w_ih, w_hh, b_ih, b_hh];
236 let carry = h0.is_some();
237 if let Some(h) = h0 {
238 inputs.push(h);
239 }
240 self.push(
241 Op::Gru {
242 hidden_size,
243 num_layers,
244 bidirectional,
245 carry,
246 },
247 inputs,
248 shape,
249 None,
250 )
251 }
252
253 /// Multi-layer (optionally bidirectional) Elman RNN. Inputs
254 /// `[x, w_ih, w_hh, bias]` (`+ [h0]` when `carry`). `relu` selects the
255 /// activation (else tanh). See [`Op::Rnn`]. Output `[batch, seq, D*hidden]`.
256 #[allow(clippy::too_many_arguments)]
257 pub fn rnn(
258 &mut self,
259 x: NodeId,
260 w_ih: NodeId,
261 w_hh: NodeId,
262 bias: NodeId,
263 h0: Option<NodeId>,
264 hidden_size: usize,
265 num_layers: usize,
266 bidirectional: bool,
267 relu: bool,
268 shape: Shape,
269 ) -> NodeId {
270 let mut inputs = vec![x, w_ih, w_hh, bias];
271 let carry = h0.is_some();
272 if let Some(h) = h0 {
273 inputs.push(h);
274 }
275 self.push(
276 Op::Rnn {
277 hidden_size,
278 num_layers,
279 bidirectional,
280 carry,
281 relu,
282 },
283 inputs,
284 shape,
285 None,
286 )
287 }
288
289 /// Mamba-2 / SSD scalar-decay SSM scan. See [`Op::Mamba2`]. Inputs
290 /// `[x, dt, a, b, c]`; output `shape` = `x` shape `[B,S,H,P]`.
291 pub fn mamba2(
292 &mut self,
293 x: NodeId,
294 dt: NodeId,
295 a: NodeId,
296 b: NodeId,
297 c: NodeId,
298 head_dim: usize,
299 state_size: usize,
300 shape: Shape,
301 ) -> NodeId {
302 self.push(
303 Op::Mamba2 {
304 head_dim,
305 state_size,
306 },
307 vec![x, dt, a, b, c],
308 shape,
309 None,
310 )
311 }
312
313 /// Bounded scan returning the final carry. Body must have exactly
314 /// one `Op::Input` (the carry) and one output, both same shape as
315 /// `init`. Output shape matches `init`.
316 pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
317 let init_shape = self.shape(init).clone();
318 self.push(
319 Op::Scan {
320 body: Box::new(body),
321 length,
322 save_trajectory: false,
323 num_bcast: 0,
324 num_xs: 0,
325 num_checkpoints: 0,
326 },
327 vec![init],
328 init_shape,
329 None,
330 )
331 }
332
333 /// Bounded scan with recursive checkpointing for memory-bounded
334 /// backward AD. Equivalent to [`Self::scan`] for the forward
335 /// computation, but during backward only `num_checkpoints` carry
336 /// values are cached; intermediate carries are recomputed via the
337 /// body. Memory: `O(num_checkpoints · carry_size)`. Time: forward
338 /// unchanged; backward `O(length)` (segment-cached).
339 ///
340 /// The AD pre-pass propagates `num_checkpoints` into the rewritten
341 /// trajectory-saving Scan and into the emitted ScanBackward, so a
342 /// single call to [`crate::Graph::scan_checkpointed`] is enough
343 /// to enable the memory bound across the whole forward+backward
344 /// pipeline.
345 pub fn scan_checkpointed(
346 &mut self,
347 init: NodeId,
348 body: Graph,
349 length: u32,
350 num_checkpoints: u32,
351 ) -> NodeId {
352 assert!(
353 num_checkpoints > 0 && num_checkpoints <= length,
354 "scan_checkpointed: num_checkpoints={num_checkpoints} \
355 must be in 1..=length={length}"
356 );
357 let init_shape = self.shape(init).clone();
358 self.push(
359 Op::Scan {
360 body: Box::new(body),
361 length,
362 save_trajectory: false,
363 num_bcast: 0,
364 num_xs: 0,
365 num_checkpoints,
366 },
367 vec![init],
368 init_shape,
369 None,
370 )
371 }
372
373 /// Bounded scan with broadcast and per-step inputs.
374 ///
375 /// Body `Op::Input`s in NodeId order: `[carry, bcast_0..bcast_{B-1},
376 /// x_t_0..x_t_{X-1}]`. Bcast inputs keep their natural shape (the
377 /// CPU executor fills them once before the scan loop). xs\[i\] has
378 /// shape `[length, *per_step]` and the body sees `xs[i][t]` per
379 /// iteration. Output shape matches `init`.
380 pub fn scan_with_bcasts_and_xs(
381 &mut self,
382 init: NodeId,
383 bcasts: &[NodeId],
384 xs: &[NodeId],
385 body: Graph,
386 length: u32,
387 ) -> NodeId {
388 let init_shape = self.shape(init).clone();
389 let mut inputs = vec![init];
390 inputs.extend_from_slice(bcasts);
391 inputs.extend_from_slice(xs);
392 self.push(
393 Op::Scan {
394 body: Box::new(body),
395 length,
396 save_trajectory: false,
397 num_bcast: bcasts.len() as u32,
398 num_xs: xs.len() as u32,
399 num_checkpoints: 0,
400 },
401 inputs,
402 init_shape,
403 None,
404 )
405 }
406
407 /// Bounded scan with per-step `xs` inputs returning the final carry.
408 /// Body has `1 + xs.len()` Op::Inputs in NodeId construction order
409 /// (first declared is the carry; the remaining match `xs` in order).
410 /// Each `xs[i]` has shape `[length, *per_step_shape_i]`; the body
411 /// sees a `per_step_shape_i` slice on iteration `t`.
412 pub fn scan_with_xs(
413 &mut self,
414 init: NodeId,
415 xs: &[NodeId],
416 body: Graph,
417 length: u32,
418 ) -> NodeId {
419 let init_shape = self.shape(init).clone();
420 let mut inputs = vec![init];
421 inputs.extend_from_slice(xs);
422 self.push(
423 Op::Scan {
424 body: Box::new(body),
425 length,
426 save_trajectory: false,
427 num_bcast: 0,
428 num_xs: xs.len() as u32,
429 num_checkpoints: 0,
430 },
431 inputs,
432 init_shape,
433 None,
434 )
435 }
436
437 /// Reverse-mode AD companion to [`Self::scan`] /
438 /// [`Self::scan_trajectory`]. Typically constructed by the
439 /// autodiff pass, not by hand.
440 ///
441 /// `xs` is the list of per-step input tensors (must match the
442 /// forward Op::Scan's xs in count, order, and per-step shape).
443 /// Body_vjp's `1 + xs.len() + 1` Op::Inputs match the forward
444 /// body's inputs plus a fresh `"d_output"` Input.
445 pub fn scan_backward(
446 &mut self,
447 init: NodeId,
448 trajectory: NodeId,
449 upstream: NodeId,
450 xs: &[NodeId],
451 body_vjp: Graph,
452 length: u32,
453 save_trajectory: bool,
454 out_shape: Shape,
455 ) -> NodeId {
456 self.scan_backward_with_checkpoints(
457 init,
458 trajectory,
459 upstream,
460 xs,
461 body_vjp,
462 length,
463 save_trajectory,
464 0,
465 None,
466 out_shape,
467 )
468 }
469
470 /// Lower-level `scan_backward` with explicit checkpointing config.
471 /// `num_checkpoints == 0` (default) means no checkpointing — the
472 /// trajectory cache holds every step's carry. `0 < K < length`
473 /// enables segment-cached recompute via `forward_body` (must be
474 /// `Some`).
475 #[allow(clippy::too_many_arguments)]
476 pub fn scan_backward_with_checkpoints(
477 &mut self,
478 init: NodeId,
479 trajectory: NodeId,
480 upstream: NodeId,
481 xs: &[NodeId],
482 body_vjp: Graph,
483 length: u32,
484 save_trajectory: bool,
485 num_checkpoints: u32,
486 forward_body: Option<Graph>,
487 out_shape: Shape,
488 ) -> NodeId {
489 let mut inputs = vec![init, trajectory, upstream];
490 inputs.extend_from_slice(xs);
491 self.push(
492 Op::ScanBackward {
493 body_vjp: Box::new(body_vjp),
494 length,
495 save_trajectory,
496 num_xs: xs.len() as u32,
497 num_checkpoints,
498 forward_body: forward_body.map(Box::new),
499 },
500 inputs,
501 out_shape,
502 None,
503 )
504 }
505
506 /// Per-step xs gradient companion to [`Self::scan_backward`].
507 /// Same inputs and same `body_vjp` graph, plus an `xs_idx`
508 /// selecting which body_vjp output to stack into the result.
509 /// Output shape is `[length, *per_step_xs_shape]`.
510 pub fn scan_backward_xs(
511 &mut self,
512 init: NodeId,
513 trajectory: NodeId,
514 upstream: NodeId,
515 xs: &[NodeId],
516 body_vjp: Graph,
517 length: u32,
518 save_trajectory: bool,
519 xs_idx: u32,
520 out_shape: Shape,
521 ) -> NodeId {
522 self.scan_backward_xs_with_checkpoints(
523 init,
524 trajectory,
525 upstream,
526 xs,
527 body_vjp,
528 length,
529 save_trajectory,
530 xs_idx,
531 0,
532 None,
533 out_shape,
534 )
535 }
536
537 #[allow(clippy::too_many_arguments)]
538 pub fn scan_backward_xs_with_checkpoints(
539 &mut self,
540 init: NodeId,
541 trajectory: NodeId,
542 upstream: NodeId,
543 xs: &[NodeId],
544 body_vjp: Graph,
545 length: u32,
546 save_trajectory: bool,
547 xs_idx: u32,
548 num_checkpoints: u32,
549 forward_body: Option<Graph>,
550 out_shape: Shape,
551 ) -> NodeId {
552 let mut inputs = vec![init, trajectory, upstream];
553 inputs.extend_from_slice(xs);
554 self.push(
555 Op::ScanBackwardXs {
556 body_vjp: Box::new(body_vjp),
557 length,
558 save_trajectory,
559 num_xs: xs.len() as u32,
560 xs_idx,
561 num_checkpoints,
562 forward_body: forward_body.map(Box::new),
563 },
564 inputs,
565 out_shape,
566 None,
567 )
568 }
569
570 /// User-defined sub-graph with optional override AD rules.
571 /// JAX-shaped `custom_vjp` / `custom_jvp` — see [`Op::CustomFn`].
572 ///
573 /// `inputs.len()` must equal the number of `Op::Input` nodes in
574 /// `fwd_body`. Output shape is inferred from `fwd_body`'s declared
575 /// output. When supplied, `vjp_body` and `jvp_body` must follow the
576 /// conventions documented on [`Op::CustomFn`] (special-named
577 /// `"primal_output"` / `"d_output"` / `"tangent_*"` Inputs).
578 pub fn custom_fn(
579 &mut self,
580 inputs: Vec<NodeId>,
581 fwd_body: Graph,
582 vjp_body: Option<Graph>,
583 jvp_body: Option<Graph>,
584 ) -> NodeId {
585 let n_in = inputs.len();
586 // Count fwd_body's primal Inputs (no special names — fwd has none).
587 let fwd_inputs: usize = fwd_body
588 .nodes()
589 .iter()
590 .filter(|n| matches!(n.op, Op::Input { .. }))
591 .count();
592 assert_eq!(
593 fwd_inputs, n_in,
594 "custom_fn: fwd_body has {fwd_inputs} Op::Input(s); outer call \
595 provides {n_in}. Counts must match.",
596 );
597 let fwd_out_id = fwd_body
598 .outputs
599 .first()
600 .copied()
601 .expect("custom_fn: fwd_body must declare exactly one output");
602 let out_shape = fwd_body.node(fwd_out_id).shape.clone();
603
604 if let Some(vjp) = vjp_body.as_ref() {
605 let primal_count = vjp
606 .nodes()
607 .iter()
608 .filter(|n| {
609 matches!(&n.op,
610 Op::Input { name } if name != "primal_output" && name != "d_output")
611 })
612 .count();
613 assert_eq!(
614 primal_count, n_in,
615 "custom_fn: vjp_body has {primal_count} primal Op::Input(s) \
616 (excluding 'primal_output' / 'd_output'); expected {n_in}",
617 );
618 let has_primal_out = vjp
619 .nodes()
620 .iter()
621 .any(|n| matches!(&n.op, Op::Input { name } if name == "primal_output"));
622 let has_d_output = vjp
623 .nodes()
624 .iter()
625 .any(|n| matches!(&n.op, Op::Input { name } if name == "d_output"));
626 assert!(
627 has_primal_out,
628 "custom_fn: vjp_body must declare an Op::Input named 'primal_output'"
629 );
630 assert!(
631 has_d_output,
632 "custom_fn: vjp_body must declare an Op::Input named 'd_output'"
633 );
634 assert_eq!(
635 vjp.outputs.len(),
636 n_in,
637 "custom_fn: vjp_body has {} outputs; expected {n_in} \
638 (one gradient per primal input)",
639 vjp.outputs.len(),
640 );
641 }
642 if let Some(jvp) = jvp_body.as_ref() {
643 let primal_count = jvp
644 .nodes()
645 .iter()
646 .filter(|n| {
647 matches!(&n.op,
648 Op::Input { name }
649 if !name.starts_with("tangent_") && name != "primal_output")
650 })
651 .count();
652 assert_eq!(
653 primal_count, n_in,
654 "custom_fn: jvp_body has {primal_count} primal Op::Input(s) \
655 (excluding 'primal_output' / 'tangent_*'); expected {n_in}",
656 );
657 for i in 0..n_in {
658 let want = format!("tangent_{i}");
659 let has = jvp
660 .nodes()
661 .iter()
662 .any(|n| matches!(&n.op, Op::Input { name } if name == &want));
663 assert!(
664 has,
665 "custom_fn: jvp_body must declare an Op::Input named '{want}'"
666 );
667 }
668 assert_eq!(
669 jvp.outputs.len(),
670 1,
671 "custom_fn: jvp_body has {} outputs; expected 1 (output tangent)",
672 jvp.outputs.len(),
673 );
674 }
675
676 self.push(
677 Op::CustomFn {
678 fwd_body: Box::new(fwd_body),
679 vjp_body: vjp_body.map(Box::new),
680 jvp_body: jvp_body.map(Box::new),
681 num_inputs: n_in as u32,
682 },
683 inputs,
684 out_shape,
685 None,
686 )
687 }
688
689 /// Multi-output `custom_fn` via the **concat-with-Narrow** design:
690 /// rewrites `fwd_body` to flatten + concat its `K` declared outputs
691 /// into a single 1-D F32 output, wraps that as [`Op::CustomFn`],
692 /// and returns a [`MultiOutputHandle`] the caller uses to extract
693 /// each sub-output via `Op::Narrow` + `Op::Reshape`.
694 ///
695 /// Per PLAN line 484, this avoids rewriting rlx's "1 Op = 1 output"
696 /// IR contract: the wrapped Op::CustomFn still has one output (the
697 /// flat concat), and `MultiOutputHandle::output(g, i)` materializes
698 /// component `i` lazily on the outer graph.
699 ///
700 /// Constraints (MVP):
701 /// - All sub-outputs must be `DType::F32`. Tuples-of-mixed-dtype
702 /// need either a per-dtype split or a future tuple-type
703 /// extension.
704 /// - All sub-output shapes must be statically known (no
705 /// `Dim::Dynamic`).
706 /// - `vjp_body` / `jvp_body` aren't yet rewritten through the
707 /// concat — caller must provide bodies that already expect
708 /// the flat-concat output convention if they need custom AD.
709 pub fn custom_fn_multi(
710 &mut self,
711 inputs: Vec<NodeId>,
712 mut fwd_body: Graph,
713 ) -> MultiOutputHandle {
714 use crate::op::BinaryOp;
715 // Snapshot the original outputs + their shapes BEFORE
716 // appending concat ops. Outputs land at the end of the graph;
717 // we'll replace them.
718 let original_outputs = fwd_body.outputs.clone();
719 assert!(
720 !original_outputs.is_empty(),
721 "custom_fn_multi: fwd_body must have ≥ 1 declared output"
722 );
723 let mut sub_shapes: Vec<Shape> = Vec::with_capacity(original_outputs.len());
724 let mut offsets: Vec<usize> = Vec::with_capacity(original_outputs.len());
725 let mut total_len: usize = 0;
726 for &out_id in &original_outputs {
727 let s = fwd_body.node(out_id).shape.clone();
728 assert_eq!(
729 s.dtype(),
730 DType::F32,
731 "custom_fn_multi MVP: all sub-outputs must be F32, got {:?} \
732 (sub-output #{})",
733 s.dtype(),
734 sub_shapes.len()
735 );
736 let n_elems: usize = s
737 .dims()
738 .iter()
739 .map(|d| match d {
740 Dim::Static(k) => *k,
741 Dim::Dynamic(_) => {
742 panic!("custom_fn_multi MVP: dynamic dims not supported")
743 }
744 })
745 .product();
746 offsets.push(total_len);
747 total_len += n_elems;
748 sub_shapes.push(s);
749 }
750 // Flatten each sub-output to [n_elems] and concat along axis 0.
751 let mut flats: Vec<NodeId> = Vec::with_capacity(original_outputs.len());
752 for (out_id, sh) in original_outputs.iter().zip(sub_shapes.iter()) {
753 let n: usize = sh
754 .dims()
755 .iter()
756 .map(|d| match d {
757 Dim::Static(k) => *k,
758 Dim::Dynamic(_) => unreachable!(),
759 })
760 .product();
761 let flat_shape = Shape::from_dims(&[Dim::Static(n)], DType::F32);
762 let flat = fwd_body.add_node(
763 Op::Reshape {
764 new_shape: vec![n as i64],
765 },
766 vec![*out_id],
767 flat_shape,
768 );
769 flats.push(flat);
770 }
771 let concat_shape = Shape::from_dims(&[Dim::Static(total_len)], DType::F32);
772 let concat = fwd_body.add_node(Op::Concat { axis: 0 }, flats.clone(), concat_shape);
773 let _ = BinaryOp::Add; // import preserved if we extend later
774 fwd_body.set_outputs(vec![concat]);
775
776 // Now build the outer custom_fn with the rewritten body. Reuses
777 // the single-output asserts; flat concat satisfies them.
778 let source = self.custom_fn(inputs, fwd_body, None, None);
779
780 MultiOutputHandle {
781 source,
782 sub_shapes,
783 offsets,
784 }
785 }
786
787 /// Bounded scan returning the stacked trajectory.
788 /// Output shape is `[length, *init.shape]` — row `t` is the carry
789 /// after step `t+1`, so row `length-1` equals the result of plain
790 /// [`Self::scan`].
791 pub fn scan_trajectory(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
792 let init_shape = self.shape(init).clone();
793 let mut traj_dims: Vec<crate::Dim> = Vec::with_capacity(init_shape.rank() + 1);
794 traj_dims.push(crate::Dim::Static(length as usize));
795 for i in 0..init_shape.rank() {
796 traj_dims.push(init_shape.dim(i));
797 }
798 let traj_shape = crate::Shape::from_dims(&traj_dims, init_shape.dtype());
799 self.push(
800 Op::Scan {
801 body: Box::new(body),
802 length,
803 save_trajectory: true,
804 num_xs: 0,
805 num_bcast: 0,
806 num_checkpoints: 0,
807 },
808 vec![init],
809 traj_shape,
810 None,
811 )
812 }
813}