Skip to main content

rlx_ir/ops/
special.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Specialised builders: SSM selective scan + space for future
17//! exotic ops (plan #53).
18
19use crate::shape::Dim;
20use crate::{DType, Graph, NodeId, Op, Shape};
21
22/// Handle to a multi-output [`Op::CustomFn`] built via
23/// [`Graph::custom_fn_multi`]. Internally the op produces a flat 1-D
24/// concatenated output; this handle remembers each sub-output's
25/// offset + original shape so [`Self::output`] can materialize
26/// component `i` lazily via `Op::Narrow` + `Op::Reshape`.
27#[derive(Debug, Clone)]
28pub struct MultiOutputHandle {
29    /// NodeId of the wrapped Op::CustomFn (1-D F32, length =
30    /// `Σ sub_shapes[i].num_elements`).
31    pub source: NodeId,
32    /// Original per-sub-output shapes (in declaration order).
33    pub sub_shapes: Vec<Shape>,
34    /// Per-sub-output start offsets into `source` (element-counted,
35    /// not byte-counted).
36    pub offsets: Vec<usize>,
37}
38
39impl MultiOutputHandle {
40    /// Number of sub-outputs.
41    pub fn n_outputs(&self) -> usize {
42        self.sub_shapes.len()
43    }
44
45    /// Materialize sub-output `idx` as an outer-graph NodeId.
46    /// Internally: `Op::Narrow(source, axis=0, start=offsets[idx],
47    /// len=numel(sub_shapes[idx]))` → `Op::Reshape` back to the
48    /// declared shape.
49    pub fn output(&self, g: &mut Graph, idx: usize) -> NodeId {
50        assert!(idx < self.sub_shapes.len(), "output index out of range");
51        let sub = &self.sub_shapes[idx];
52        let n_elems: usize = sub
53            .dims()
54            .iter()
55            .map(|d| match d {
56                Dim::Static(k) => *k,
57                Dim::Dynamic(_) => panic!("dynamic sub-output dim"),
58            })
59            .product();
60        let flat_shape = Shape::from_dims(&[Dim::Static(n_elems)], sub.dtype());
61        let narrowed = g.add_node(
62            Op::Narrow {
63                axis: 0,
64                start: self.offsets[idx],
65                len: n_elems,
66            },
67            vec![self.source],
68            flat_shape,
69        );
70        if sub.rank() == 1 {
71            // Already the right shape.
72            narrowed
73        } else {
74            let dims: Vec<i64> = sub
75                .dims()
76                .iter()
77                .map(|d| match d {
78                    Dim::Static(k) => *k as i64,
79                    Dim::Dynamic(_) => unreachable!(),
80                })
81                .collect();
82            g.add_node(Op::Reshape { new_shape: dims }, vec![narrowed], sub.clone())
83        }
84    }
85}
86
87impl Graph {
88    /// Mamba-style selective scan: y = SSM(x, Δ, A, B, C).
89    /// Inputs: x \[b,s,h\], delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
90    /// Output \[b,s,h\]. n is the state size.
91    pub fn selective_scan(
92        &mut self,
93        x: NodeId,
94        delta: NodeId,
95        a: NodeId,
96        b: NodeId,
97        c: NodeId,
98        state_size: usize,
99        shape: Shape,
100    ) -> NodeId {
101        self.push(
102            Op::SelectiveScan { state_size },
103            vec![x, delta, a, b, c],
104            shape,
105            None,
106        )
107    }
108
109    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk,
110    /// Qwen3-Next, Kimi-Linear). See [`Op::GatedDeltaNet`] for the
111    /// recurrence math. All five inputs are `f32`. Shapes:
112    /// `q,k,v`: `[b, s, h_v, n]`; `g,beta`: `[b, s, h_v]`. Output:
113    /// `[b, s, h_v, n]`. State is implicit (reset per batch) unless
114    /// `carry_state` is set — then pass `state` as a sixth input.
115    pub fn gated_delta_net(
116        &mut self,
117        q: NodeId,
118        k: NodeId,
119        v: NodeId,
120        g: NodeId,
121        beta: NodeId,
122        state_size: usize,
123        shape: Shape,
124    ) -> NodeId {
125        self.push(
126            Op::GatedDeltaNet {
127                state_size,
128                carry_state: false,
129            },
130            vec![q, k, v, g, beta],
131            shape,
132            None,
133        )
134    }
135
136    /// Same as [`Self::gated_delta_net`] but threads `state`
137    /// `[b, h_v, n, n]` in/out for decode-mode recurrence.
138    pub fn gated_delta_net_carry(
139        &mut self,
140        q: NodeId,
141        k: NodeId,
142        v: NodeId,
143        g: NodeId,
144        beta: NodeId,
145        state: NodeId,
146        state_size: usize,
147        shape: Shape,
148    ) -> NodeId {
149        self.push(
150            Op::GatedDeltaNet {
151                state_size,
152                carry_state: true,
153            },
154            vec![q, k, v, g, beta, state],
155            shape,
156            None,
157        )
158    }
159
160    /// Bounded scan returning the final carry. Body must have exactly
161    /// one `Op::Input` (the carry) and one output, both same shape as
162    /// `init`. Output shape matches `init`.
163    pub fn scan(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
164        let init_shape = self.shape(init).clone();
165        self.push(
166            Op::Scan {
167                body: Box::new(body),
168                length,
169                save_trajectory: false,
170                num_bcast: 0,
171                num_xs: 0,
172                num_checkpoints: 0,
173            },
174            vec![init],
175            init_shape,
176            None,
177        )
178    }
179
180    /// Bounded scan with recursive checkpointing for memory-bounded
181    /// backward AD. Equivalent to [`Self::scan`] for the forward
182    /// computation, but during backward only `num_checkpoints` carry
183    /// values are cached; intermediate carries are recomputed via the
184    /// body. Memory: `O(num_checkpoints · carry_size)`. Time: forward
185    /// unchanged; backward `O(length)` (segment-cached).
186    ///
187    /// The AD pre-pass propagates `num_checkpoints` into the rewritten
188    /// trajectory-saving Scan and into the emitted ScanBackward, so a
189    /// single call to [`crate::Graph::scan_checkpointed`] is enough
190    /// to enable the memory bound across the whole forward+backward
191    /// pipeline.
192    pub fn scan_checkpointed(
193        &mut self,
194        init: NodeId,
195        body: Graph,
196        length: u32,
197        num_checkpoints: u32,
198    ) -> NodeId {
199        assert!(
200            num_checkpoints > 0 && num_checkpoints <= length,
201            "scan_checkpointed: num_checkpoints={num_checkpoints} \
202             must be in 1..=length={length}"
203        );
204        let init_shape = self.shape(init).clone();
205        self.push(
206            Op::Scan {
207                body: Box::new(body),
208                length,
209                save_trajectory: false,
210                num_bcast: 0,
211                num_xs: 0,
212                num_checkpoints,
213            },
214            vec![init],
215            init_shape,
216            None,
217        )
218    }
219
220    /// Bounded scan with broadcast and per-step inputs.
221    ///
222    /// Body `Op::Input`s in NodeId order: `[carry, bcast_0..bcast_{B-1},
223    /// x_t_0..x_t_{X-1}]`. Bcast inputs keep their natural shape (the
224    /// CPU executor fills them once before the scan loop). xs\[i\] has
225    /// shape `[length, *per_step]` and the body sees `xs[i][t]` per
226    /// iteration. Output shape matches `init`.
227    pub fn scan_with_bcasts_and_xs(
228        &mut self,
229        init: NodeId,
230        bcasts: &[NodeId],
231        xs: &[NodeId],
232        body: Graph,
233        length: u32,
234    ) -> NodeId {
235        let init_shape = self.shape(init).clone();
236        let mut inputs = vec![init];
237        inputs.extend_from_slice(bcasts);
238        inputs.extend_from_slice(xs);
239        self.push(
240            Op::Scan {
241                body: Box::new(body),
242                length,
243                save_trajectory: false,
244                num_bcast: bcasts.len() as u32,
245                num_xs: xs.len() as u32,
246                num_checkpoints: 0,
247            },
248            inputs,
249            init_shape,
250            None,
251        )
252    }
253
254    /// Bounded scan with per-step `xs` inputs returning the final carry.
255    /// Body has `1 + xs.len()` Op::Inputs in NodeId construction order
256    /// (first declared is the carry; the remaining match `xs` in order).
257    /// Each `xs[i]` has shape `[length, *per_step_shape_i]`; the body
258    /// sees a `per_step_shape_i` slice on iteration `t`.
259    pub fn scan_with_xs(
260        &mut self,
261        init: NodeId,
262        xs: &[NodeId],
263        body: Graph,
264        length: u32,
265    ) -> NodeId {
266        let init_shape = self.shape(init).clone();
267        let mut inputs = vec![init];
268        inputs.extend_from_slice(xs);
269        self.push(
270            Op::Scan {
271                body: Box::new(body),
272                length,
273                save_trajectory: false,
274                num_bcast: 0,
275                num_xs: xs.len() as u32,
276                num_checkpoints: 0,
277            },
278            inputs,
279            init_shape,
280            None,
281        )
282    }
283
284    /// Reverse-mode AD companion to [`Self::scan`] /
285    /// [`Self::scan_trajectory`]. Typically constructed by the
286    /// autodiff pass, not by hand.
287    ///
288    /// `xs` is the list of per-step input tensors (must match the
289    /// forward Op::Scan's xs in count, order, and per-step shape).
290    /// Body_vjp's `1 + xs.len() + 1` Op::Inputs match the forward
291    /// body's inputs plus a fresh `"d_output"` Input.
292    pub fn scan_backward(
293        &mut self,
294        init: NodeId,
295        trajectory: NodeId,
296        upstream: NodeId,
297        xs: &[NodeId],
298        body_vjp: Graph,
299        length: u32,
300        save_trajectory: bool,
301        out_shape: Shape,
302    ) -> NodeId {
303        self.scan_backward_with_checkpoints(
304            init,
305            trajectory,
306            upstream,
307            xs,
308            body_vjp,
309            length,
310            save_trajectory,
311            0,
312            None,
313            out_shape,
314        )
315    }
316
317    /// Lower-level `scan_backward` with explicit checkpointing config.
318    /// `num_checkpoints == 0` (default) means no checkpointing — the
319    /// trajectory cache holds every step's carry. `0 < K < length`
320    /// enables segment-cached recompute via `forward_body` (must be
321    /// `Some`).
322    #[allow(clippy::too_many_arguments)]
323    pub fn scan_backward_with_checkpoints(
324        &mut self,
325        init: NodeId,
326        trajectory: NodeId,
327        upstream: NodeId,
328        xs: &[NodeId],
329        body_vjp: Graph,
330        length: u32,
331        save_trajectory: bool,
332        num_checkpoints: u32,
333        forward_body: Option<Graph>,
334        out_shape: Shape,
335    ) -> NodeId {
336        let mut inputs = vec![init, trajectory, upstream];
337        inputs.extend_from_slice(xs);
338        self.push(
339            Op::ScanBackward {
340                body_vjp: Box::new(body_vjp),
341                length,
342                save_trajectory,
343                num_xs: xs.len() as u32,
344                num_checkpoints,
345                forward_body: forward_body.map(Box::new),
346            },
347            inputs,
348            out_shape,
349            None,
350        )
351    }
352
353    /// Per-step xs gradient companion to [`Self::scan_backward`].
354    /// Same inputs and same `body_vjp` graph, plus an `xs_idx`
355    /// selecting which body_vjp output to stack into the result.
356    /// Output shape is `[length, *per_step_xs_shape]`.
357    pub fn scan_backward_xs(
358        &mut self,
359        init: NodeId,
360        trajectory: NodeId,
361        upstream: NodeId,
362        xs: &[NodeId],
363        body_vjp: Graph,
364        length: u32,
365        save_trajectory: bool,
366        xs_idx: u32,
367        out_shape: Shape,
368    ) -> NodeId {
369        self.scan_backward_xs_with_checkpoints(
370            init,
371            trajectory,
372            upstream,
373            xs,
374            body_vjp,
375            length,
376            save_trajectory,
377            xs_idx,
378            0,
379            None,
380            out_shape,
381        )
382    }
383
384    #[allow(clippy::too_many_arguments)]
385    pub fn scan_backward_xs_with_checkpoints(
386        &mut self,
387        init: NodeId,
388        trajectory: NodeId,
389        upstream: NodeId,
390        xs: &[NodeId],
391        body_vjp: Graph,
392        length: u32,
393        save_trajectory: bool,
394        xs_idx: u32,
395        num_checkpoints: u32,
396        forward_body: Option<Graph>,
397        out_shape: Shape,
398    ) -> NodeId {
399        let mut inputs = vec![init, trajectory, upstream];
400        inputs.extend_from_slice(xs);
401        self.push(
402            Op::ScanBackwardXs {
403                body_vjp: Box::new(body_vjp),
404                length,
405                save_trajectory,
406                num_xs: xs.len() as u32,
407                xs_idx,
408                num_checkpoints,
409                forward_body: forward_body.map(Box::new),
410            },
411            inputs,
412            out_shape,
413            None,
414        )
415    }
416
417    /// User-defined sub-graph with optional override AD rules.
418    /// JAX-shaped `custom_vjp` / `custom_jvp` — see [`Op::CustomFn`].
419    ///
420    /// `inputs.len()` must equal the number of `Op::Input` nodes in
421    /// `fwd_body`. Output shape is inferred from `fwd_body`'s declared
422    /// output. When supplied, `vjp_body` and `jvp_body` must follow the
423    /// conventions documented on [`Op::CustomFn`] (special-named
424    /// `"primal_output"` / `"d_output"` / `"tangent_*"` Inputs).
425    pub fn custom_fn(
426        &mut self,
427        inputs: Vec<NodeId>,
428        fwd_body: Graph,
429        vjp_body: Option<Graph>,
430        jvp_body: Option<Graph>,
431    ) -> NodeId {
432        let n_in = inputs.len();
433        // Count fwd_body's primal Inputs (no special names — fwd has none).
434        let fwd_inputs: usize = fwd_body
435            .nodes()
436            .iter()
437            .filter(|n| matches!(n.op, Op::Input { .. }))
438            .count();
439        assert_eq!(
440            fwd_inputs, n_in,
441            "custom_fn: fwd_body has {fwd_inputs} Op::Input(s); outer call \
442             provides {n_in}. Counts must match.",
443        );
444        let fwd_out_id = fwd_body
445            .outputs
446            .first()
447            .copied()
448            .expect("custom_fn: fwd_body must declare exactly one output");
449        let out_shape = fwd_body.node(fwd_out_id).shape.clone();
450
451        if let Some(vjp) = vjp_body.as_ref() {
452            let primal_count = vjp
453                .nodes()
454                .iter()
455                .filter(|n| {
456                    matches!(&n.op,
457                    Op::Input { name } if name != "primal_output" && name != "d_output")
458                })
459                .count();
460            assert_eq!(
461                primal_count, n_in,
462                "custom_fn: vjp_body has {primal_count} primal Op::Input(s) \
463                 (excluding 'primal_output' / 'd_output'); expected {n_in}",
464            );
465            let has_primal_out = vjp
466                .nodes()
467                .iter()
468                .any(|n| matches!(&n.op, Op::Input { name } if name == "primal_output"));
469            let has_d_output = vjp
470                .nodes()
471                .iter()
472                .any(|n| matches!(&n.op, Op::Input { name } if name == "d_output"));
473            assert!(
474                has_primal_out,
475                "custom_fn: vjp_body must declare an Op::Input named 'primal_output'"
476            );
477            assert!(
478                has_d_output,
479                "custom_fn: vjp_body must declare an Op::Input named 'd_output'"
480            );
481            assert_eq!(
482                vjp.outputs.len(),
483                n_in,
484                "custom_fn: vjp_body has {} outputs; expected {n_in} \
485                 (one gradient per primal input)",
486                vjp.outputs.len(),
487            );
488        }
489        if let Some(jvp) = jvp_body.as_ref() {
490            let primal_count = jvp
491                .nodes()
492                .iter()
493                .filter(|n| {
494                    matches!(&n.op,
495                    Op::Input { name }
496                        if !name.starts_with("tangent_") && name != "primal_output")
497                })
498                .count();
499            assert_eq!(
500                primal_count, n_in,
501                "custom_fn: jvp_body has {primal_count} primal Op::Input(s) \
502                 (excluding 'primal_output' / 'tangent_*'); expected {n_in}",
503            );
504            for i in 0..n_in {
505                let want = format!("tangent_{i}");
506                let has = jvp
507                    .nodes()
508                    .iter()
509                    .any(|n| matches!(&n.op, Op::Input { name } if name == &want));
510                assert!(
511                    has,
512                    "custom_fn: jvp_body must declare an Op::Input named '{want}'"
513                );
514            }
515            assert_eq!(
516                jvp.outputs.len(),
517                1,
518                "custom_fn: jvp_body has {} outputs; expected 1 (output tangent)",
519                jvp.outputs.len(),
520            );
521        }
522
523        self.push(
524            Op::CustomFn {
525                fwd_body: Box::new(fwd_body),
526                vjp_body: vjp_body.map(Box::new),
527                jvp_body: jvp_body.map(Box::new),
528                num_inputs: n_in as u32,
529            },
530            inputs,
531            out_shape,
532            None,
533        )
534    }
535
536    /// Multi-output `custom_fn` via the **concat-with-Narrow** design:
537    /// rewrites `fwd_body` to flatten + concat its `K` declared outputs
538    /// into a single 1-D F32 output, wraps that as [`Op::CustomFn`],
539    /// and returns a [`MultiOutputHandle`] the caller uses to extract
540    /// each sub-output via `Op::Narrow` + `Op::Reshape`.
541    ///
542    /// Per PLAN line 484, this avoids rewriting rlx's "1 Op = 1 output"
543    /// IR contract: the wrapped Op::CustomFn still has one output (the
544    /// flat concat), and `MultiOutputHandle::output(g, i)` materializes
545    /// component `i` lazily on the outer graph.
546    ///
547    /// Constraints (MVP):
548    /// - All sub-outputs must be `DType::F32`. Tuples-of-mixed-dtype
549    ///   need either a per-dtype split or a future tuple-type
550    ///   extension.
551    /// - All sub-output shapes must be statically known (no
552    ///   `Dim::Dynamic`).
553    /// - `vjp_body` / `jvp_body` aren't yet rewritten through the
554    ///   concat — caller must provide bodies that already expect
555    ///   the flat-concat output convention if they need custom AD.
556    pub fn custom_fn_multi(
557        &mut self,
558        inputs: Vec<NodeId>,
559        mut fwd_body: Graph,
560    ) -> MultiOutputHandle {
561        use crate::op::BinaryOp;
562        // Snapshot the original outputs + their shapes BEFORE
563        // appending concat ops. Outputs land at the end of the graph;
564        // we'll replace them.
565        let original_outputs = fwd_body.outputs.clone();
566        assert!(
567            !original_outputs.is_empty(),
568            "custom_fn_multi: fwd_body must have ≥ 1 declared output"
569        );
570        let mut sub_shapes: Vec<Shape> = Vec::with_capacity(original_outputs.len());
571        let mut offsets: Vec<usize> = Vec::with_capacity(original_outputs.len());
572        let mut total_len: usize = 0;
573        for &out_id in &original_outputs {
574            let s = fwd_body.node(out_id).shape.clone();
575            assert_eq!(
576                s.dtype(),
577                DType::F32,
578                "custom_fn_multi MVP: all sub-outputs must be F32, got {:?} \
579                 (sub-output #{})",
580                s.dtype(),
581                sub_shapes.len()
582            );
583            let n_elems: usize = s
584                .dims()
585                .iter()
586                .map(|d| match d {
587                    Dim::Static(k) => *k,
588                    Dim::Dynamic(_) => {
589                        panic!("custom_fn_multi MVP: dynamic dims not supported")
590                    }
591                })
592                .product();
593            offsets.push(total_len);
594            total_len += n_elems;
595            sub_shapes.push(s);
596        }
597        // Flatten each sub-output to [n_elems] and concat along axis 0.
598        let mut flats: Vec<NodeId> = Vec::with_capacity(original_outputs.len());
599        for (out_id, sh) in original_outputs.iter().zip(sub_shapes.iter()) {
600            let n: usize = sh
601                .dims()
602                .iter()
603                .map(|d| match d {
604                    Dim::Static(k) => *k,
605                    Dim::Dynamic(_) => unreachable!(),
606                })
607                .product();
608            let flat_shape = Shape::from_dims(&[Dim::Static(n)], DType::F32);
609            let flat = fwd_body.add_node(
610                Op::Reshape {
611                    new_shape: vec![n as i64],
612                },
613                vec![*out_id],
614                flat_shape,
615            );
616            flats.push(flat);
617        }
618        let concat_shape = Shape::from_dims(&[Dim::Static(total_len)], DType::F32);
619        let concat = fwd_body.add_node(Op::Concat { axis: 0 }, flats.clone(), concat_shape);
620        let _ = BinaryOp::Add; // import preserved if we extend later
621        fwd_body.set_outputs(vec![concat]);
622
623        // Now build the outer custom_fn with the rewritten body. Reuses
624        // the single-output asserts; flat concat satisfies them.
625        let source = self.custom_fn(inputs, fwd_body, None, None);
626
627        MultiOutputHandle {
628            source,
629            sub_shapes,
630            offsets,
631        }
632    }
633
634    /// Bounded scan returning the stacked trajectory.
635    /// Output shape is `[length, *init.shape]` — row `t` is the carry
636    /// after step `t+1`, so row `length-1` equals the result of plain
637    /// [`Self::scan`].
638    pub fn scan_trajectory(&mut self, init: NodeId, body: Graph, length: u32) -> NodeId {
639        let init_shape = self.shape(init).clone();
640        let mut traj_dims: Vec<crate::Dim> = Vec::with_capacity(init_shape.rank() + 1);
641        traj_dims.push(crate::Dim::Static(length as usize));
642        for i in 0..init_shape.rank() {
643            traj_dims.push(init_shape.dim(i));
644        }
645        let traj_shape = crate::Shape::from_dims(&traj_dims, init_shape.dtype());
646        self.push(
647            Op::Scan {
648                body: Box::new(body),
649                length,
650                save_trajectory: true,
651                num_xs: 0,
652                num_bcast: 0,
653                num_checkpoints: 0,
654            },
655            vec![init],
656            traj_shape,
657            None,
658        )
659    }
660}