Skip to main content

rlx_fusion/
control_flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Control-flow lowering passes: `Op::If` → `Where` + inlined
17//! branches; `Op::While` → bounded unroll of body replicas.
18//!
19//! Backends that don't have native sub-graph executors run
20//! `LowerControlFlow` BEFORE the legalize / supported-set check so
21//! they never see `Op::If` or `Op::While`. Used by rlx-cpu and
22//! rlx-metal (the runtime-side `run_if` / `run_while` helpers exist
23//! but the executor wiring through their thunk schedules is more
24//! invasive than this rewrite). Other backends (rlx-wgpu, rlx-cuda,
25//! rlx-rocm, rlx-tpu) ship their own per-backend unfuse passes that
26//! do equivalent work — this module is the portable, IR-level
27//! version.
28//!
29//! Trade-offs:
30//!   * `Op::If` always evaluates **both** branches in the rewritten
31//!     graph. That's the price of expressing it via primitives. Fine
32//!     for inference where Op::If is rare; if a workload hits a
33//!     hot Op::If on a path where both branches are expensive, the
34//!     fix is a backend-native If executor, not this rewrite.
35//!   * `Op::While` requires `max_iterations = Some(N)` — unbounded
36//!     loops have no terminating count and panic with a clear
37//!     message pointing at `rlx_runtime::subgraph::run_while` for
38//!     the dynamic alternative.
39//!
40//! Capture binding (used by both passes): each sub-graph's
41//! `Op::Input` nodes appear in the same order as the parent's
42//! captures (`inputs[1..]` for `Op::If` past the predicate, all
43//! `inputs[..]` for `Op::While`). Sub-graph `Op::Input[i]` rewires
44//! to `captures[i]` when inlined into the parent.
45
46use crate::pass::Pass;
47use rlx_ir::op::BinaryOp;
48use rlx_ir::shape::Dim;
49use rlx_ir::{DType, Graph, NodeId, Op, Shape};
50use std::collections::HashMap;
51
52/// Pass form: rewrites `Op::If` and `Op::While` into primitive ops.
53/// No-op when neither op is present.
54pub struct LowerControlFlow;
55
56impl Pass for LowerControlFlow {
57    fn name(&self) -> &str {
58        "LowerControlFlow"
59    }
60    fn run(&self, graph: Graph) -> Graph {
61        let g = inline_if(graph);
62        unroll_while(g)
63    }
64}
65
66/// Inline `Op::If` sub-graphs into the parent and replace the If
67/// node with `Where(predicate, then_output, else_output)`. Both
68/// branches are present in the rewritten graph and always evaluate.
69pub fn inline_if(g: Graph) -> Graph {
70    let mut out = Graph::new(g.name.clone());
71    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
72    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
73
74    for node in &nodes {
75        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
76        let new_id = match &node.op {
77            Op::If {
78                then_branch,
79                else_branch,
80            } => {
81                let captures: Vec<NodeId> = new_inputs[1..].to_vec();
82                let then_out = inline_subgraph_into(then_branch, &captures, &mut out);
83                let else_out = inline_subgraph_into(else_branch, &captures, &mut out);
84                // Most backends' Where kernel requires the predicate
85                // to share the output's element count (no broadcast
86                // inside the kernel). Expand a smaller predicate up
87                // to the output shape so the rewritten graph runs
88                // out of the box on CPU/Metal.
89                let predicate = expand_to_shape(new_inputs[0], &node.shape, &mut out);
90                out.add_node(
91                    Op::Where,
92                    vec![predicate, then_out, else_out],
93                    node.shape.clone(),
94                )
95            }
96            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
97        };
98        id_map.insert(node.id, new_id);
99    }
100    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
101    out.set_outputs(new_outputs);
102    out
103}
104
105/// Bounded-unroll `Op::While` up to `max_iterations`. Each iteration
106/// inlines `cond` and `body` with all loop-carried captures, then
107/// applies `Where(active, body_out, carried)` per carry (MLX semantics).
108pub fn unroll_while(g: Graph) -> Graph {
109    let mut out = Graph::new(g.name.clone());
110    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
111    let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
112    let scalar_f32 = Shape::new(&[1], DType::F32);
113
114    for node in &nodes {
115        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
116        let new_id = match &node.op {
117            Op::While {
118                cond,
119                body,
120                max_iterations: Some(n),
121                ..
122            } => {
123                if new_inputs.is_empty() {
124                    panic!(
125                        "Op::While unroll: at least one \
126                            loop-carried input required"
127                    );
128                }
129                let one = out.add_node(
130                    Op::Constant {
131                        data: 1.0_f32.to_le_bytes().to_vec(),
132                    },
133                    vec![],
134                    scalar_f32.clone(),
135                );
136                let mut active = one;
137                let mut carried = new_inputs;
138                for _ in 0..*n {
139                    let cond_out = inline_subgraph_into(cond, &carried, &mut out);
140                    let cond_f = cond_to_f32_mask(cond_out, &mut out);
141                    let cond_shape = out.node(cond_f).shape.clone();
142                    let active_lhs = expand_to_shape(active, &cond_shape, &mut out);
143                    active = out.binary(BinaryOp::Mul, active_lhs, cond_f, cond_shape);
144
145                    let body_outs = inline_subgraph_into_outputs(body, &carried, &mut out);
146                    assert_eq!(
147                        body_outs.len(),
148                        carried.len(),
149                        "Op::While: body output count must match loop-carried arity"
150                    );
151                    let mut next = Vec::with_capacity(carried.len());
152                    for (body_out, &prev) in body_outs.iter().zip(carried.iter()) {
153                        let shape = out.node(prev).shape.clone();
154                        let mask = expand_to_shape(active, &shape, &mut out);
155                        let merged = out.add_node(Op::Where, vec![mask, *body_out, prev], shape);
156                        next.push(merged);
157                    }
158                    carried = next;
159                }
160                carried[0]
161            }
162            Op::While {
163                max_iterations: None,
164                ..
165            } => {
166                panic!(
167                    "LowerControlFlow: Op::While requires \
168                        max_iterations = Some(N) for unrolling. \
169                        Either set a bounded max_iterations on the \
170                        forward graph, or use the dynamic \
171                        `rlx_runtime::subgraph::run_while` helper."
172                );
173            }
174            _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
175        };
176        id_map.insert(node.id, new_id);
177    }
178    let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
179    out.set_outputs(new_outputs);
180    out
181}
182
183/// Cast a cond-subgraph output to an f32 loop mask. Bool uses the
184/// Bool→I32→F32 chain because CPU `Cast` Bool→F32 is a raw byte copy.
185fn cond_to_f32_mask(cond_out: NodeId, out: &mut Graph) -> NodeId {
186    let cond_shape = out.node(cond_out).shape.clone();
187    match cond_shape.dtype() {
188        DType::F32 => cond_out,
189        DType::Bool => {
190            let f32_shape = cond_shape.clone().with_dtype(DType::F32);
191            let i32_shape = cond_shape.with_dtype(DType::I32);
192            let as_i32 = out.add_node(Op::Cast { to: DType::I32 }, vec![cond_out], i32_shape);
193            out.add_node(Op::Cast { to: DType::F32 }, vec![as_i32], f32_shape)
194        }
195        _ => out.add_node(
196            Op::Cast { to: DType::F32 },
197            vec![cond_out],
198            cond_shape.with_dtype(DType::F32),
199        ),
200    }
201}
202
203/// Expand a tensor up to `target` via `Op::Expand` if its shape
204/// (specifically its element count) differs from the target. Used to
205/// promote a scalar / smaller predicate up to the Where output shape
206/// during `Op::If` lowering.
207fn expand_to_shape(src: NodeId, target: &rlx_ir::Shape, out: &mut Graph) -> NodeId {
208    let src_shape = out.node(src).shape.clone();
209    let src_n = src_shape
210        .dims()
211        .iter()
212        .filter_map(|d| match d {
213            Dim::Static(n) => Some(*n),
214            _ => None,
215        })
216        .product::<usize>();
217    let tgt_n = target
218        .dims()
219        .iter()
220        .filter_map(|d| match d {
221            Dim::Static(n) => Some(*n),
222            _ => None,
223        })
224        .product::<usize>();
225    if src_shape.dims() == target.dims() {
226        return src;
227    }
228    let target_dims_i64: Vec<i64> = target
229        .dims()
230        .iter()
231        .map(|d| match d {
232            Dim::Static(n) => *n as i64,
233            _ => -1,
234        })
235        .collect();
236    // Op::Expand requires equal rank (broadcast via 1-dim only).
237    // If src has a smaller rank, left-pad with 1s via a Reshape first.
238    let src_rank = src_shape.rank();
239    let tgt_rank = target.dims().len();
240    let to_expand = if src_rank < tgt_rank {
241        let mut padded_dims: Vec<Dim> = std::iter::repeat_n(Dim::Static(1), tgt_rank - src_rank)
242            .chain(src_shape.dims().iter().copied())
243            .collect();
244        // Width of last dim follows src; rank gain pads with 1s.
245        let _ = src_n;
246        let _ = tgt_n;
247        let dtype = src_shape.dtype();
248        let pad_dims_i64: Vec<i64> = padded_dims
249            .iter()
250            .map(|d| match d {
251                Dim::Static(n) => *n as i64,
252                _ => -1,
253            })
254            .collect();
255        // Borrow the padded shape for Reshape's output.
256        let pad_shape = rlx_ir::Shape::from_dims(&padded_dims, dtype);
257        padded_dims.clear();
258        out.reshape(src, pad_dims_i64, pad_shape)
259    } else {
260        src
261    };
262    out.add_node(
263        Op::Expand {
264            target_shape: target_dims_i64,
265        },
266        vec![to_expand],
267        target.clone(),
268    )
269}
270
271/// Inline `sub` into `out`, wiring `Op::Input` slots to `captures` in
272/// subgraph node order. Returns every output node (declaration order).
273pub fn inline_subgraph_into_outputs(
274    sub: &Graph,
275    captures: &[NodeId],
276    out: &mut Graph,
277) -> Vec<NodeId> {
278    let mut sub_to_parent: HashMap<NodeId, NodeId> = HashMap::new();
279    let mut input_idx = 0usize;
280    for sub_node in sub.nodes() {
281        let new_id = match &sub_node.op {
282            Op::Input { .. } => {
283                let parent_id = captures[input_idx];
284                input_idx += 1;
285                parent_id
286            }
287            _ => {
288                let new_inputs: Vec<NodeId> =
289                    sub_node.inputs.iter().map(|i| sub_to_parent[i]).collect();
290                out.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
291            }
292        };
293        sub_to_parent.insert(sub_node.id, new_id);
294    }
295    assert_eq!(
296        input_idx,
297        captures.len(),
298        "Op::While/If sub-graph: {} Op::Input nodes but {} captures",
299        input_idx,
300        captures.len()
301    );
302    sub.outputs.iter().map(|o| sub_to_parent[o]).collect()
303}
304
305/// Helper: copy `sub`'s nodes into `out`, mapping each Op::Input
306/// by position to the corresponding capture. Returns the new
307/// NodeId in `out` of the sub-graph's first declared output.
308pub fn inline_subgraph_into(sub: &Graph, captures: &[NodeId], out: &mut Graph) -> NodeId {
309    inline_subgraph_into_outputs(sub, captures, out)[0]
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use rlx_ir::op::{Activation, BinaryOp};
316    use rlx_ir::{DType, Shape};
317
318    #[test]
319    fn lower_control_flow_pass_handles_both_if_and_while() {
320        let s = Shape::new(&[2], DType::F32);
321
322        let mut then_g = Graph::new("th");
323        let ti = then_g.input("c", s.clone());
324        let to = then_g.activation(Activation::Relu, ti, s.clone());
325        then_g.set_outputs(vec![to]);
326        let mut else_g = Graph::new("el");
327        let ei = else_g.input("c", s.clone());
328        let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
329        else_g.set_outputs(vec![eo]);
330
331        let mut body_g = Graph::new("body");
332        let bi = body_g.input("c", s.clone());
333        let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
334        body_g.set_outputs(vec![bo]);
335        let mut cond_g = Graph::new("cond");
336        let ci = cond_g.input("c", s.clone());
337        cond_g.set_outputs(vec![ci]);
338
339        let mut g = Graph::new("parent");
340        let x = g.input("x", s.clone());
341        let pred = g.input("p", Shape::new(&[1], DType::F32));
342        let if_out = g.add_node(
343            Op::If {
344                then_branch: Box::new(then_g),
345                else_branch: Box::new(else_g),
346            },
347            vec![pred, x],
348            s.clone(),
349        );
350        let w_out = g.add_node(
351            Op::While {
352                cond: Box::new(cond_g),
353                body: Box::new(body_g),
354                max_iterations: Some(2),
355            },
356            vec![if_out],
357            s.clone(),
358        );
359        g.set_outputs(vec![w_out]);
360
361        let lowered = LowerControlFlow.run(g);
362        let has_if = lowered
363            .nodes()
364            .iter()
365            .any(|n| matches!(n.op, Op::If { .. }));
366        let has_while = lowered
367            .nodes()
368            .iter()
369            .any(|n| matches!(n.op, Op::While { .. }));
370        assert!(
371            !has_if && !has_while,
372            "LowerControlFlow should erase both If and While"
373        );
374        // 1 Where from If; While unroll adds 1 Where per iteration per
375        // carry (MLX semantics, see `unroll_while`).
376        let n_where = lowered
377            .nodes()
378            .iter()
379            .filter(|n| matches!(n.op, Op::Where))
380            .count();
381        let n_mul = lowered
382            .nodes()
383            .iter()
384            .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
385            .count();
386        assert_eq!(
387            n_where, 3,
388            "expected 1 Where from If + 2 from While (N=2, 1 carry)"
389        );
390        assert_eq!(
391            n_mul, 4,
392            "expected 2 body Mul + 2 active*cond_f Mul from While (N=2)"
393        );
394    }
395
396    #[test]
397    fn unroll_while_multi_carry_cond_freezes_updates() {
398        let v_shape = Shape::new(&[2], DType::F32);
399        let s_shape = Shape::new(&[1], DType::F32);
400
401        let mut body = Graph::new("body");
402        let v_in = body.input("v", v_shape.clone());
403        let s_in = body.input("s", s_shape.clone());
404        let one = body.add_node(
405            Op::Constant {
406                data: 1.0_f32.to_le_bytes().to_vec(),
407            },
408            vec![],
409            s_shape.clone(),
410        );
411        let v_out = body.binary(BinaryOp::Add, v_in, one, v_shape.clone());
412        body.set_outputs(vec![v_out, s_in]);
413
414        let mut cond = Graph::new("cond");
415        let v_c = cond.input("v", v_shape.clone());
416        let _s_c = cond.input("s", s_shape.clone());
417        let ten = cond.add_node(
418            Op::Constant {
419                data: 10.0_f32.to_le_bytes().to_vec(),
420            },
421            vec![],
422            s_shape.clone(),
423        );
424        let lt = cond.add_node(
425            Op::Compare(rlx_ir::op::CmpOp::Lt),
426            vec![v_c, ten],
427            Shape::new(&[1], DType::Bool),
428        );
429        cond.set_outputs(vec![lt]);
430
431        let mut g = Graph::new("parent");
432        let v0 = g.input("v0", v_shape.clone());
433        let s0 = g.input("s0", s_shape.clone());
434        let w = g.add_node(
435            Op::While {
436                cond: Box::new(cond),
437                body: Box::new(body),
438                max_iterations: Some(3),
439            },
440            vec![v0, s0],
441            v_shape.clone(),
442        );
443        g.set_outputs(vec![w]);
444
445        let lowered = unroll_while(g);
446        assert!(
447            !lowered
448                .nodes()
449                .iter()
450                .any(|n| matches!(n.op, Op::While { .. })),
451            "While should be erased"
452        );
453        let n_where = lowered
454            .nodes()
455            .iter()
456            .filter(|n| matches!(n.op, Op::Where))
457            .count();
458        assert_eq!(n_where, 6, "expected 3 iters × 2 carries Where masks");
459    }
460
461    #[test]
462    fn unroll_while_squares_on_cpu_thunks() {
463        let s = Shape::new(&[2], DType::F32);
464        let mut body_g = Graph::new("body");
465        let bi = body_g.input("c", s.clone());
466        let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
467        body_g.set_outputs(vec![bo]);
468        let mut cond_g = Graph::new("cond");
469        let ci = cond_g.input("c", s.clone());
470        cond_g.set_outputs(vec![ci]);
471
472        let mut g = Graph::new("while_test");
473        let x = g.input("x", s.clone());
474        let y = g.add_node(
475            Op::While {
476                cond: Box::new(cond_g),
477                body: Box::new(body_g),
478                max_iterations: Some(3),
479            },
480            vec![x],
481            s.clone(),
482        );
483        g.set_outputs(vec![y]);
484
485        let lowered = unroll_while(g);
486        assert!(
487            !lowered
488                .nodes()
489                .iter()
490                .any(|n| matches!(n.op, Op::While { .. }))
491        );
492
493        let x_id = lowered
494            .nodes()
495            .iter()
496            .find(|n| matches!(&n.op, Op::Input { name, .. } if name == "x"))
497            .expect("lowered graph missing input x")
498            .id;
499        let plan = rlx_opt::memory::plan_memory(&lowered);
500        let mut arena = rlx_cpu::arena::Arena::from_plan(plan);
501        let sched = rlx_cpu::thunk::compile_thunks(&lowered, &arena);
502        for node in lowered.nodes() {
503            if let Op::Constant { data } = &node.op
504                && arena.has_buffer(node.id)
505                && !data.is_empty()
506            {
507                let buf = arena.slice_mut(node.id);
508                let n_floats = data.len() / 4;
509                let n = buf.len().min(n_floats);
510                for i in 0..n {
511                    let bytes = [
512                        data[i * 4],
513                        data[i * 4 + 1],
514                        data[i * 4 + 2],
515                        data[i * 4 + 3],
516                    ];
517                    buf[i] = f32::from_le_bytes(bytes);
518                }
519            }
520        }
521        let x_off = arena.byte_offset(x_id);
522        let out_id = lowered.outputs[0];
523        let out_off = arena.byte_offset(out_id);
524        let buf = arena.raw_buf_mut();
525        unsafe {
526            let px = buf.as_mut_ptr().add(x_off) as *mut f32;
527            *px.add(0) = 2.0;
528            *px.add(1) = 3.0;
529        }
530        rlx_cpu::thunk::execute_thunks(&sched, arena.raw_buf_mut());
531        let got: Vec<f32> = unsafe {
532            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
533            vec![*p.add(0), *p.add(1)]
534        };
535        let want = [256.0_f32, 6561.0_f32];
536        for (i, (&a, &b)) in got.iter().zip(&want).enumerate() {
537            assert!(
538                (a - b).abs() < 1e-3,
539                "unrolled while[{i}]: got {a} want {b}"
540            );
541        }
542    }
543}