1use crate::ir::{Graph, NodeId, Op};
6use rustc_hash::{FxHashMap, FxHashSet};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum OptimizationPass {
11 ConstantFolding,
13 DeadCodeElimination,
15 ElementwiseFusion,
17 CommonSubexpressionElimination,
19 AlgebraicSimplification,
21 StrengthReduction,
23}
24
25pub struct Optimizer {
27 passes: Vec<OptimizationPass>,
28}
29
30impl Optimizer {
31 pub fn new() -> Self {
33 Self { passes: Vec::new() }
34 }
35
36 pub fn default_passes() -> Self {
38 Self {
39 passes: vec![
40 OptimizationPass::ConstantFolding,
41 OptimizationPass::AlgebraicSimplification,
42 OptimizationPass::DeadCodeElimination,
43 OptimizationPass::CommonSubexpressionElimination,
44 ],
45 }
46 }
47
48 pub fn add_pass(&mut self, pass: OptimizationPass) {
50 self.passes.push(pass);
51 }
52
53 pub fn optimize(&self, mut graph: Graph) -> Graph {
55 for pass in &self.passes {
56 graph = self.run_pass(graph, *pass);
57 }
58 graph
59 }
60
61 fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
62 match pass {
63 OptimizationPass::ConstantFolding => constant_folding(graph),
64 OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
65 OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
66 OptimizationPass::CommonSubexpressionElimination => cse(graph),
67 OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
68 OptimizationPass::StrengthReduction => strength_reduction(graph),
69 }
70 }
71}
72
73impl Default for Optimizer {
74 fn default() -> Self {
75 Self::default_passes()
76 }
77}
78
79fn constant_folding(graph: Graph) -> Graph {
81 let mut new_graph = Graph::new();
84 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
85 let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
86
87 for node in graph.nodes() {
88 if let Op::Constant { value } = &node.op {
90 constants.insert(node.id, *value);
91 }
92
93 let new_op = match &node.op {
95 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
96 let new_input = node_map.get(input).copied().unwrap_or(*input);
98 node_map.insert(node.id, new_input);
99 continue;
100 }
101 Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
102 Op::Constant { value: 0.0 }
104 }
105 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
106 let new_input = node_map.get(input).copied().unwrap_or(*input);
108 node_map.insert(node.id, new_input);
109 continue;
110 }
111 other => remap_op(other, &node_map),
112 };
113
114 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
115 node_map.insert(node.id, new_id);
116 }
117
118 for (name, id) in graph.inputs() {
120 if let Some(&new_id) = node_map.get(id) {
121 new_graph.register_input(name, new_id);
122 }
123 }
124 for (name, id) in graph.outputs() {
125 if let Some(&new_id) = node_map.get(id) {
126 new_graph.register_output(name, new_id);
127 }
128 }
129
130 new_graph
131}
132
133fn dead_code_elimination(graph: Graph) -> Graph {
135 let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
137 let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
138
139 while let Some(id) = worklist.pop() {
140 if live_nodes.insert(id) {
141 let node = graph.node(id);
142 for input_id in node.op.inputs() {
143 worklist.push(input_id);
144 }
145 }
146 }
147
148 let mut new_graph = Graph::new();
150 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
151
152 for node in graph.nodes() {
153 if !live_nodes.contains(&node.id) {
154 continue;
155 }
156
157 let new_op = remap_op(&node.op, &node_map);
158 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
159 node_map.insert(node.id, new_id);
160 }
161
162 for (name, id) in graph.inputs() {
164 if let Some(&new_id) = node_map.get(id) {
165 new_graph.register_input(name, new_id);
166 }
167 }
168 for (name, id) in graph.outputs() {
169 if let Some(&new_id) = node_map.get(id) {
170 new_graph.register_output(name, new_id);
171 }
172 }
173
174 new_graph
175}
176
177fn elementwise_fusion(graph: Graph) -> Graph {
179 graph
183}
184
185fn cse(graph: Graph) -> Graph {
187 let mut new_graph = Graph::new();
189 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
190 let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
191
192 for node in graph.nodes() {
193 let remapped_op = remap_op(&node.op, &node_map);
194 let expr_key = format!("{:?}", remapped_op);
195
196 if let Some(&existing_id) = expr_map.get(&expr_key) {
197 node_map.insert(node.id, existing_id);
199 } else {
200 let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
201 node_map.insert(node.id, new_id);
202 expr_map.insert(expr_key, new_id);
203 }
204 }
205
206 for (name, id) in graph.inputs() {
208 if let Some(&new_id) = node_map.get(id) {
209 new_graph.register_input(name, new_id);
210 }
211 }
212 for (name, id) in graph.outputs() {
213 if let Some(&new_id) = node_map.get(id) {
214 new_graph.register_output(name, new_id);
215 }
216 }
217
218 new_graph
219}
220
221fn algebraic_simplification(graph: Graph) -> Graph {
223 let mut new_graph = Graph::new();
224 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
225
226 for node in graph.nodes() {
227 let simplified_op = match &node.op {
228 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
230 let new_input = node_map.get(input).copied().unwrap_or(*input);
231 node_map.insert(node.id, new_input);
232 continue;
233 }
234 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
236 let new_input = node_map.get(input).copied().unwrap_or(*input);
237 node_map.insert(node.id, new_input);
238 continue;
239 }
240 Op::Neg { input } => {
244 let actual_input = node_map.get(input).copied().unwrap_or(*input);
245 if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
246 if let Op::Neg { input: inner } = &input_node.op {
247 node_map.insert(node.id, *inner);
248 continue;
249 }
250 }
251 Op::Neg { input: actual_input }
252 }
253 other => remap_op(other, &node_map),
254 };
255
256 let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
257 node_map.insert(node.id, new_id);
258 }
259
260 for (name, id) in graph.inputs() {
262 if let Some(&new_id) = node_map.get(id) {
263 new_graph.register_input(name, new_id);
264 }
265 }
266 for (name, id) in graph.outputs() {
267 if let Some(&new_id) = node_map.get(id) {
268 new_graph.register_output(name, new_id);
269 }
270 }
271
272 new_graph
273}
274
275fn strength_reduction(graph: Graph) -> Graph {
277 let mut new_graph = Graph::new();
278 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
279
280 for node in graph.nodes() {
281 let reduced_op = match &node.op {
282 Op::Pow { .. } => {
284 remap_op(&node.op, &node_map)
287 }
288 Op::Div { .. } => {
290 remap_op(&node.op, &node_map)
292 }
293 other => remap_op(other, &node_map),
294 };
295
296 let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
297 node_map.insert(node.id, new_id);
298 }
299
300 for (name, id) in graph.inputs() {
302 if let Some(&new_id) = node_map.get(id) {
303 new_graph.register_input(name, new_id);
304 }
305 }
306 for (name, id) in graph.outputs() {
307 if let Some(&new_id) = node_map.get(id) {
308 new_graph.register_output(name, new_id);
309 }
310 }
311
312 new_graph
313}
314
315fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
317 let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
318
319 match op {
320 Op::Input { name } => Op::Input { name: name.clone() },
321 Op::Output { name, input } => Op::Output { name: name.clone(), input: remap(input) },
322 Op::Constant { value } => Op::Constant { value: *value },
323
324 Op::Add { lhs, rhs } => Op::Add { lhs: remap(lhs), rhs: remap(rhs) },
325 Op::Sub { lhs, rhs } => Op::Sub { lhs: remap(lhs), rhs: remap(rhs) },
326 Op::Mul { lhs, rhs } => Op::Mul { lhs: remap(lhs), rhs: remap(rhs) },
327 Op::Div { lhs, rhs } => Op::Div { lhs: remap(lhs), rhs: remap(rhs) },
328 Op::Pow { base, exp } => Op::Pow { base: remap(base), exp: remap(exp) },
329 Op::Max { lhs, rhs } => Op::Max { lhs: remap(lhs), rhs: remap(rhs) },
330 Op::Min { lhs, rhs } => Op::Min { lhs: remap(lhs), rhs: remap(rhs) },
331
332 Op::Neg { input } => Op::Neg { input: remap(input) },
333 Op::Abs { input } => Op::Abs { input: remap(input) },
334 Op::Sqrt { input } => Op::Sqrt { input: remap(input) },
335 Op::Exp { input } => Op::Exp { input: remap(input) },
336 Op::Log { input } => Op::Log { input: remap(input) },
337 Op::Sin { input } => Op::Sin { input: remap(input) },
338 Op::Cos { input } => Op::Cos { input: remap(input) },
339 Op::Tanh { input } => Op::Tanh { input: remap(input) },
340
341 Op::Relu { input } => Op::Relu { input: remap(input) },
342 Op::Sigmoid { input } => Op::Sigmoid { input: remap(input) },
343 Op::Gelu { input } => Op::Gelu { input: remap(input) },
344 Op::Silu { input } => Op::Silu { input: remap(input) },
345
346 Op::AddScalar { input, scalar } => Op::AddScalar { input: remap(input), scalar: *scalar },
347 Op::MulScalar { input, scalar } => Op::MulScalar { input: remap(input), scalar: *scalar },
348
349 Op::Sum { input } => Op::Sum { input: remap(input) },
350 Op::SumAxis { input, axis, keepdim } => Op::SumAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
351 Op::Mean { input } => Op::Mean { input: remap(input) },
352 Op::MeanAxis { input, axis, keepdim } => Op::MeanAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
353 Op::MaxAxis { input, axis, keepdim } => Op::MaxAxis { input: remap(input), axis: *axis, keepdim: *keepdim },
354
355 Op::Reshape { input, shape } => Op::Reshape { input: remap(input), shape: shape.clone() },
356 Op::Transpose { input, dim0, dim1 } => Op::Transpose { input: remap(input), dim0: *dim0, dim1: *dim1 },
357 Op::Squeeze { input, dim } => Op::Squeeze { input: remap(input), dim: *dim },
358 Op::Unsqueeze { input, dim } => Op::Unsqueeze { input: remap(input), dim: *dim },
359 Op::Broadcast { input, shape } => Op::Broadcast { input: remap(input), shape: shape.clone() },
360
361 Op::MatMul { lhs, rhs } => Op::MatMul { lhs: remap(lhs), rhs: remap(rhs) },
362
363 Op::Gt { lhs, rhs } => Op::Gt { lhs: remap(lhs), rhs: remap(rhs) },
364 Op::Lt { lhs, rhs } => Op::Lt { lhs: remap(lhs), rhs: remap(rhs) },
365 Op::Eq { lhs, rhs } => Op::Eq { lhs: remap(lhs), rhs: remap(rhs) },
366
367 Op::Where { condition, x, y } => Op::Where { condition: remap(condition), x: remap(x), y: remap(y) },
368
369 Op::Cast { input, dtype } => Op::Cast { input: remap(input), dtype: *dtype },
370 Op::Contiguous { input } => Op::Contiguous { input: remap(input) },
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::trace::trace;
378
379 #[test]
380 fn test_dead_code_elimination() {
381 let graph = trace(|tracer| {
382 let a = tracer.input("a", &[2, 3]);
383 let b = tracer.input("b", &[2, 3]);
384 let _unused = a.mul(&b); let c = a.add(&b);
386 tracer.output("result", c)
387 });
388
389 let optimizer = Optimizer::new();
390 let mut opt = optimizer;
391 opt.add_pass(OptimizationPass::DeadCodeElimination);
392 let optimized = opt.optimize(graph);
393
394 let has_mul = optimized.nodes().iter().any(|n| matches!(n.op, Op::Mul { .. }));
396 assert!(!has_mul);
397 }
398
399 #[test]
400 fn test_algebraic_simplification() {
401 let graph = trace(|tracer| {
402 let x = tracer.input("x", &[2, 3]);
403 let y = x.mul_scalar(1.0); tracer.output("y", y)
405 });
406
407 let mut optimizer = Optimizer::new();
408 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
409 let optimized = optimizer.optimize(graph);
410
411 let has_mul_scalar = optimized.nodes().iter().any(|n| matches!(n.op, Op::MulScalar { .. }));
413 assert!(!has_mul_scalar);
414 }
415
416 #[test]
417 fn test_constant_folding() {
418 let graph = trace(|tracer| {
419 let x = tracer.input("x", &[2, 3]);
420 let y = x.mul_scalar(0.0); tracer.output("y", y)
422 });
423
424 let mut optimizer = Optimizer::new();
425 optimizer.add_pass(OptimizationPass::ConstantFolding);
426 let optimized = optimizer.optimize(graph);
427
428 let has_constant = optimized.nodes().iter().any(|n| matches!(n.op, Op::Constant { .. }));
430 assert!(has_constant);
431 }
432}