use crate::pass::Pass;
use rlx_ir::op::BinaryOp;
use rlx_ir::shape::Dim;
use rlx_ir::{DType, Graph, NodeId, Op, Shape};
use std::collections::HashMap;
pub struct LowerControlFlow;
impl Pass for LowerControlFlow {
fn name(&self) -> &str {
"LowerControlFlow"
}
fn run(&self, graph: Graph) -> Graph {
let g = inline_if(graph);
unroll_while(g)
}
}
pub fn inline_if(g: Graph) -> Graph {
let mut out = Graph::new(g.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
for node in &nodes {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
let new_id = match &node.op {
Op::If {
then_branch,
else_branch,
} => {
let captures: Vec<NodeId> = new_inputs[1..].to_vec();
let then_out = inline_subgraph_into(then_branch, &captures, &mut out);
let else_out = inline_subgraph_into(else_branch, &captures, &mut out);
let predicate = expand_to_shape(new_inputs[0], &node.shape, &mut out);
out.add_node(
Op::Where,
vec![predicate, then_out, else_out],
node.shape.clone(),
)
}
_ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
out.set_outputs(new_outputs);
out
}
pub fn unroll_while(g: Graph) -> Graph {
let mut out = Graph::new(g.name.clone());
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
let scalar_f32 = Shape::new(&[1], DType::F32);
for node in &nodes {
let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
let new_id = match &node.op {
Op::While {
cond,
body,
max_iterations: Some(n),
..
} => {
if new_inputs.is_empty() {
panic!(
"Op::While unroll: at least one \
loop-carried input required"
);
}
let one = out.add_node(
Op::Constant {
data: 1.0_f32.to_le_bytes().to_vec(),
},
vec![],
scalar_f32.clone(),
);
let mut active = one;
let mut carried = new_inputs;
for _ in 0..*n {
let cond_out = inline_subgraph_into(cond, &carried, &mut out);
let cond_f = cond_to_f32_mask(cond_out, &mut out);
let cond_shape = out.node(cond_f).shape.clone();
let active_lhs = expand_to_shape(active, &cond_shape, &mut out);
active = out.binary(BinaryOp::Mul, active_lhs, cond_f, cond_shape);
let body_outs = inline_subgraph_into_outputs(body, &carried, &mut out);
assert_eq!(
body_outs.len(),
carried.len(),
"Op::While: body output count must match loop-carried arity"
);
let mut next = Vec::with_capacity(carried.len());
for (body_out, &prev) in body_outs.iter().zip(carried.iter()) {
let shape = out.node(prev).shape.clone();
let mask = expand_to_shape(active, &shape, &mut out);
let merged = out.add_node(Op::Where, vec![mask, *body_out, prev], shape);
next.push(merged);
}
carried = next;
}
carried[0]
}
Op::While {
max_iterations: None,
..
} => {
panic!(
"LowerControlFlow: Op::While requires \
max_iterations = Some(N) for unrolling. \
Either set a bounded max_iterations on the \
forward graph, or use the dynamic \
`rlx_runtime::subgraph::run_while` helper."
);
}
_ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
};
id_map.insert(node.id, new_id);
}
let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
out.set_outputs(new_outputs);
out
}
fn cond_to_f32_mask(cond_out: NodeId, out: &mut Graph) -> NodeId {
let cond_shape = out.node(cond_out).shape.clone();
match cond_shape.dtype() {
DType::F32 => cond_out,
DType::Bool => {
let f32_shape = cond_shape.clone().with_dtype(DType::F32);
let i32_shape = cond_shape.with_dtype(DType::I32);
let as_i32 = out.add_node(Op::Cast { to: DType::I32 }, vec![cond_out], i32_shape);
out.add_node(Op::Cast { to: DType::F32 }, vec![as_i32], f32_shape)
}
_ => out.add_node(
Op::Cast { to: DType::F32 },
vec![cond_out],
cond_shape.with_dtype(DType::F32),
),
}
}
fn expand_to_shape(src: NodeId, target: &rlx_ir::Shape, out: &mut Graph) -> NodeId {
let src_shape = out.node(src).shape.clone();
let src_n = src_shape
.dims()
.iter()
.filter_map(|d| match d {
Dim::Static(n) => Some(*n),
_ => None,
})
.product::<usize>();
let tgt_n = target
.dims()
.iter()
.filter_map(|d| match d {
Dim::Static(n) => Some(*n),
_ => None,
})
.product::<usize>();
if src_shape.dims() == target.dims() {
return src;
}
let target_dims_i64: Vec<i64> = target
.dims()
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
_ => -1,
})
.collect();
let src_rank = src_shape.rank();
let tgt_rank = target.dims().len();
let to_expand = if src_rank < tgt_rank {
let mut padded_dims: Vec<Dim> = std::iter::repeat_n(Dim::Static(1), tgt_rank - src_rank)
.chain(src_shape.dims().iter().copied())
.collect();
let _ = src_n;
let _ = tgt_n;
let dtype = src_shape.dtype();
let pad_dims_i64: Vec<i64> = padded_dims
.iter()
.map(|d| match d {
Dim::Static(n) => *n as i64,
_ => -1,
})
.collect();
let pad_shape = rlx_ir::Shape::from_dims(&padded_dims, dtype);
padded_dims.clear();
out.reshape(src, pad_dims_i64, pad_shape)
} else {
src
};
out.add_node(
Op::Expand {
target_shape: target_dims_i64,
},
vec![to_expand],
target.clone(),
)
}
pub fn inline_subgraph_into_outputs(
sub: &Graph,
captures: &[NodeId],
out: &mut Graph,
) -> Vec<NodeId> {
let mut sub_to_parent: HashMap<NodeId, NodeId> = HashMap::new();
let mut input_idx = 0usize;
for sub_node in sub.nodes() {
let new_id = match &sub_node.op {
Op::Input { .. } => {
let parent_id = captures[input_idx];
input_idx += 1;
parent_id
}
_ => {
let new_inputs: Vec<NodeId> =
sub_node.inputs.iter().map(|i| sub_to_parent[i]).collect();
out.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
}
};
sub_to_parent.insert(sub_node.id, new_id);
}
assert_eq!(
input_idx,
captures.len(),
"Op::While/If sub-graph: {} Op::Input nodes but {} captures",
input_idx,
captures.len()
);
sub.outputs.iter().map(|o| sub_to_parent[o]).collect()
}
pub fn inline_subgraph_into(sub: &Graph, captures: &[NodeId], out: &mut Graph) -> NodeId {
inline_subgraph_into_outputs(sub, captures, out)[0]
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::{DType, Shape};
#[test]
fn lower_control_flow_pass_handles_both_if_and_while() {
let s = Shape::new(&[2], DType::F32);
let mut then_g = Graph::new("th");
let ti = then_g.input("c", s.clone());
let to = then_g.activation(Activation::Relu, ti, s.clone());
then_g.set_outputs(vec![to]);
let mut else_g = Graph::new("el");
let ei = else_g.input("c", s.clone());
let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
else_g.set_outputs(vec![eo]);
let mut body_g = Graph::new("body");
let bi = body_g.input("c", s.clone());
let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
body_g.set_outputs(vec![bo]);
let mut cond_g = Graph::new("cond");
let ci = cond_g.input("c", s.clone());
cond_g.set_outputs(vec![ci]);
let mut g = Graph::new("parent");
let x = g.input("x", s.clone());
let pred = g.input("p", Shape::new(&[1], DType::F32));
let if_out = g.add_node(
Op::If {
then_branch: Box::new(then_g),
else_branch: Box::new(else_g),
},
vec![pred, x],
s.clone(),
);
let w_out = g.add_node(
Op::While {
cond: Box::new(cond_g),
body: Box::new(body_g),
max_iterations: Some(2),
},
vec![if_out],
s.clone(),
);
g.set_outputs(vec![w_out]);
let lowered = LowerControlFlow.run(g);
let has_if = lowered
.nodes()
.iter()
.any(|n| matches!(n.op, Op::If { .. }));
let has_while = lowered
.nodes()
.iter()
.any(|n| matches!(n.op, Op::While { .. }));
assert!(
!has_if && !has_while,
"LowerControlFlow should erase both If and While"
);
let n_where = lowered
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Where))
.count();
let n_mul = lowered
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
.count();
assert_eq!(
n_where, 3,
"expected 1 Where from If + 2 from While (N=2, 1 carry)"
);
assert_eq!(
n_mul, 4,
"expected 2 body Mul + 2 active*cond_f Mul from While (N=2)"
);
}
#[test]
fn unroll_while_multi_carry_cond_freezes_updates() {
let v_shape = Shape::new(&[2], DType::F32);
let s_shape = Shape::new(&[1], DType::F32);
let mut body = Graph::new("body");
let v_in = body.input("v", v_shape.clone());
let s_in = body.input("s", s_shape.clone());
let one = body.add_node(
Op::Constant {
data: 1.0_f32.to_le_bytes().to_vec(),
},
vec![],
s_shape.clone(),
);
let v_out = body.binary(BinaryOp::Add, v_in, one, v_shape.clone());
body.set_outputs(vec![v_out, s_in]);
let mut cond = Graph::new("cond");
let v_c = cond.input("v", v_shape.clone());
let _s_c = cond.input("s", s_shape.clone());
let ten = cond.add_node(
Op::Constant {
data: 10.0_f32.to_le_bytes().to_vec(),
},
vec![],
s_shape.clone(),
);
let lt = cond.add_node(
Op::Compare(rlx_ir::op::CmpOp::Lt),
vec![v_c, ten],
Shape::new(&[1], DType::Bool),
);
cond.set_outputs(vec![lt]);
let mut g = Graph::new("parent");
let v0 = g.input("v0", v_shape.clone());
let s0 = g.input("s0", s_shape.clone());
let w = g.add_node(
Op::While {
cond: Box::new(cond),
body: Box::new(body),
max_iterations: Some(3),
},
vec![v0, s0],
v_shape.clone(),
);
g.set_outputs(vec![w]);
let lowered = unroll_while(g);
assert!(
!lowered
.nodes()
.iter()
.any(|n| matches!(n.op, Op::While { .. })),
"While should be erased"
);
let n_where = lowered
.nodes()
.iter()
.filter(|n| matches!(n.op, Op::Where))
.count();
assert_eq!(n_where, 6, "expected 3 iters × 2 carries Where masks");
}
#[test]
fn unroll_while_squares_on_cpu_thunks() {
let s = Shape::new(&[2], DType::F32);
let mut body_g = Graph::new("body");
let bi = body_g.input("c", s.clone());
let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
body_g.set_outputs(vec![bo]);
let mut cond_g = Graph::new("cond");
let ci = cond_g.input("c", s.clone());
cond_g.set_outputs(vec![ci]);
let mut g = Graph::new("while_test");
let x = g.input("x", s.clone());
let y = g.add_node(
Op::While {
cond: Box::new(cond_g),
body: Box::new(body_g),
max_iterations: Some(3),
},
vec![x],
s.clone(),
);
g.set_outputs(vec![y]);
let lowered = unroll_while(g);
assert!(
!lowered
.nodes()
.iter()
.any(|n| matches!(n.op, Op::While { .. }))
);
let x_id = lowered
.nodes()
.iter()
.find(|n| matches!(&n.op, Op::Input { name, .. } if name == "x"))
.expect("lowered graph missing input x")
.id;
let plan = rlx_opt::memory::plan_memory(&lowered);
let mut arena = rlx_cpu::arena::Arena::from_plan(plan);
let sched = rlx_cpu::thunk::compile_thunks(&lowered, &arena);
for node in lowered.nodes() {
if let Op::Constant { data } = &node.op
&& arena.has_buffer(node.id)
&& !data.is_empty()
{
let buf = arena.slice_mut(node.id);
let n_floats = data.len() / 4;
let n = buf.len().min(n_floats);
for i in 0..n {
let bytes = [
data[i * 4],
data[i * 4 + 1],
data[i * 4 + 2],
data[i * 4 + 3],
];
buf[i] = f32::from_le_bytes(bytes);
}
}
}
let x_off = arena.byte_offset(x_id);
let out_id = lowered.outputs[0];
let out_off = arena.byte_offset(out_id);
let buf = arena.raw_buf_mut();
unsafe {
let px = buf.as_mut_ptr().add(x_off) as *mut f32;
*px.add(0) = 2.0;
*px.add(1) = 3.0;
}
rlx_cpu::thunk::execute_thunks(&sched, arena.raw_buf_mut());
let got: Vec<f32> = unsafe {
let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
vec![*p.add(0), *p.add(1)]
};
let want = [256.0_f32, 6561.0_f32];
for (i, (&a, &b)) in got.iter().zip(&want).enumerate() {
assert!(
(a - b).abs() < 1e-3,
"unrolled while[{i}]: got {a} want {b}"
);
}
}
}