1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//! Verify that the runtime's `LowerControlFlow` pre-pass lets the CPU
//! backend execute graphs containing `Op::If` and `Op::While` end-to-end
//! — neither op is in `CPU_SUPPORTED_OPS`, so without the rewrite the
//! legalize check would reject them.
//!
//! `Op::If` rewrites to `Where(predicate, then_inlined, else_inlined)`
//! and `Op::While` to a chain of body replicas up to `max_iterations`.
//! Both branches / all iterations always execute in the rewritten graph
//! (the trade-off documented in `rlx_opt::control_flow`).
#![cfg(feature = "cpu")]
use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::{DType, Graph, Op, Shape};
use rlx_runtime::{Device, Session};
#[test]
fn cpu_runs_op_if_via_lower_control_flow() {
// then = relu(x), else = sigmoid(x). Predicate=1.0 selects then.
let s = Shape::new(&[4], DType::F32);
let mut then_g = Graph::new("then");
let ti = then_g.input("c", s.clone());
let to = then_g.activation(Activation::Relu, ti, s.clone());
then_g.set_outputs(vec![to]);
let mut else_g = Graph::new("else");
let ei = else_g.input("c", s.clone());
let eo = else_g.activation(Activation::Sigmoid, ei, s.clone());
else_g.set_outputs(vec![eo]);
let mut g = Graph::new("if_test");
let x = g.input("x", s.clone());
let pred = g.input("pred", Shape::new(&[1], DType::F32));
let y = g.add_node(
Op::If {
then_branch: Box::new(then_g),
else_branch: Box::new(else_g),
},
vec![pred, x],
s.clone(),
);
g.set_outputs(vec![y]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
let xs: Vec<f32> = vec![-1.0, 0.0, 1.0, 2.0];
let pred_true = vec![1.0f32];
let outs = compiled.run(&[("x", &xs), ("pred", &pred_true)]);
// pred=1.0 → take Relu(x) = [0, 0, 1, 2].
assert_eq!(
outs[0],
vec![0.0, 0.0, 1.0, 2.0],
"Op::If should select then-branch when predicate is non-zero"
);
}
#[test]
fn cpu_runs_op_while_via_lower_control_flow() {
// body = c * c. Loop-carried single value, max_iterations=3.
// Starting from x = [2, 3], after 3 iterations the value is
// x ^ (2^3) = x^8.
let s = Shape::new(&[2], DType::F32);
let mut body_g = Graph::new("body");
let bi = body_g.input("c", s.clone());
let bo = body_g.binary(BinaryOp::Mul, bi, bi, s.clone());
body_g.set_outputs(vec![bo]);
let mut cond_g = Graph::new("cond");
let ci = cond_g.input("c", s.clone());
cond_g.set_outputs(vec![ci]);
let mut g = Graph::new("while_test");
let x = g.input("x", s.clone());
let y = g.add_node(
Op::While {
cond: Box::new(cond_g),
body: Box::new(body_g),
max_iterations: Some(3),
},
vec![x],
s.clone(),
);
g.set_outputs(vec![y]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
let xs: Vec<f32> = vec![2.0, 3.0];
let outs = compiled.run(&[("x", &xs)]);
let want = vec![2.0_f32.powi(8), 3.0_f32.powi(8)]; // [256, 6561]
assert!(
outs[0].iter().zip(&want).all(|(a, b)| (a - b).abs() < 1e-3),
"Op::While unroll should square 3 times: got {:?} want {want:?}",
outs[0]
);
}