1use crate::pass::Pass;
47use rlx_ir::op::BinaryOp;
48use rlx_ir::shape::Dim;
49use rlx_ir::{DType, Graph, NodeId, Op, Shape};
50use std::collections::HashMap;
51
52pub struct LowerControlFlow;
55
56impl Pass for LowerControlFlow {
57 fn name(&self) -> &str {
58 "LowerControlFlow"
59 }
60 fn run(&self, graph: Graph) -> Graph {
61 let g = inline_if(graph);
62 unroll_while(g)
63 }
64}
65
66pub fn inline_if(g: Graph) -> Graph {
70 let mut out = Graph::new(g.name.clone());
71 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
72 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
73
74 for node in &nodes {
75 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
76 let new_id = match &node.op {
77 Op::If {
78 then_branch,
79 else_branch,
80 } => {
81 let captures: Vec<NodeId> = new_inputs[1..].to_vec();
82 let then_out = inline_subgraph_into(then_branch, &captures, &mut out);
83 let else_out = inline_subgraph_into(else_branch, &captures, &mut out);
84 let predicate = expand_to_shape(new_inputs[0], &node.shape, &mut out);
90 out.add_node(
91 Op::Where,
92 vec![predicate, then_out, else_out],
93 node.shape.clone(),
94 )
95 }
96 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
97 };
98 id_map.insert(node.id, new_id);
99 }
100 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
101 out.set_outputs(new_outputs);
102 out
103}
104
105pub fn unroll_while(g: Graph) -> Graph {
109 let mut out = Graph::new(g.name.clone());
110 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
111 let nodes: Vec<rlx_ir::Node> = g.nodes().to_vec();
112 let scalar_f32 = Shape::new(&[1], DType::F32);
113
114 for node in &nodes {
115 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
116 let new_id = match &node.op {
117 Op::While {
118 cond,
119 body,
120 max_iterations: Some(n),
121 ..
122 } => {
123 if new_inputs.is_empty() {
124 panic!(
125 "Op::While unroll: at least one \
126 loop-carried input required"
127 );
128 }
129 let one = out.add_node(
130 Op::Constant {
131 data: 1.0_f32.to_le_bytes().to_vec(),
132 },
133 vec![],
134 scalar_f32.clone(),
135 );
136 let mut active = one;
137 let mut carried = new_inputs;
138 for _ in 0..*n {
139 let cond_out = inline_subgraph_into(cond, &carried, &mut out);
140 let cond_f = cond_to_f32_mask(cond_out, &mut out);
141 let cond_shape = out.node(cond_f).shape.clone();
142 let active_lhs = expand_to_shape(active, &cond_shape, &mut out);
143 active = out.binary(BinaryOp::Mul, active_lhs, cond_f, cond_shape);
144
145 let body_outs = inline_subgraph_into_outputs(body, &carried, &mut out);
146 assert_eq!(
147 body_outs.len(),
148 carried.len(),
149 "Op::While: body output count must match loop-carried arity"
150 );
151 let mut next = Vec::with_capacity(carried.len());
152 for (body_out, &prev) in body_outs.iter().zip(carried.iter()) {
153 let shape = out.node(prev).shape.clone();
154 let mask = expand_to_shape(active, &shape, &mut out);
155 let merged = out.add_node(Op::Where, vec![mask, *body_out, prev], shape);
156 next.push(merged);
157 }
158 carried = next;
159 }
160 carried[0]
161 }
162 Op::While {
163 max_iterations: None,
164 ..
165 } => {
166 panic!(
167 "LowerControlFlow: Op::While requires \
168 max_iterations = Some(N) for unrolling. \
169 Either set a bounded max_iterations on the \
170 forward graph, or use the dynamic \
171 `rlx_runtime::subgraph::run_while` helper."
172 );
173 }
174 _ => out.add_node(node.op.clone(), new_inputs, node.shape.clone()),
175 };
176 id_map.insert(node.id, new_id);
177 }
178 let new_outputs: Vec<NodeId> = g.outputs.iter().map(|i| id_map[i]).collect();
179 out.set_outputs(new_outputs);
180 out
181}
182
183fn cond_to_f32_mask(cond_out: NodeId, out: &mut Graph) -> NodeId {
186 let cond_shape = out.node(cond_out).shape.clone();
187 match cond_shape.dtype() {
188 DType::F32 => cond_out,
189 DType::Bool => {
190 let f32_shape = cond_shape.clone().with_dtype(DType::F32);
191 let i32_shape = cond_shape.with_dtype(DType::I32);
192 let as_i32 = out.add_node(Op::Cast { to: DType::I32 }, vec![cond_out], i32_shape);
193 out.add_node(Op::Cast { to: DType::F32 }, vec![as_i32], f32_shape)
194 }
195 _ => out.add_node(
196 Op::Cast { to: DType::F32 },
197 vec![cond_out],
198 cond_shape.with_dtype(DType::F32),
199 ),
200 }
201}
202
203fn expand_to_shape(src: NodeId, target: &rlx_ir::Shape, out: &mut Graph) -> NodeId {
208 let src_shape = out.node(src).shape.clone();
209 let src_n = src_shape
210 .dims()
211 .iter()
212 .filter_map(|d| match d {
213 Dim::Static(n) => Some(*n),
214 _ => None,
215 })
216 .product::<usize>();
217 let tgt_n = target
218 .dims()
219 .iter()
220 .filter_map(|d| match d {
221 Dim::Static(n) => Some(*n),
222 _ => None,
223 })
224 .product::<usize>();
225 if src_shape.dims() == target.dims() {
226 return src;
227 }
228 let target_dims_i64: Vec<i64> = target
229 .dims()
230 .iter()
231 .map(|d| match d {
232 Dim::Static(n) => *n as i64,
233 _ => -1,
234 })
235 .collect();
236 let src_rank = src_shape.rank();
239 let tgt_rank = target.dims().len();
240 let to_expand = if src_rank < tgt_rank {
241 let mut padded_dims: Vec<Dim> = std::iter::repeat_n(Dim::Static(1), tgt_rank - src_rank)
242 .chain(src_shape.dims().iter().copied())
243 .collect();
244 let _ = src_n;
246 let _ = tgt_n;
247 let dtype = src_shape.dtype();
248 let pad_dims_i64: Vec<i64> = padded_dims
249 .iter()
250 .map(|d| match d {
251 Dim::Static(n) => *n as i64,
252 _ => -1,
253 })
254 .collect();
255 let pad_shape = rlx_ir::Shape::from_dims(&padded_dims, dtype);
257 padded_dims.clear();
258 out.reshape(src, pad_dims_i64, pad_shape)
259 } else {
260 src
261 };
262 out.add_node(
263 Op::Expand {
264 target_shape: target_dims_i64,
265 },
266 vec![to_expand],
267 target.clone(),
268 )
269}
270
271pub fn inline_subgraph_into_outputs(
274 sub: &Graph,
275 captures: &[NodeId],
276 out: &mut Graph,
277) -> Vec<NodeId> {
278 let mut sub_to_parent: HashMap<NodeId, NodeId> = HashMap::new();
279 let mut input_idx = 0usize;
280 for sub_node in sub.nodes() {
281 let new_id = match &sub_node.op {
282 Op::Input { .. } => {
283 let parent_id = captures[input_idx];
284 input_idx += 1;
285 parent_id
286 }
287 _ => {
288 let new_inputs: Vec<NodeId> =
289 sub_node.inputs.iter().map(|i| sub_to_parent[i]).collect();
290 out.add_node(sub_node.op.clone(), new_inputs, sub_node.shape.clone())
291 }
292 };
293 sub_to_parent.insert(sub_node.id, new_id);
294 }
295 assert_eq!(
296 input_idx,
297 captures.len(),
298 "Op::While/If sub-graph: {} Op::Input nodes but {} captures",
299 input_idx,
300 captures.len()
301 );
302 sub.outputs.iter().map(|o| sub_to_parent[o]).collect()
303}
304
305pub fn inline_subgraph_into(sub: &Graph, captures: &[NodeId], out: &mut Graph) -> NodeId {
309 inline_subgraph_into_outputs(sub, captures, out)[0]
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use rlx_ir::op::{Activation, BinaryOp};
316 use rlx_ir::{DType, Shape};
317
318 #[test]
319 fn lower_control_flow_pass_handles_both_if_and_while() {
320 let s = Shape::new(&[2], DType::F32);
321
322 let mut then_g = Graph::new("th");
323 let ti = then_g.input("c", s.clone());
324 let to = then_g.activation(Activation::Relu, ti, s.clone());
325 then_g.set_outputs(vec![to]);
326 let mut else_g = Graph::new("el");
327 let ei = else_g.input("c", s.clone());
328 let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
329 else_g.set_outputs(vec![eo]);
330
331 let mut body_g = Graph::new("body");
332 let bi = body_g.input("c", s.clone());
333 let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
334 body_g.set_outputs(vec![bo]);
335 let mut cond_g = Graph::new("cond");
336 let ci = cond_g.input("c", s.clone());
337 cond_g.set_outputs(vec![ci]);
338
339 let mut g = Graph::new("parent");
340 let x = g.input("x", s.clone());
341 let pred = g.input("p", Shape::new(&[1], DType::F32));
342 let if_out = g.add_node(
343 Op::If {
344 then_branch: Box::new(then_g),
345 else_branch: Box::new(else_g),
346 },
347 vec![pred, x],
348 s.clone(),
349 );
350 let w_out = g.add_node(
351 Op::While {
352 cond: Box::new(cond_g),
353 body: Box::new(body_g),
354 max_iterations: Some(2),
355 },
356 vec![if_out],
357 s.clone(),
358 );
359 g.set_outputs(vec![w_out]);
360
361 let lowered = LowerControlFlow.run(g);
362 let has_if = lowered
363 .nodes()
364 .iter()
365 .any(|n| matches!(n.op, Op::If { .. }));
366 let has_while = lowered
367 .nodes()
368 .iter()
369 .any(|n| matches!(n.op, Op::While { .. }));
370 assert!(
371 !has_if && !has_while,
372 "LowerControlFlow should erase both If and While"
373 );
374 let n_where = lowered
377 .nodes()
378 .iter()
379 .filter(|n| matches!(n.op, Op::Where))
380 .count();
381 let n_mul = lowered
382 .nodes()
383 .iter()
384 .filter(|n| matches!(n.op, Op::Binary(BinaryOp::Mul)))
385 .count();
386 assert_eq!(
387 n_where, 3,
388 "expected 1 Where from If + 2 from While (N=2, 1 carry)"
389 );
390 assert_eq!(
391 n_mul, 4,
392 "expected 2 body Mul + 2 active*cond_f Mul from While (N=2)"
393 );
394 }
395
396 #[test]
397 fn unroll_while_multi_carry_cond_freezes_updates() {
398 let v_shape = Shape::new(&[2], DType::F32);
399 let s_shape = Shape::new(&[1], DType::F32);
400
401 let mut body = Graph::new("body");
402 let v_in = body.input("v", v_shape.clone());
403 let s_in = body.input("s", s_shape.clone());
404 let one = body.add_node(
405 Op::Constant {
406 data: 1.0_f32.to_le_bytes().to_vec(),
407 },
408 vec![],
409 s_shape.clone(),
410 );
411 let v_out = body.binary(BinaryOp::Add, v_in, one, v_shape.clone());
412 body.set_outputs(vec![v_out, s_in]);
413
414 let mut cond = Graph::new("cond");
415 let v_c = cond.input("v", v_shape.clone());
416 let _s_c = cond.input("s", s_shape.clone());
417 let ten = cond.add_node(
418 Op::Constant {
419 data: 10.0_f32.to_le_bytes().to_vec(),
420 },
421 vec![],
422 s_shape.clone(),
423 );
424 let lt = cond.add_node(
425 Op::Compare(rlx_ir::op::CmpOp::Lt),
426 vec![v_c, ten],
427 Shape::new(&[1], DType::Bool),
428 );
429 cond.set_outputs(vec![lt]);
430
431 let mut g = Graph::new("parent");
432 let v0 = g.input("v0", v_shape.clone());
433 let s0 = g.input("s0", s_shape.clone());
434 let w = g.add_node(
435 Op::While {
436 cond: Box::new(cond),
437 body: Box::new(body),
438 max_iterations: Some(3),
439 },
440 vec![v0, s0],
441 v_shape.clone(),
442 );
443 g.set_outputs(vec![w]);
444
445 let lowered = unroll_while(g);
446 assert!(
447 !lowered
448 .nodes()
449 .iter()
450 .any(|n| matches!(n.op, Op::While { .. })),
451 "While should be erased"
452 );
453 let n_where = lowered
454 .nodes()
455 .iter()
456 .filter(|n| matches!(n.op, Op::Where))
457 .count();
458 assert_eq!(n_where, 6, "expected 3 iters × 2 carries Where masks");
459 }
460
461 #[test]
462 fn unroll_while_squares_on_cpu_thunks() {
463 let s = Shape::new(&[2], DType::F32);
464 let mut body_g = Graph::new("body");
465 let bi = body_g.input("c", s.clone());
466 let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
467 body_g.set_outputs(vec![bo]);
468 let mut cond_g = Graph::new("cond");
469 let ci = cond_g.input("c", s.clone());
470 cond_g.set_outputs(vec![ci]);
471
472 let mut g = Graph::new("while_test");
473 let x = g.input("x", s.clone());
474 let y = g.add_node(
475 Op::While {
476 cond: Box::new(cond_g),
477 body: Box::new(body_g),
478 max_iterations: Some(3),
479 },
480 vec![x],
481 s.clone(),
482 );
483 g.set_outputs(vec![y]);
484
485 let lowered = unroll_while(g);
486 assert!(
487 !lowered
488 .nodes()
489 .iter()
490 .any(|n| matches!(n.op, Op::While { .. }))
491 );
492
493 let x_id = lowered
494 .nodes()
495 .iter()
496 .find(|n| matches!(&n.op, Op::Input { name, .. } if name == "x"))
497 .expect("lowered graph missing input x")
498 .id;
499 let plan = rlx_opt::memory::plan_memory(&lowered);
500 let mut arena = rlx_cpu::arena::Arena::from_plan(plan);
501 let sched = rlx_cpu::thunk::compile_thunks(&lowered, &arena);
502 for node in lowered.nodes() {
503 if let Op::Constant { data } = &node.op
504 && arena.has_buffer(node.id)
505 && !data.is_empty()
506 {
507 let buf = arena.slice_mut(node.id);
508 let n_floats = data.len() / 4;
509 let n = buf.len().min(n_floats);
510 for i in 0..n {
511 let bytes = [
512 data[i * 4],
513 data[i * 4 + 1],
514 data[i * 4 + 2],
515 data[i * 4 + 3],
516 ];
517 buf[i] = f32::from_le_bytes(bytes);
518 }
519 }
520 }
521 let x_off = arena.byte_offset(x_id);
522 let out_id = lowered.outputs[0];
523 let out_off = arena.byte_offset(out_id);
524 let buf = arena.raw_buf_mut();
525 unsafe {
526 let px = buf.as_mut_ptr().add(x_off) as *mut f32;
527 *px.add(0) = 2.0;
528 *px.add(1) = 3.0;
529 }
530 rlx_cpu::thunk::execute_thunks(&sched, arena.raw_buf_mut());
531 let got: Vec<f32> = unsafe {
532 let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
533 vec![*p.add(0), *p.add(1)]
534 };
535 let want = [256.0_f32, 6561.0_f32];
536 for (i, (&a, &b)) in got.iter().zip(&want).enumerate() {
537 assert!(
538 (a - b).abs() < 1e-3,
539 "unrolled while[{i}]: got {a} want {b}"
540 );
541 }
542 }
543}