rlx_compile/
const_fold.rs1use rlx_fusion::pass::Pass;
32use rlx_ir::op::{Activation, BinaryOp};
33use rlx_ir::{Graph, NodeId, Op};
34use std::collections::{HashMap, HashSet};
35
36pub struct ConstantFolding;
37
38fn is_pure(op: &Op) -> bool {
40 matches!(
41 op,
42 Op::Activation(_)
43 | Op::Binary(_)
44 | Op::Compare(_)
45 | Op::Reshape { .. }
46 | Op::Expand { .. }
47 | Op::Cast { .. }
48 )
49}
50
51fn is_foldable(node_id: NodeId, graph: &Graph, folded: &HashSet<NodeId>) -> bool {
54 let node = graph.node(node_id);
55 if !is_pure(&node.op) {
56 return false;
57 }
58 node.inputs.iter().all(|i| folded.contains(i))
59}
60
61fn evaluate(node: &rlx_ir::Node, inputs: &[&Vec<f32>]) -> Option<Vec<f32>> {
64 let total = node.shape.num_elements()?;
65 let mut out = vec![0f32; total];
66
67 match &node.op {
68 Op::Activation(act) => {
69 let x = inputs[0];
70 for (i, &v) in x.iter().enumerate() {
71 out[i] = match act {
72 Activation::Gelu | Activation::GeluApprox => {
73 v * 0.5 * (1.0 + (v * std::f32::consts::FRAC_1_SQRT_2).tanh())
74 }
75 Activation::Silu => v / (1.0 + (-v).exp()),
76 Activation::Relu => v.max(0.0),
77 Activation::Sigmoid => 1.0 / (1.0 + (-v).exp()),
78 Activation::Tanh => v.tanh(),
79 Activation::Exp => v.exp(),
80 Activation::Log => v.ln(),
81 Activation::Sqrt => v.sqrt(),
82 Activation::Rsqrt => 1.0 / v.sqrt(),
83 Activation::Neg => -v,
84 Activation::Abs => v.abs(),
85 Activation::Round => v.round(),
86 Activation::Sin => v.sin(),
87 Activation::Cos => v.cos(),
88 Activation::Tan => v.tan(),
89 Activation::Atan => v.atan(),
90 };
91 }
92 Some(out)
93 }
94 Op::Binary(op) => {
95 let lhs = inputs[0];
96 let rhs = inputs[1];
97 if lhs.len() != total || rhs.len() != total {
99 return None;
100 }
101 for i in 0..total {
102 out[i] = match op {
103 BinaryOp::Add => lhs[i] + rhs[i],
104 BinaryOp::Sub => lhs[i] - rhs[i],
105 BinaryOp::Mul => lhs[i] * rhs[i],
106 BinaryOp::Div => lhs[i] / rhs[i],
107 BinaryOp::Max => lhs[i].max(rhs[i]),
108 BinaryOp::Min => lhs[i].min(rhs[i]),
109 BinaryOp::Pow => lhs[i].powf(rhs[i]),
110 };
111 }
112 Some(out)
113 }
114 Op::Reshape { .. } | Op::Expand { .. } | Op::Cast { .. } => {
115 let src = inputs[0];
117 if src.len() == total {
118 Some(src.clone())
119 } else if src.len() == 1 {
120 Some(vec![src[0]; total])
121 } else {
122 None
123 }
124 }
125 _ => None,
126 }
127}
128
129fn encode_constant(data: &[f32]) -> Vec<u8> {
131 let mut bytes = Vec::with_capacity(data.len() * 4);
132 for &v in data {
133 bytes.extend_from_slice(&v.to_le_bytes());
134 }
135 bytes
136}
137
138impl Pass for ConstantFolding {
139 fn name(&self) -> &str {
140 "constant_folding"
141 }
142
143 fn run(&self, graph: Graph) -> Graph {
144 let mut folded: HashSet<NodeId> = HashSet::new();
147 let mut values: HashMap<NodeId, Vec<f32>> = HashMap::new();
148
149 for node in graph.nodes() {
150 if let Op::Constant { data } = &node.op {
152 folded.insert(node.id);
153 let f32s: Vec<f32> = data
154 .chunks_exact(4)
155 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
156 .collect();
157 values.insert(node.id, f32s);
158 continue;
159 }
160 if matches!(node.op, Op::Input { .. } | Op::Param { .. }) {
162 continue;
163 }
164 if is_foldable(node.id, &graph, &folded) {
166 let inputs: Vec<&Vec<f32>> = node.inputs.iter().map(|i| &values[i]).collect();
167 if let Some(result) = evaluate(node, &inputs) {
168 folded.insert(node.id);
169 values.insert(node.id, result);
170 }
171 }
172 }
173
174 let mut new_graph = Graph::new(&graph.name);
176 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
177 for node in graph.nodes() {
178 if folded.contains(&node.id)
181 && !matches!(
182 node.op,
183 Op::Constant { .. } | Op::Param { .. } | Op::Input { .. }
184 )
185 {
186 let bytes = encode_constant(&values[&node.id]);
187 let new_id =
188 new_graph.add_node(Op::Constant { data: bytes }, vec![], node.shape.clone());
189 id_map.insert(node.id, new_id);
190 continue;
191 }
192 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
194 let new_id = new_graph.add_node(node.op.clone(), new_inputs, node.shape.clone());
195 id_map.insert(node.id, new_id);
196 }
197 let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
198 new_graph.set_outputs(new_outputs);
199 new_graph
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use rlx_ir::*;
207
208 #[test]
209 fn folds_constant_arithmetic() {
210 let mut g = Graph::new("test");
212 let a = g.add_node(
213 Op::Constant {
214 data: 2.0f32.to_le_bytes().to_vec(),
215 },
216 vec![],
217 Shape::new(&[1], DType::F32),
218 );
219 let b = g.add_node(
220 Op::Constant {
221 data: 3.0f32.to_le_bytes().to_vec(),
222 },
223 vec![],
224 Shape::new(&[1], DType::F32),
225 );
226 let sum = g.binary(op::BinaryOp::Add, a, b, Shape::new(&[1], DType::F32));
227 g.set_outputs(vec![sum]);
228
229 let folded = ConstantFolding.run(g);
230 let out_node = folded.node(folded.outputs[0]);
232 if let Op::Constant { data } = &out_node.op {
233 let v = f32::from_le_bytes([data[0], data[1], data[2], data[3]]);
234 assert!((v - 5.0).abs() < 1e-6);
235 } else {
236 panic!("expected folded Constant, got {:?}", out_node.op);
237 }
238 }
239
240 #[test]
241 fn does_not_fold_input_dependent() {
242 let mut g = Graph::new("test");
243 let x = g.input("x", Shape::new(&[4], DType::F32));
244 let c = g.add_node(
245 Op::Constant {
246 data: vec![0u8; 16],
247 },
248 vec![],
249 Shape::new(&[4], DType::F32),
250 );
251 let sum = g.binary(op::BinaryOp::Add, x, c, Shape::new(&[4], DType::F32));
252 g.set_outputs(vec![sum]);
253
254 let folded = ConstantFolding.run(g);
255 assert!(matches!(folded.node(folded.outputs[0]).op, Op::Binary(_)));
257 }
258}