rlx_ir/ops/special.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Specialised builders: SSM selective scan + space for future
17//! exotic ops (plan #53).
18
19use crate::shape::Dim;
20use crate::{DType, Graph, NodeId, Op, Shape};
21
22/// Handle to a multi-output [`Op::CustomFn`] built via
23/// [`Graph::custom_fn_multi`]. Internally the op produces a flat 1-D
24/// concatenated output; this handle remembers each sub-output's
25/// offset + original shape so [`Self::output`] can materialize
26/// component `i` lazily via `Op::Narrow` + `Op::Reshape`.
27#[derive(Debug, Clone)]
28pub struct MultiOutputHandle {
29 /// NodeId of the wrapped Op::CustomFn (1-D F32, length =
30 /// `Σ sub_shapes[i].num_elements`).
31 pub source: NodeId,
32 /// Original per-sub-output shapes (in declaration order).
33 pub sub_shapes: Vec<Shape>,
34 /// Per-sub-output start offsets into `source` (element-counted,
35 /// not byte-counted).
36 pub offsets: Vec<usize>,
37}
38
39impl MultiOutputHandle {
40 /// Number of sub-outputs.
41 pub fn n_outputs(&self) -> usize {
42 self.sub_shapes.len()
43 }
44
45 /// Materialize sub-output `idx` as an outer-graph NodeId.
46 /// Internally: `Op::Narrow(source, axis=0, start=offsets[idx],
47 /// len=numel(sub_shapes[idx]))` → `Op::Reshape` back to the
48 /// declared shape.
49 pub fn output(&self, g: &mut Graph, idx: usize) -> NodeId {
50 assert!(idx < self.sub_shapes.len(), "output index out of range");
51 let sub = &self.sub_shapes[idx];
52 let n_elems: usize = sub
53 .dims()
54 .iter()
55 .map(|d| match d {
56 Dim::Static(k) => *k,
57 Dim::Dynamic(_) => panic!("dynamic sub-output dim"),
58 })
59 .product();
60 let flat_shape = Shape::from_dims(&[Dim::Static(n_elems)], sub.dtype());
61 let narrowed = g.add_node(
62 Op::Narrow {
63 axis: 0,
64 start: self.offsets[idx],
65 len: n_elems,
66 },
67 vec![self.source],
68 flat_shape,
69 );
70 if sub.rank() == 1 {
71 // Already the right shape.
72 narrowed
73 } else {
74 let dims: Vec<i64> = sub
75 .dims()
76 .iter()
77 .map(|d| match d {
78 Dim::Static(k) => *k as i64,
79 Dim::Dynamic(_) => unreachable!(),
80 })
81 .collect();
82 g.add_node(Op::Reshape { new_shape: dims }, vec![narrowed], sub.clone())
83 }
84 }
85}
86
87impl Graph {
88 /// Mamba-style selective scan: y = SSM(x, Δ, A, B, C).
89 /// Inputs: x \[b,s,h\], delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
90 /// Output \[b,s,h\]. n is the state size.
91 pub fn selective_scan(
92 &mut self,
93 x: NodeId,
94 delta: NodeId,
95 a: NodeId,
96 b: NodeId,
97 c: NodeId,
98 state_size: usize,
99 shape: Shape,
100 ) -> NodeId {
101 self.push(
102 Op::SelectiveScan { state_size },
103 vec![x, delta, a, b, c],
104 shape,
105 None,
106 )
107 }
108
109 /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk,
110 /// Qwen3-Next, Kimi-Linear). See [`Op::GatedDeltaNet`] for the
111 /// recurrence math. All five inputs are `f32`. Shapes:
112 /// `q,k,v`: `[b, s, h_v, n]`; `g,beta`: `[b, s, h_v]`. Output:
113 /// `[b, s, h_v, n]`. State is implicit (reset per batch) unless
114 /// `carry_state` is set — then pass `state` as a sixth input.
115 pub fn gated_delta_net(
116 &mut self,
117 q: NodeId,
118 k: NodeId,
119 v: NodeId,
120 g: NodeId,
121 beta: NodeId,
122 state_size: usize,
123 shape: Shape,
124 ) -> NodeId {
125 self.push(
126 Op::GatedDeltaNet {
127 state_size,
128 carry_state: false,
129 },
130 vec![q, k, v, g, beta],
131 shape,
132 None,
133 )
134 }
135
136 /// Same as [`Self::gated_delta_net`] but threads `state`
137 /// `[b, h_v, n, n]` in/out for decode-mode recurrence.
138 pub fn gated_delta_net_carry(
139 &mut self,
140 q: NodeId,
141 k: NodeId,
142 v: NodeId,
143 g: NodeId,
144 beta: NodeId,
145 state: NodeId,
146 state_size: usize,
147 shape: Shape,
148 ) -> NodeId {
149 self.push(
150 Op::GatedDeltaNet {
151 state_size,
152 carry_state: true,
153 },
154 vec![q, k, v, g, beta, state],
155 shape,
156 None,
157 )
158 }
159
160 /// Bounded scan returning the final carry. Body must have exactly
161 /// one `Op::Input` (the carry) and one output, both same shape as
162 /// `init`. Output shape matches `init`.
163 pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
164 let init_shape = self.shape(init).clone();
165 self.push(
166 Op::Scan {
167 body: Box::new(body),
168 length,
169 save_trajectory: false,
170 num_bcast: 0,
171 num_xs: 0,
172 num_checkpoints: 0,
173 },
174 vec![init],
175 init_shape,
176 None,
177 )
178 }
179
180 /// Bounded scan with recursive checkpointing for memory-bounded
181 /// backward AD. Equivalent to [`Self::scan`] for the forward
182 /// computation, but during backward only `num_checkpoints` carry
183 /// values are cached; intermediate carries are recomputed via the
184 /// body. Memory: `O(num_checkpoints · carry_size)`. Time: forward
185 /// unchanged; backward `O(length)` (segment-cached).
186 ///
187 /// The AD pre-pass propagates `num_checkpoints` into the rewritten
188 /// trajectory-saving Scan and into the emitted ScanBackward, so a
189 /// single call to [`crate::Graph::scan_checkpointed`] is enough
190 /// to enable the memory bound across the whole forward+backward
191 /// pipeline.
192 pub fn scan_checkpointed(
193 &mut self,
194 init: NodeId,
195 body: Graph,
196 length: u32,
197 num_checkpoints: u32,
198 ) -> NodeId {
199 assert!(
200 num_checkpoints > 0 && num_checkpoints <= length,
201 "scan_checkpointed: num_checkpoints={num_checkpoints} \
202 must be in 1..=length={length}"
203 );
204 let init_shape = self.shape(init).clone();
205 self.push(
206 Op::Scan {
207 body: Box::new(body),
208 length,
209 save_trajectory: false,
210 num_bcast: 0,
211 num_xs: 0,
212 num_checkpoints,
213 },
214 vec![init],
215 init_shape,
216 None,
217 )
218 }
219
220 /// Bounded scan with broadcast and per-step inputs.
221 ///
222 /// Body `Op::Input`s in NodeId order: `[carry, bcast_0..bcast_{B-1},
223 /// x_t_0..x_t_{X-1}]`. Bcast inputs keep their natural shape (the
224 /// CPU executor fills them once before the scan loop). xs\[i\] has
225 /// shape `[length, *per_step]` and the body sees `xs[i][t]` per
226 /// iteration. Output shape matches `init`.
227 pub fn scan_with_bcasts_and_xs(
228 &mut self,
229 init: NodeId,
230 bcasts: &[NodeId],
231 xs: &[NodeId],
232 body: Graph,
233 length: u32,
234 ) -> NodeId {
235 let init_shape = self.shape(init).clone();
236 let mut inputs = vec![init];
237 inputs.extend_from_slice(bcasts);
238 inputs.extend_from_slice(xs);
239 self.push(
240 Op::Scan {
241 body: Box::new(body),
242 length,
243 save_trajectory: false,
244 num_bcast: bcasts.len() as u32,
245 num_xs: xs.len() as u32,
246 num_checkpoints: 0,
247 },
248 inputs,
249 init_shape,
250 None,
251 )
252 }
253
254 /// Bounded scan with per-step `xs` inputs returning the final carry.
255 /// Body has `1 + xs.len()` Op::Inputs in NodeId construction order
256 /// (first declared is the carry; the remaining match `xs` in order).
257 /// Each `xs[i]` has shape `[length, *per_step_shape_i]`; the body
258 /// sees a `per_step_shape_i` slice on iteration `t`.
259 pub fn scan_with_xs(
260 &mut self,
261 init: NodeId,
262 xs: &[NodeId],
263 body: Graph,
264 length: u32,
265 ) -> NodeId {
266 let init_shape = self.shape(init).clone();
267 let mut inputs = vec![init];
268 inputs.extend_from_slice(xs);
269 self.push(
270 Op::Scan {
271 body: Box::new(body),
272 length,
273 save_trajectory: false,
274 num_bcast: 0,
275 num_xs: xs.len() as u32,
276 num_checkpoints: 0,
277 },
278 inputs,
279 init_shape,
280 None,
281 )
282 }
283
284 /// Reverse-mode AD companion to [`Self::scan`] /
285 /// [`Self::scan_trajectory`]. Typically constructed by the
286 /// autodiff pass, not by hand.
287 ///
288 /// `xs` is the list of per-step input tensors (must match the
289 /// forward Op::Scan's xs in count, order, and per-step shape).
290 /// Body_vjp's `1 + xs.len() + 1` Op::Inputs match the forward
291 /// body's inputs plus a fresh `"d_output"` Input.
292 pub fn scan_backward(
293 &mut self,
294 init: NodeId,
295 trajectory: NodeId,
296 upstream: NodeId,
297 xs: &[NodeId],
298 body_vjp: Graph,
299 length: u32,
300 save_trajectory: bool,
301 out_shape: Shape,
302 ) -> NodeId {
303 self.scan_backward_with_checkpoints(
304 init,
305 trajectory,
306 upstream,
307 xs,
308 body_vjp,
309 length,
310 save_trajectory,
311 0,
312 None,
313 out_shape,
314 )
315 }
316
317 /// Lower-level `scan_backward` with explicit checkpointing config.
318 /// `num_checkpoints == 0` (default) means no checkpointing — the
319 /// trajectory cache holds every step's carry. `0 < K < length`
320 /// enables segment-cached recompute via `forward_body` (must be
321 /// `Some`).
322 #[allow(clippy::too_many_arguments)]
323 pub fn scan_backward_with_checkpoints(
324 &mut self,
325 init: NodeId,
326 trajectory: NodeId,
327 upstream: NodeId,
328 xs: &[NodeId],
329 body_vjp: Graph,
330 length: u32,
331 save_trajectory: bool,
332 num_checkpoints: u32,
333 forward_body: Option<Graph>,
334 out_shape: Shape,
335 ) -> NodeId {
336 let mut inputs = vec![init, trajectory, upstream];
337 inputs.extend_from_slice(xs);
338 self.push(
339 Op::ScanBackward {
340 body_vjp: Box::new(body_vjp),
341 length,
342 save_trajectory,
343 num_xs: xs.len() as u32,
344 num_checkpoints,
345 forward_body: forward_body.map(Box::new),
346 },
347 inputs,
348 out_shape,
349 None,
350 )
351 }
352
353 /// Per-step xs gradient companion to [`Self::scan_backward`].
354 /// Same inputs and same `body_vjp` graph, plus an `xs_idx`
355 /// selecting which body_vjp output to stack into the result.
356 /// Output shape is `[length, *per_step_xs_shape]`.
357 pub fn scan_backward_xs(
358 &mut self,
359 init: NodeId,
360 trajectory: NodeId,
361 upstream: NodeId,
362 xs: &[NodeId],
363 body_vjp: Graph,
364 length: u32,
365 save_trajectory: bool,
366 xs_idx: u32,
367 out_shape: Shape,
368 ) -> NodeId {
369 self.scan_backward_xs_with_checkpoints(
370 init,
371 trajectory,
372 upstream,
373 xs,
374 body_vjp,
375 length,
376 save_trajectory,
377 xs_idx,
378 0,
379 None,
380 out_shape,
381 )
382 }
383
384 #[allow(clippy::too_many_arguments)]
385 pub fn scan_backward_xs_with_checkpoints(
386 &mut self,
387 init: NodeId,
388 trajectory: NodeId,
389 upstream: NodeId,
390 xs: &[NodeId],
391 body_vjp: Graph,
392 length: u32,
393 save_trajectory: bool,
394 xs_idx: u32,
395 num_checkpoints: u32,
396 forward_body: Option<Graph>,
397 out_shape: Shape,
398 ) -> NodeId {
399 let mut inputs = vec![init, trajectory, upstream];
400 inputs.extend_from_slice(xs);
401 self.push(
402 Op::ScanBackwardXs {
403 body_vjp: Box::new(body_vjp),
404 length,
405 save_trajectory,
406 num_xs: xs.len() as u32,
407 xs_idx,
408 num_checkpoints,
409 forward_body: forward_body.map(Box::new),
410 },
411 inputs,
412 out_shape,
413 None,
414 )
415 }
416
417 /// User-defined sub-graph with optional override AD rules.
418 /// JAX-shaped `custom_vjp` / `custom_jvp` — see [`Op::CustomFn`].
419 ///
420 /// `inputs.len()` must equal the number of `Op::Input` nodes in
421 /// `fwd_body`. Output shape is inferred from `fwd_body`'s declared
422 /// output. When supplied, `vjp_body` and `jvp_body` must follow the
423 /// conventions documented on [`Op::CustomFn`] (special-named
424 /// `"primal_output"` / `"d_output"` / `"tangent_*"` Inputs).
425 pub fn custom_fn(
426 &mut self,
427 inputs: Vec<NodeId>,
428 fwd_body: Graph,
429 vjp_body: Option<Graph>,
430 jvp_body: Option<Graph>,
431 ) -> NodeId {
432 let n_in = inputs.len();
433 // Count fwd_body's primal Inputs (no special names — fwd has none).
434 let fwd_inputs: usize = fwd_body
435 .nodes()
436 .iter()
437 .filter(|n| matches!(n.op, Op::Input { .. }))
438 .count();
439 assert_eq!(
440 fwd_inputs, n_in,
441 "custom_fn: fwd_body has {fwd_inputs} Op::Input(s); outer call \
442 provides {n_in}. Counts must match.",
443 );
444 let fwd_out_id = fwd_body
445 .outputs
446 .first()
447 .copied()
448 .expect("custom_fn: fwd_body must declare exactly one output");
449 let out_shape = fwd_body.node(fwd_out_id).shape.clone();
450
451 if let Some(vjp) = vjp_body.as_ref() {
452 let primal_count = vjp
453 .nodes()
454 .iter()
455 .filter(|n| {
456 matches!(&n.op,
457 Op::Input { name } if name != "primal_output" && name != "d_output")
458 })
459 .count();
460 assert_eq!(
461 primal_count, n_in,
462 "custom_fn: vjp_body has {primal_count} primal Op::Input(s) \
463 (excluding 'primal_output' / 'd_output'); expected {n_in}",
464 );
465 let has_primal_out = vjp
466 .nodes()
467 .iter()
468 .any(|n| matches!(&n.op, Op::Input { name } if name == "primal_output"));
469 let has_d_output = vjp
470 .nodes()
471 .iter()
472 .any(|n| matches!(&n.op, Op::Input { name } if name == "d_output"));
473 assert!(
474 has_primal_out,
475 "custom_fn: vjp_body must declare an Op::Input named 'primal_output'"
476 );
477 assert!(
478 has_d_output,
479 "custom_fn: vjp_body must declare an Op::Input named 'd_output'"
480 );
481 assert_eq!(
482 vjp.outputs.len(),
483 n_in,
484 "custom_fn: vjp_body has {} outputs; expected {n_in} \
485 (one gradient per primal input)",
486 vjp.outputs.len(),
487 );
488 }
489 if let Some(jvp) = jvp_body.as_ref() {
490 let primal_count = jvp
491 .nodes()
492 .iter()
493 .filter(|n| {
494 matches!(&n.op,
495 Op::Input { name }
496 if !name.starts_with("tangent_") && name != "primal_output")
497 })
498 .count();
499 assert_eq!(
500 primal_count, n_in,
501 "custom_fn: jvp_body has {primal_count} primal Op::Input(s) \
502 (excluding 'primal_output' / 'tangent_*'); expected {n_in}",
503 );
504 for i in 0..n_in {
505 let want = format!("tangent_{i}");
506 let has = jvp
507 .nodes()
508 .iter()
509 .any(|n| matches!(&n.op, Op::Input { name } if name == &want));
510 assert!(
511 has,
512 "custom_fn: jvp_body must declare an Op::Input named '{want}'"
513 );
514 }
515 assert_eq!(
516 jvp.outputs.len(),
517 1,
518 "custom_fn: jvp_body has {} outputs; expected 1 (output tangent)",
519 jvp.outputs.len(),
520 );
521 }
522
523 self.push(
524 Op::CustomFn {
525 fwd_body: Box::new(fwd_body),
526 vjp_body: vjp_body.map(Box::new),
527 jvp_body: jvp_body.map(Box::new),
528 num_inputs: n_in as u32,
529 },
530 inputs,
531 out_shape,
532 None,
533 )
534 }
535
536 /// Multi-output `custom_fn` via the **concat-with-Narrow** design:
537 /// rewrites `fwd_body` to flatten + concat its `K` declared outputs
538 /// into a single 1-D F32 output, wraps that as [`Op::CustomFn`],
539 /// and returns a [`MultiOutputHandle`] the caller uses to extract
540 /// each sub-output via `Op::Narrow` + `Op::Reshape`.
541 ///
542 /// Per PLAN line 484, this avoids rewriting rlx's "1 Op = 1 output"
543 /// IR contract: the wrapped Op::CustomFn still has one output (the
544 /// flat concat), and `MultiOutputHandle::output(g, i)` materializes
545 /// component `i` lazily on the outer graph.
546 ///
547 /// Constraints (MVP):
548 /// - All sub-outputs must be `DType::F32`. Tuples-of-mixed-dtype
549 /// need either a per-dtype split or a future tuple-type
550 /// extension.
551 /// - All sub-output shapes must be statically known (no
552 /// `Dim::Dynamic`).
553 /// - `vjp_body` / `jvp_body` aren't yet rewritten through the
554 /// concat — caller must provide bodies that already expect
555 /// the flat-concat output convention if they need custom AD.
556 pub fn custom_fn_multi(
557 &mut self,
558 inputs: Vec<NodeId>,
559 mut fwd_body: Graph,
560 ) -> MultiOutputHandle {
561 use crate::op::BinaryOp;
562 // Snapshot the original outputs + their shapes BEFORE
563 // appending concat ops. Outputs land at the end of the graph;
564 // we'll replace them.
565 let original_outputs = fwd_body.outputs.clone();
566 assert!(
567 !original_outputs.is_empty(),
568 "custom_fn_multi: fwd_body must have ≥ 1 declared output"
569 );
570 let mut sub_shapes: Vec<Shape> = Vec::with_capacity(original_outputs.len());
571 let mut offsets: Vec<usize> = Vec::with_capacity(original_outputs.len());
572 let mut total_len: usize = 0;
573 for &out_id in &original_outputs {
574 let s = fwd_body.node(out_id).shape.clone();
575 assert_eq!(
576 s.dtype(),
577 DType::F32,
578 "custom_fn_multi MVP: all sub-outputs must be F32, got {:?} \
579 (sub-output #{})",
580 s.dtype(),
581 sub_shapes.len()
582 );
583 let n_elems: usize = s
584 .dims()
585 .iter()
586 .map(|d| match d {
587 Dim::Static(k) => *k,
588 Dim::Dynamic(_) => {
589 panic!("custom_fn_multi MVP: dynamic dims not supported")
590 }
591 })
592 .product();
593 offsets.push(total_len);
594 total_len += n_elems;
595 sub_shapes.push(s);
596 }
597 // Flatten each sub-output to [n_elems] and concat along axis 0.
598 let mut flats: Vec<NodeId> = Vec::with_capacity(original_outputs.len());
599 for (out_id, sh) in original_outputs.iter().zip(sub_shapes.iter()) {
600 let n: usize = sh
601 .dims()
602 .iter()
603 .map(|d| match d {
604 Dim::Static(k) => *k,
605 Dim::Dynamic(_) => unreachable!(),
606 })
607 .product();
608 let flat_shape = Shape::from_dims(&[Dim::Static(n)], DType::F32);
609 let flat = fwd_body.add_node(
610 Op::Reshape {
611 new_shape: vec![n as i64],
612 },
613 vec![*out_id],
614 flat_shape,
615 );
616 flats.push(flat);
617 }
618 let concat_shape = Shape::from_dims(&[Dim::Static(total_len)], DType::F32);
619 let concat = fwd_body.add_node(Op::Concat { axis: 0 }, flats.clone(), concat_shape);
620 let _ = BinaryOp::Add; // import preserved if we extend later
621 fwd_body.set_outputs(vec![concat]);
622
623 // Now build the outer custom_fn with the rewritten body. Reuses
624 // the single-output asserts; flat concat satisfies them.
625 let source = self.custom_fn(inputs, fwd_body, None, None);
626
627 MultiOutputHandle {
628 source,
629 sub_shapes,
630 offsets,
631 }
632 }
633
634 /// Bounded scan returning the stacked trajectory.
635 /// Output shape is `[length, *init.shape]` — row `t` is the carry
636 /// after step `t+1`, so row `length-1` equals the result of plain
637 /// [`Self::scan`].
638 pub fn scan_trajectory(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
639 let init_shape = self.shape(init).clone();
640 let mut traj_dims: Vec<crate::Dim> = Vec::with_capacity(init_shape.rank() + 1);
641 traj_dims.push(crate::Dim::Static(length as usize));
642 for i in 0..init_shape.rank() {
643 traj_dims.push(init_shape.dim(i));
644 }
645 let traj_shape = crate::Shape::from_dims(&traj_dims, init_shape.dtype());
646 self.push(
647 Op::Scan {
648 body: Box::new(body),
649 length,
650 save_trajectory: true,
651 num_xs: 0,
652 num_bcast: 0,
653 num_checkpoints: 0,
654 },
655 vec![init],
656 traj_shape,
657 None,
658 )
659 }
660}