Skip to main content

rlx_fusion/
control_flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Control-flow lowering passes: `Op::If` → `Where` + inlined
17//! branches; `Op::While` → bounded unroll of body replicas.
18//!
19//! Backends that don't have native sub-graph executors run
20//! `LowerControlFlow` BEFORE the legalize / supported-set check so
21//! they never see `Op::If` or `Op::While`. Used by rlx-cpu and
22//! rlx-metal (the runtime-side `run_if` / `run_while` helpers exist
23//! but the executor wiring through their thunk schedules is more
24//! invasive than this rewrite). Other backends (rlx-wgpu, rlx-cuda,
25//! rlx-rocm, rlx-tpu) ship their own per-backend unfuse passes that
26//! do equivalent work — this module is the portable, IR-level
27//! version.
28//!
29//! Trade-offs:
30//!   * `Op::If` always evaluates **both** branches in the rewritten
31//!     graph. That's the price of expressing it via primitives. Fine
32//!     for inference where Op::If is rare; if a workload hits a
33//!     hot Op::If on a path where both branches are expensive, the
34//!     fix is a backend-native If executor, not this rewrite.
35//!   * `Op::While` requires `max_iterations = Some(N)` — unbounded
36//!     loops have no terminating count and panic with a clear
37//!     message pointing at `rlx_runtime::subgraph::run_while` for
38//!     the dynamic alternative.
39//!
40//! Capture binding (used by both passes): each sub-graph's
41//! `Op::Input` nodes appear in the same order as the parent's
42//! captures (`inputs[1..]` for `Op::If` past the predicate, all
43//! `inputs[..]` for `Op::While`). Sub-graph `Op::Input[i]` rewires
44//! to `captures[i]` when inlined into the parent.
45
46use crate::pass::Pass;
47use rlx_ir::op::{BinaryOp, CmpOp, ReduceOp};
48use rlx_ir::shape::Dim;
49use rlx_ir::{DType, Graph, NodeId, Op, Shape};
50use std::collections::HashMap;
51
52/// Pass form: rewrites `Op::If` and `Op::While` into primitive ops.
53/// No-op when neither op is present.
54pub struct LowerControlFlow;
55
56impl Pass for LowerControlFlow {
57    fn name(&self) -> &str {
58        "LowerControlFlow"
59    }
60    fn run(&self, graph: Graph) -> Graph {
61        let g = inline_if(graph);
62        unroll_while(g)
63    }
64}
65
66/// Inline `Op::If` sub-graphs into the parent and replace the If
67/// node with `Where(predicate, then_output, else_output)`. Both
68/// branches are present in the rewritten graph and always evaluate.
69pub fn inline_if(g: Graph) -> Graph {
70    let mut out = Graph::new(g.name.clone());
71    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
72    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
73
74    for node in &nodes {
75        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
76        let new_id = match &node.op {
77            Op::If {
78                then_branch,
79                else_branch,
80            } => {
81                let captures: Vec<NodeId> = new_inputs[1..].to_vec();
82                let then_out = inline_subgraph_into(then_branch, &captures, &mut out);
83                let else_out = inline_subgraph_into(else_branch, &captures, &mut out);
84                // Most backends' Where kernel requires the predicate
85                // to share the output's element count (no broadcast
86                // inside the kernel). Expand a smaller predicate up
87                // to the output shape so the rewritten graph runs
88                // out of the box on CPU/Metal.
89                let predicate = expand_to_shape(new_inputs[0], &node.shape, &mut out);
90                out.add_node(
91                    Op::Where,
92                    vec![predicate, then_out, else_out],
93                    node.shape.clone(),
94                )
95            }
96            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
97        };
98        id_map.insert(node.id, new_id);
99    }
100    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
101    out.set_outputs(new_outputs);
102    out
103}
104
105/// Bounded-unroll `Op::While` up to `max_iterations`. Each iteration
106/// inlines `cond` and `body` with all loop-carried captures, then
107/// applies `Where(active, body_out, carried)` per carry (MLX semantics).
108pub fn unroll_while(g: Graph) -> Graph {
109    let mut out = Graph::new(g.name.clone());
110    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
111    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
112    let scalar_f32 = Shape::new(&[1], DType::F32);
113
114    for node in &nodes {
115        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
116        let new_id = match &node.op {
117            Op::While {
118                cond,
119                body,
120                max_iterations: Some(n),
121                ..
122            } => {
123                if new_inputs.is_empty() {
124                    panic!(
125                        "Op::While unroll: at least one \
126                            loop-carried input required"
127                    );
128                }
129                let one = out.add_node(
130                    Op::Constant {
131                        data: 1.0_f32.to_le_bytes().to_vec(),
132                    },
133                    vec![],
134                    scalar_f32.clone(),
135                );
136                let mut active = one;
137                let mut carried = new_inputs;
138                for _ in 0..*n {
139                    let cond_out = inline_subgraph_into(cond, &carried, &mut out);
140                    let cond_f = cond_to_scalar_f32(cond_out, &mut out, &scalar_f32);
141                    active = out.binary(BinaryOp::Mul, active, cond_f, scalar_f32.clone());
142
143                    let body_outs = inline_subgraph_into_outputs(body, &carried, &mut out);
144                    assert_eq!(
145                        body_outs.len(),
146                        carried.len(),
147                        "Op::While: body output count must match loop-carried arity"
148                    );
149                    let mut next = Vec::with_capacity(carried.len());
150                    for (body_out, &prev) in body_outs.iter().zip(carried.iter()) {
151                        let shape = out.node(prev).shape.clone();
152                        let mask = expand_to_shape(active, &shape, &mut out);
153                        let merged = out.add_node(Op::Where, vec![mask, *body_out, prev], shape);
154                        next.push(merged);
155                    }
156                    carried = next;
157                }
158                carried[0]
159            }
160            Op::While {
161                max_iterations: None,
162                ..
163            } => {
164                panic!(
165                    "LowerControlFlow: Op::While requires \
166                        max_iterations = Some(N) for unrolling. \
167                        Either set a bounded max_iterations on the \
168                        forward graph, or use the dynamic \
169                        `rlx_runtime::subgraph::run_while` helper."
170                );
171            }
172            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
173        };
174        id_map.insert(node.id, new_id);
175    }
176    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
177    out.set_outputs(new_outputs);
178    out
179}
180
181/// Fold a cond-subgraph output into a scalar f32 loop flag for `active`.
182/// Vector conds are reduced with min(nonzero) so every element must be
183/// truthy for the loop to keep running (matches treating 0 as false).
184fn cond_to_scalar_f32(cond_out: NodeId, out: &mut Graph, scalar_f32: &Shape) -> NodeId {
185    let cond_shape = out.node(cond_out).shape.clone();
186    let n = cond_shape
187        .dims()
188        .iter()
189        .filter_map(|d| match d {
190            Dim::Static(n) => Some(*n),
191            _ => None,
192        })
193        .product::<usize>();
194    let as_f32 = if cond_shape.dtype() == DType::F32 {
195        cond_out
196    } else {
197        out.add_node(
198            Op::Cast { to: DType::F32 },
199            vec![cond_out],
200            cond_shape.with_dtype(DType::F32),
201        )
202    };
203    if n <= 1 {
204        return as_f32;
205    }
206    let as_f32_shape = out.node(as_f32).shape.clone();
207    let rank = as_f32_shape.rank();
208    let zero = out.add_node(
209        Op::Constant {
210            data: 0.0_f32.to_le_bytes().to_vec(),
211        },
212        vec![],
213        scalar_f32.clone(),
214    );
215    let nonzero = out.add_node(
216        Op::Compare(CmpOp::Ne),
217        vec![as_f32, zero],
218        as_f32_shape.clone().with_dtype(DType::Bool),
219    );
220    let nonzero_f = out.add_node(
221        Op::Cast { to: DType::F32 },
222        vec![nonzero],
223        as_f32_shape.with_dtype(DType::F32),
224    );
225    let axes: Vec<usize> = (0..rank).collect();
226    out.reduce(nonzero_f, ReduceOp::Min, axes, true, scalar_f32.clone())
227}
228
229/// Expand a tensor up to `target` via `Op::Expand` if its shape
230/// (specifically its element count) differs from the target. Used to
231/// promote a scalar / smaller predicate up to the Where output shape
232/// during `Op::If` lowering.
233fn expand_to_shape(src: NodeId, target: &rlx_ir::Shape, out: &mut Graph) -> NodeId {
234    let src_shape = out.node(src).shape.clone();
235    let src_n = src_shape
236        .dims()
237        .iter()
238        .filter_map(|d| match d {
239            Dim::Static(n) => Some(*n),
240            _ => None,
241        })
242        .product::<usize>();
243    let tgt_n = target
244        .dims()
245        .iter()
246        .filter_map(|d| match d {
247            Dim::Static(n) => Some(*n),
248            _ => None,
249        })
250        .product::<usize>();
251    if src_shape.dims() == target.dims() {
252        return src;
253    }
254    let target_dims_i64: Vec<i64> = target
255        .dims()
256        .iter()
257        .map(|d| match d {
258            Dim::Static(n) => *n as i64,
259            _ => -1,
260        })
261        .collect();
262    // Op::Expand requires equal rank (broadcast via 1-dim only).
263    // If src has a smaller rank, left-pad with 1s via a Reshape first.
264    let src_rank = src_shape.rank();
265    let tgt_rank = target.dims().len();
266    let to_expand = if src_rank < tgt_rank {
267        let mut padded_dims: Vec<Dim> = std::iter::repeat_n(Dim::Static(1), tgt_rank - src_rank)
268            .chain(src_shape.dims().iter().copied())
269            .collect();
270        // Width of last dim follows src; rank gain pads with 1s.
271        let _ = src_n;
272        let _ = tgt_n;
273        let dtype = src_shape.dtype();
274        let pad_dims_i64: Vec<i64> = padded_dims
275            .iter()
276            .map(|d| match d {
277                Dim::Static(n) => *n as i64,
278                _ => -1,
279            })
280            .collect();
281        // Borrow the padded shape for Reshape's output.
282        let pad_shape = rlx_ir::Shape::from_dims(&padded_dims, dtype);
283        padded_dims.clear();
284        out.reshape(src, pad_dims_i64, pad_shape)
285    } else {
286        src
287    };
288    out.add_node(
289        Op::Expand {
290            target_shape: target_dims_i64,
291        },
292        vec![to_expand],
293        target.clone(),
294    )
295}
296
297/// Inline `sub` into `out`, wiring `Op::Input` slots to `captures` in
298/// subgraph node order. Returns every output node (declaration order).
299pub fn inline_subgraph_into_outputs(
300    sub: &Graph,
301    captures: &[NodeId],
302    out: &mut Graph,
303) -> Vec<NodeId> {
304    let mut sub_to_parent: HashMap<NodeId, NodeId> = HashMap::new();
305    let mut input_idx = 0usize;
306    for sub_node in sub.nodes() {
307        let new_id = match &sub_node.op {
308            Op::Input { .. } => {
309                let parent_id = captures[input_idx];
310                input_idx += 1;
311                parent_id
312            }
313            _ => {
314                let new_inputs: Vec<NodeId> =
315                    sub_node.inputs.iter().map(|i| sub_to_parent[i]).collect();
316                out.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
317            }
318        };
319        sub_to_parent.insert(sub_node.id, new_id);
320    }
321    assert_eq!(
322        input_idx,
323        captures.len(),
324        "Op::While/If sub-graph: {} Op::Input nodes but {} captures",
325        input_idx,
326        captures.len()
327    );
328    sub.outputs.iter().map(|o| sub_to_parent[o]).collect()
329}
330
331/// Helper: copy `sub`'s nodes into `out`, mapping each Op::Input
332/// by position to the corresponding capture. Returns the new
333/// NodeId in `out` of the sub-graph's first declared output.
334pub fn inline_subgraph_into(sub: &Graph, captures: &[NodeId], out: &mut Graph) -> NodeId {
335    inline_subgraph_into_outputs(sub, captures, out)[0]
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use rlx_ir::op::{Activation, BinaryOp};
342    use rlx_ir::{DType, Shape};
343
344    #[test]
345    fn lower_control_flow_pass_handles_both_if_and_while() {
346        let s = Shape::new(&[2], DType::F32);
347
348        let mut then_g = Graph::new("th");
349        let ti = then_g.input("c", s.clone());
350        let to = then_g.activation(Activation::Relu, ti, s.clone());
351        then_g.set_outputs(vec![to]);
352        let mut else_g = Graph::new("el");
353        let ei = else_g.input("c", s.clone());
354        let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
355        else_g.set_outputs(vec![eo]);
356
357        let mut body_g = Graph::new("body");
358        let bi = body_g.input("c", s.clone());
359        let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
360        body_g.set_outputs(vec![bo]);
361        let mut cond_g = Graph::new("cond");
362        let ci = cond_g.input("c", s.clone());
363        cond_g.set_outputs(vec![ci]);
364
365        let mut g = Graph::new("parent");
366        let x = g.input("x", s.clone());
367        let pred = g.input("p", Shape::new(&[1], DType::F32));
368        let if_out = g.add_node(
369            Op::If {
370                then_branch: Box::new(then_g),
371                else_branch: Box::new(else_g),
372            },
373            vec![pred, x],
374            s.clone(),
375        );
376        let w_out = g.add_node(
377            Op::While {
378                cond: Box::new(cond_g),
379                body: Box::new(body_g),
380                max_iterations: Some(2),
381            },
382            vec![if_out],
383            s.clone(),
384        );
385        g.set_outputs(vec![w_out]);
386
387        let lowered = LowerControlFlow.run(g);
388        let has_if = lowered
389            .nodes()
390            .iter()
391            .any(|n| matches!(n.op, Op::If { .. }));
392        let has_while = lowered
393            .nodes()
394            .iter()
395            .any(|n| matches!(n.op, Op::While { .. }));
396        assert!(
397            !has_if && !has_while,
398            "LowerControlFlow should erase both If and While"
399        );
400        // 1 Where from If; While unroll adds 1 Where per iteration per
401        // carry (MLX semantics, see `unroll_while`).
402        let n_where = lowered
403            .nodes()
404            .iter()
405            .filter(|n| matches!(n.op, Op::Where))
406            .count();
407        let n_mul = lowered
408            .nodes()
409            .iter()
410            .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
411            .count();
412        assert_eq!(
413            n_where, 3,
414            "expected 1 Where from If + 2 from While (N=2, 1 carry)"
415        );
416        assert_eq!(
417            n_mul, 4,
418            "expected 2 body Mul + 2 active*cond_f Mul from While (N=2)"
419        );
420    }
421
422    #[test]
423    fn unroll_while_multi_carry_cond_freezes_updates() {
424        let v_shape = Shape::new(&[2], DType::F32);
425        let s_shape = Shape::new(&[1], DType::F32);
426
427        let mut body = Graph::new("body");
428        let v_in = body.input("v", v_shape.clone());
429        let s_in = body.input("s", s_shape.clone());
430        let one = body.add_node(
431            Op::Constant {
432                data: 1.0_f32.to_le_bytes().to_vec(),
433            },
434            vec![],
435            s_shape.clone(),
436        );
437        let v_out = body.binary(BinaryOp::Add, v_in, one, v_shape.clone());
438        body.set_outputs(vec![v_out, s_in]);
439
440        let mut cond = Graph::new("cond");
441        let v_c = cond.input("v", v_shape.clone());
442        let _s_c = cond.input("s", s_shape.clone());
443        let ten = cond.add_node(
444            Op::Constant {
445                data: 10.0_f32.to_le_bytes().to_vec(),
446            },
447            vec![],
448            s_shape.clone(),
449        );
450        let lt = cond.add_node(
451            Op::Compare(rlx_ir::op::CmpOp::Lt),
452            vec![v_c, ten],
453            Shape::new(&[1], DType::Bool),
454        );
455        cond.set_outputs(vec![lt]);
456
457        let mut g = Graph::new("parent");
458        let v0 = g.input("v0", v_shape.clone());
459        let s0 = g.input("s0", s_shape.clone());
460        let w = g.add_node(
461            Op::While {
462                cond: Box::new(cond),
463                body: Box::new(body),
464                max_iterations: Some(3),
465            },
466            vec![v0, s0],
467            v_shape.clone(),
468        );
469        g.set_outputs(vec![w]);
470
471        let lowered = unroll_while(g);
472        assert!(
473            !lowered
474                .nodes()
475                .iter()
476                .any(|n| matches!(n.op, Op::While { .. })),
477            "While should be erased"
478        );
479        let n_where = lowered
480            .nodes()
481            .iter()
482            .filter(|n| matches!(n.op, Op::Where))
483            .count();
484        assert_eq!(n_where, 6, "expected 3 iters × 2 carries Where masks");
485    }
486}