1use crate::pass::Pass;
47use rlx_ir::op::{BinaryOp, CmpOp, ReduceOp};
48use rlx_ir::shape::Dim;
49use rlx_ir::{DType, Graph, NodeId, Op, Shape};
50use std::collections::HashMap;
51
52pub struct LowerControlFlow;
55
56impl Pass for LowerControlFlow {
57 fn name(&self) -> &str {
58 "LowerControlFlow"
59 }
60 fn run(&self, graph: Graph) -> Graph {
61 let g = inline_if(graph);
62 unroll_while(g)
63 }
64}
65
66pub fn inline_if(g: Graph) -> Graph {
70 let mut out = Graph::new(g.name.clone());
71 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
72 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
73
74 for node in &nodes {
75 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
76 let new_id = match &node.op {
77 Op::If {
78 then_branch,
79 else_branch,
80 } => {
81 let captures: Vec<NodeId> = new_inputs[1..].to_vec();
82 let then_out = inline_subgraph_into(then_branch, &captures, &mut out);
83 let else_out = inline_subgraph_into(else_branch, &captures, &mut out);
84 let predicate = expand_to_shape(new_inputs[0], &node.shape, &mut out);
90 out.add_node(
91 Op::Where,
92 vec![predicate, then_out, else_out],
93 node.shape.clone(),
94 )
95 }
96 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
97 };
98 id_map.insert(node.id, new_id);
99 }
100 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
101 out.set_outputs(new_outputs);
102 out
103}
104
105pub fn unroll_while(g: Graph) -> Graph {
109 let mut out = Graph::new(g.name.clone());
110 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
111 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
112 let scalar_f32 = Shape::new(&[1], DType::F32);
113
114 for node in &nodes {
115 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
116 let new_id = match &node.op {
117 Op::While {
118 cond,
119 body,
120 max_iterations: Some(n),
121 ..
122 } => {
123 if new_inputs.is_empty() {
124 panic!(
125 "Op::While unroll: at least one \
126 loop-carried input required"
127 );
128 }
129 let one = out.add_node(
130 Op::Constant {
131 data: 1.0_f32.to_le_bytes().to_vec(),
132 },
133 vec![],
134 scalar_f32.clone(),
135 );
136 let mut active = one;
137 let mut carried = new_inputs;
138 for _ in 0..*n {
139 let cond_out = inline_subgraph_into(cond, &carried, &mut out);
140 let cond_f = cond_to_scalar_f32(cond_out, &mut out, &scalar_f32);
141 active = out.binary(BinaryOp::Mul, active, cond_f, scalar_f32.clone());
142
143 let body_outs = inline_subgraph_into_outputs(body, &carried, &mut out);
144 assert_eq!(
145 body_outs.len(),
146 carried.len(),
147 "Op::While: body output count must match loop-carried arity"
148 );
149 let mut next = Vec::with_capacity(carried.len());
150 for (body_out, &prev) in body_outs.iter().zip(carried.iter()) {
151 let shape = out.node(prev).shape.clone();
152 let mask = expand_to_shape(active, &shape, &mut out);
153 let merged = out.add_node(Op::Where, vec![mask, *body_out, prev], shape);
154 next.push(merged);
155 }
156 carried = next;
157 }
158 carried[0]
159 }
160 Op::While {
161 max_iterations: None,
162 ..
163 } => {
164 panic!(
165 "LowerControlFlow: Op::While requires \
166 max_iterations = Some(N) for unrolling. \
167 Either set a bounded max_iterations on the \
168 forward graph, or use the dynamic \
169 `rlx_runtime::subgraph::run_while` helper."
170 );
171 }
172 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
173 };
174 id_map.insert(node.id, new_id);
175 }
176 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
177 out.set_outputs(new_outputs);
178 out
179}
180
181fn cond_to_scalar_f32(cond_out: NodeId, out: &mut Graph, scalar_f32: &Shape) -> NodeId {
185 let cond_shape = out.node(cond_out).shape.clone();
186 let n = cond_shape
187 .dims()
188 .iter()
189 .filter_map(|d| match d {
190 Dim::Static(n) => Some(*n),
191 _ => None,
192 })
193 .product::<usize>();
194 let as_f32 = if cond_shape.dtype() == DType::F32 {
195 cond_out
196 } else {
197 out.add_node(
198 Op::Cast { to: DType::F32 },
199 vec![cond_out],
200 cond_shape.with_dtype(DType::F32),
201 )
202 };
203 if n <= 1 {
204 return as_f32;
205 }
206 let as_f32_shape = out.node(as_f32).shape.clone();
207 let rank = as_f32_shape.rank();
208 let zero = out.add_node(
209 Op::Constant {
210 data: 0.0_f32.to_le_bytes().to_vec(),
211 },
212 vec![],
213 scalar_f32.clone(),
214 );
215 let nonzero = out.add_node(
216 Op::Compare(CmpOp::Ne),
217 vec![as_f32, zero],
218 as_f32_shape.clone().with_dtype(DType::Bool),
219 );
220 let nonzero_f = out.add_node(
221 Op::Cast { to: DType::F32 },
222 vec![nonzero],
223 as_f32_shape.with_dtype(DType::F32),
224 );
225 let axes: Vec<usize> = (0..rank).collect();
226 out.reduce(nonzero_f, ReduceOp::Min, axes, true, scalar_f32.clone())
227}
228
229fn expand_to_shape(src: NodeId, target: &rlx_ir::Shape, out: &mut Graph) -> NodeId {
234 let src_shape = out.node(src).shape.clone();
235 let src_n = src_shape
236 .dims()
237 .iter()
238 .filter_map(|d| match d {
239 Dim::Static(n) => Some(*n),
240 _ => None,
241 })
242 .product::<usize>();
243 let tgt_n = target
244 .dims()
245 .iter()
246 .filter_map(|d| match d {
247 Dim::Static(n) => Some(*n),
248 _ => None,
249 })
250 .product::<usize>();
251 if src_shape.dims() == target.dims() {
252 return src;
253 }
254 let target_dims_i64: Vec<i64> = target
255 .dims()
256 .iter()
257 .map(|d| match d {
258 Dim::Static(n) => *n as i64,
259 _ => -1,
260 })
261 .collect();
262 let src_rank = src_shape.rank();
265 let tgt_rank = target.dims().len();
266 let to_expand = if src_rank < tgt_rank {
267 let mut padded_dims: Vec<Dim> = std::iter::repeat_n(Dim::Static(1), tgt_rank - src_rank)
268 .chain(src_shape.dims().iter().copied())
269 .collect();
270 let _ = src_n;
272 let _ = tgt_n;
273 let dtype = src_shape.dtype();
274 let pad_dims_i64: Vec<i64> = padded_dims
275 .iter()
276 .map(|d| match d {
277 Dim::Static(n) => *n as i64,
278 _ => -1,
279 })
280 .collect();
281 let pad_shape = rlx_ir::Shape::from_dims(&padded_dims, dtype);
283 padded_dims.clear();
284 out.reshape(src, pad_dims_i64, pad_shape)
285 } else {
286 src
287 };
288 out.add_node(
289 Op::Expand {
290 target_shape: target_dims_i64,
291 },
292 vec![to_expand],
293 target.clone(),
294 )
295}
296
297pub fn inline_subgraph_into_outputs(
300 sub: &Graph,
301 captures: &[NodeId],
302 out: &mut Graph,
303) -> Vec<NodeId> {
304 let mut sub_to_parent: HashMap<NodeId, NodeId> = HashMap::new();
305 let mut input_idx = 0usize;
306 for sub_node in sub.nodes() {
307 let new_id = match &sub_node.op {
308 Op::Input { .. } => {
309 let parent_id = captures[input_idx];
310 input_idx += 1;
311 parent_id
312 }
313 _ => {
314 let new_inputs: Vec<NodeId> =
315 sub_node.inputs.iter().map(|i| sub_to_parent[i]).collect();
316 out.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
317 }
318 };
319 sub_to_parent.insert(sub_node.id, new_id);
320 }
321 assert_eq!(
322 input_idx,
323 captures.len(),
324 "Op::While/If sub-graph: {} Op::Input nodes but {} captures",
325 input_idx,
326 captures.len()
327 );
328 sub.outputs.iter().map(|o| sub_to_parent[o]).collect()
329}
330
331pub fn inline_subgraph_into(sub: &Graph, captures: &[NodeId], out: &mut Graph) -> NodeId {
335 inline_subgraph_into_outputs(sub, captures, out)[0]
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use rlx_ir::op::{Activation, BinaryOp};
342 use rlx_ir::{DType, Shape};
343
344 #[test]
345 fn lower_control_flow_pass_handles_both_if_and_while() {
346 let s = Shape::new(&[2], DType::F32);
347
348 let mut then_g = Graph::new("th");
349 let ti = then_g.input("c", s.clone());
350 let to = then_g.activation(Activation::Relu, ti, s.clone());
351 then_g.set_outputs(vec![to]);
352 let mut else_g = Graph::new("el");
353 let ei = else_g.input("c", s.clone());
354 let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
355 else_g.set_outputs(vec![eo]);
356
357 let mut body_g = Graph::new("body");
358 let bi = body_g.input("c", s.clone());
359 let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
360 body_g.set_outputs(vec![bo]);
361 let mut cond_g = Graph::new("cond");
362 let ci = cond_g.input("c", s.clone());
363 cond_g.set_outputs(vec![ci]);
364
365 let mut g = Graph::new("parent");
366 let x = g.input("x", s.clone());
367 let pred = g.input("p", Shape::new(&[1], DType::F32));
368 let if_out = g.add_node(
369 Op::If {
370 then_branch: Box::new(then_g),
371 else_branch: Box::new(else_g),
372 },
373 vec![pred, x],
374 s.clone(),
375 );
376 let w_out = g.add_node(
377 Op::While {
378 cond: Box::new(cond_g),
379 body: Box::new(body_g),
380 max_iterations: Some(2),
381 },
382 vec![if_out],
383 s.clone(),
384 );
385 g.set_outputs(vec![w_out]);
386
387 let lowered = LowerControlFlow.run(g);
388 let has_if = lowered
389 .nodes()
390 .iter()
391 .any(|n| matches!(n.op, Op::If { .. }));
392 let has_while = lowered
393 .nodes()
394 .iter()
395 .any(|n| matches!(n.op, Op::While { .. }));
396 assert!(
397 !has_if && !has_while,
398 "LowerControlFlow should erase both If and While"
399 );
400 let n_where = lowered
403 .nodes()
404 .iter()
405 .filter(|n| matches!(n.op, Op::Where))
406 .count();
407 let n_mul = lowered
408 .nodes()
409 .iter()
410 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
411 .count();
412 assert_eq!(
413 n_where, 3,
414 "expected 1 Where from If + 2 from While (N=2, 1 carry)"
415 );
416 assert_eq!(
417 n_mul, 4,
418 "expected 2 body Mul + 2 active*cond_f Mul from While (N=2)"
419 );
420 }
421
422 #[test]
423 fn unroll_while_multi_carry_cond_freezes_updates() {
424 let v_shape = Shape::new(&[2], DType::F32);
425 let s_shape = Shape::new(&[1], DType::F32);
426
427 let mut body = Graph::new("body");
428 let v_in = body.input("v", v_shape.clone());
429 let s_in = body.input("s", s_shape.clone());
430 let one = body.add_node(
431 Op::Constant {
432 data: 1.0_f32.to_le_bytes().to_vec(),
433 },
434 vec![],
435 s_shape.clone(),
436 );
437 let v_out = body.binary(BinaryOp::Add, v_in, one, v_shape.clone());
438 body.set_outputs(vec![v_out, s_in]);
439
440 let mut cond = Graph::new("cond");
441 let v_c = cond.input("v", v_shape.clone());
442 let _s_c = cond.input("s", s_shape.clone());
443 let ten = cond.add_node(
444 Op::Constant {
445 data: 10.0_f32.to_le_bytes().to_vec(),
446 },
447 vec![],
448 s_shape.clone(),
449 );
450 let lt = cond.add_node(
451 Op::Compare(rlx_ir::op::CmpOp::Lt),
452 vec![v_c, ten],
453 Shape::new(&[1], DType::Bool),
454 );
455 cond.set_outputs(vec![lt]);
456
457 let mut g = Graph::new("parent");
458 let v0 = g.input("v0", v_shape.clone());
459 let s0 = g.input("s0", s_shape.clone());
460 let w = g.add_node(
461 Op::While {
462 cond: Box::new(cond),
463 body: Box::new(body),
464 max_iterations: Some(3),
465 },
466 vec![v0, s0],
467 v_shape.clone(),
468 );
469 g.set_outputs(vec![w]);
470
471 let lowered = unroll_while(g);
472 assert!(
473 !lowered
474 .nodes()
475 .iter()
476 .any(|n| matches!(n.op, Op::While { .. })),
477 "While should be erased"
478 );
479 let n_where = lowered
480 .nodes()
481 .iter()
482 .filter(|n| matches!(n.op, Op::Where))
483 .count();
484 assert_eq!(n_where, 6, "expected 3 iters × 2 carries Where masks");
485 }
486}