1use crate::ir::{Graph, NodeId, Op};
6use rustc_hash::{FxHashMap, FxHashSet};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum OptimizationPass {
11 ConstantFolding,
13 DeadCodeElimination,
15 ElementwiseFusion,
17 CommonSubexpressionElimination,
19 AlgebraicSimplification,
21 StrengthReduction,
23}
24
25pub struct Optimizer {
27 passes: Vec<OptimizationPass>,
28}
29
30impl Optimizer {
31 pub fn new() -> Self {
33 Self { passes: Vec::new() }
34 }
35
36 pub fn default_passes() -> Self {
38 Self {
39 passes: vec![
40 OptimizationPass::ConstantFolding,
41 OptimizationPass::AlgebraicSimplification,
42 OptimizationPass::DeadCodeElimination,
43 OptimizationPass::CommonSubexpressionElimination,
44 ],
45 }
46 }
47
48 pub fn add_pass(&mut self, pass: OptimizationPass) {
50 self.passes.push(pass);
51 }
52
53 pub fn optimize(&self, mut graph: Graph) -> Graph {
55 for pass in &self.passes {
56 graph = self.run_pass(graph, *pass);
57 }
58 graph
59 }
60
61 fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
62 match pass {
63 OptimizationPass::ConstantFolding => constant_folding(graph),
64 OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
65 OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
66 OptimizationPass::CommonSubexpressionElimination => cse(graph),
67 OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
68 OptimizationPass::StrengthReduction => strength_reduction(graph),
69 }
70 }
71}
72
73impl Default for Optimizer {
74 fn default() -> Self {
75 Self::default_passes()
76 }
77}
78
79fn constant_folding(graph: Graph) -> Graph {
81 let mut new_graph = Graph::new();
84 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
85 let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
86
87 for node in graph.nodes() {
88 if let Op::Constant { value } = &node.op {
90 constants.insert(node.id, *value);
91 }
92
93 let new_op = match &node.op {
95 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
96 let new_input = node_map.get(input).copied().unwrap_or(*input);
98 node_map.insert(node.id, new_input);
99 continue;
100 }
101 Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
102 Op::Constant { value: 0.0 }
104 }
105 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
106 let new_input = node_map.get(input).copied().unwrap_or(*input);
108 node_map.insert(node.id, new_input);
109 continue;
110 }
111 other => remap_op(other, &node_map),
112 };
113
114 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
115 node_map.insert(node.id, new_id);
116 }
117
118 for (name, id) in graph.inputs() {
120 if let Some(&new_id) = node_map.get(id) {
121 new_graph.register_input(name, new_id);
122 }
123 }
124 for (name, id) in graph.outputs() {
125 if let Some(&new_id) = node_map.get(id) {
126 new_graph.register_output(name, new_id);
127 }
128 }
129
130 new_graph
131}
132
133fn dead_code_elimination(graph: Graph) -> Graph {
135 let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
137 let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
138
139 while let Some(id) = worklist.pop() {
140 if live_nodes.insert(id) {
141 let node = graph.node(id);
142 for input_id in node.op.inputs() {
143 worklist.push(input_id);
144 }
145 }
146 }
147
148 let mut new_graph = Graph::new();
150 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
151
152 for node in graph.nodes() {
153 if !live_nodes.contains(&node.id) {
154 continue;
155 }
156
157 let new_op = remap_op(&node.op, &node_map);
158 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
159 node_map.insert(node.id, new_id);
160 }
161
162 for (name, id) in graph.inputs() {
164 if let Some(&new_id) = node_map.get(id) {
165 new_graph.register_input(name, new_id);
166 }
167 }
168 for (name, id) in graph.outputs() {
169 if let Some(&new_id) = node_map.get(id) {
170 new_graph.register_output(name, new_id);
171 }
172 }
173
174 new_graph
175}
176
177fn elementwise_fusion(graph: Graph) -> Graph {
179 graph
183}
184
185fn cse(graph: Graph) -> Graph {
187 let mut new_graph = Graph::new();
189 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
190 let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
191
192 for node in graph.nodes() {
193 let remapped_op = remap_op(&node.op, &node_map);
194 let expr_key = format!("{:?}", remapped_op);
195
196 if let Some(&existing_id) = expr_map.get(&expr_key) {
197 node_map.insert(node.id, existing_id);
199 } else {
200 let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
201 node_map.insert(node.id, new_id);
202 expr_map.insert(expr_key, new_id);
203 }
204 }
205
206 for (name, id) in graph.inputs() {
208 if let Some(&new_id) = node_map.get(id) {
209 new_graph.register_input(name, new_id);
210 }
211 }
212 for (name, id) in graph.outputs() {
213 if let Some(&new_id) = node_map.get(id) {
214 new_graph.register_output(name, new_id);
215 }
216 }
217
218 new_graph
219}
220
221fn algebraic_simplification(graph: Graph) -> Graph {
223 let mut new_graph = Graph::new();
224 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
225
226 for node in graph.nodes() {
227 let simplified_op = match &node.op {
228 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
230 let new_input = node_map.get(input).copied().unwrap_or(*input);
231 node_map.insert(node.id, new_input);
232 continue;
233 }
234 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
236 let new_input = node_map.get(input).copied().unwrap_or(*input);
237 node_map.insert(node.id, new_input);
238 continue;
239 }
240 Op::Neg { input } => {
244 let actual_input = node_map.get(input).copied().unwrap_or(*input);
245 if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
246 if let Op::Neg { input: inner } = &input_node.op {
247 node_map.insert(node.id, *inner);
248 continue;
249 }
250 }
251 Op::Neg {
252 input: actual_input,
253 }
254 }
255 other => remap_op(other, &node_map),
256 };
257
258 let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
259 node_map.insert(node.id, new_id);
260 }
261
262 for (name, id) in graph.inputs() {
264 if let Some(&new_id) = node_map.get(id) {
265 new_graph.register_input(name, new_id);
266 }
267 }
268 for (name, id) in graph.outputs() {
269 if let Some(&new_id) = node_map.get(id) {
270 new_graph.register_output(name, new_id);
271 }
272 }
273
274 new_graph
275}
276
277fn strength_reduction(graph: Graph) -> Graph {
279 let mut new_graph = Graph::new();
280 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
281
282 for node in graph.nodes() {
283 let reduced_op = match &node.op {
284 Op::Pow { .. } => {
286 remap_op(&node.op, &node_map)
289 }
290 Op::Div { .. } => {
292 remap_op(&node.op, &node_map)
294 }
295 other => remap_op(other, &node_map),
296 };
297
298 let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
299 node_map.insert(node.id, new_id);
300 }
301
302 for (name, id) in graph.inputs() {
304 if let Some(&new_id) = node_map.get(id) {
305 new_graph.register_input(name, new_id);
306 }
307 }
308 for (name, id) in graph.outputs() {
309 if let Some(&new_id) = node_map.get(id) {
310 new_graph.register_output(name, new_id);
311 }
312 }
313
314 new_graph
315}
316
317fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
319 let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
320
321 match op {
322 Op::Input { name } => Op::Input { name: name.clone() },
323 Op::Output { name, input } => Op::Output {
324 name: name.clone(),
325 input: remap(input),
326 },
327 Op::Constant { value } => Op::Constant { value: *value },
328
329 Op::Add { lhs, rhs } => Op::Add {
330 lhs: remap(lhs),
331 rhs: remap(rhs),
332 },
333 Op::Sub { lhs, rhs } => Op::Sub {
334 lhs: remap(lhs),
335 rhs: remap(rhs),
336 },
337 Op::Mul { lhs, rhs } => Op::Mul {
338 lhs: remap(lhs),
339 rhs: remap(rhs),
340 },
341 Op::Div { lhs, rhs } => Op::Div {
342 lhs: remap(lhs),
343 rhs: remap(rhs),
344 },
345 Op::Pow { base, exp } => Op::Pow {
346 base: remap(base),
347 exp: remap(exp),
348 },
349 Op::Max { lhs, rhs } => Op::Max {
350 lhs: remap(lhs),
351 rhs: remap(rhs),
352 },
353 Op::Min { lhs, rhs } => Op::Min {
354 lhs: remap(lhs),
355 rhs: remap(rhs),
356 },
357
358 Op::Neg { input } => Op::Neg {
359 input: remap(input),
360 },
361 Op::Abs { input } => Op::Abs {
362 input: remap(input),
363 },
364 Op::Sqrt { input } => Op::Sqrt {
365 input: remap(input),
366 },
367 Op::Exp { input } => Op::Exp {
368 input: remap(input),
369 },
370 Op::Log { input } => Op::Log {
371 input: remap(input),
372 },
373 Op::Sin { input } => Op::Sin {
374 input: remap(input),
375 },
376 Op::Cos { input } => Op::Cos {
377 input: remap(input),
378 },
379 Op::Tanh { input } => Op::Tanh {
380 input: remap(input),
381 },
382
383 Op::Relu { input } => Op::Relu {
384 input: remap(input),
385 },
386 Op::Sigmoid { input } => Op::Sigmoid {
387 input: remap(input),
388 },
389 Op::Gelu { input } => Op::Gelu {
390 input: remap(input),
391 },
392 Op::Silu { input } => Op::Silu {
393 input: remap(input),
394 },
395
396 Op::AddScalar { input, scalar } => Op::AddScalar {
397 input: remap(input),
398 scalar: *scalar,
399 },
400 Op::MulScalar { input, scalar } => Op::MulScalar {
401 input: remap(input),
402 scalar: *scalar,
403 },
404
405 Op::Sum { input } => Op::Sum {
406 input: remap(input),
407 },
408 Op::SumAxis {
409 input,
410 axis,
411 keepdim,
412 } => Op::SumAxis {
413 input: remap(input),
414 axis: *axis,
415 keepdim: *keepdim,
416 },
417 Op::Mean { input } => Op::Mean {
418 input: remap(input),
419 },
420 Op::MeanAxis {
421 input,
422 axis,
423 keepdim,
424 } => Op::MeanAxis {
425 input: remap(input),
426 axis: *axis,
427 keepdim: *keepdim,
428 },
429 Op::MaxAxis {
430 input,
431 axis,
432 keepdim,
433 } => Op::MaxAxis {
434 input: remap(input),
435 axis: *axis,
436 keepdim: *keepdim,
437 },
438
439 Op::Reshape { input, shape } => Op::Reshape {
440 input: remap(input),
441 shape: shape.clone(),
442 },
443 Op::Transpose { input, dim0, dim1 } => Op::Transpose {
444 input: remap(input),
445 dim0: *dim0,
446 dim1: *dim1,
447 },
448 Op::Squeeze { input, dim } => Op::Squeeze {
449 input: remap(input),
450 dim: *dim,
451 },
452 Op::Unsqueeze { input, dim } => Op::Unsqueeze {
453 input: remap(input),
454 dim: *dim,
455 },
456 Op::Broadcast { input, shape } => Op::Broadcast {
457 input: remap(input),
458 shape: shape.clone(),
459 },
460
461 Op::MatMul { lhs, rhs } => Op::MatMul {
462 lhs: remap(lhs),
463 rhs: remap(rhs),
464 },
465
466 Op::Gt { lhs, rhs } => Op::Gt {
467 lhs: remap(lhs),
468 rhs: remap(rhs),
469 },
470 Op::Lt { lhs, rhs } => Op::Lt {
471 lhs: remap(lhs),
472 rhs: remap(rhs),
473 },
474 Op::Eq { lhs, rhs } => Op::Eq {
475 lhs: remap(lhs),
476 rhs: remap(rhs),
477 },
478
479 Op::Where { condition, x, y } => Op::Where {
480 condition: remap(condition),
481 x: remap(x),
482 y: remap(y),
483 },
484
485 Op::Cast { input, dtype } => Op::Cast {
486 input: remap(input),
487 dtype: *dtype,
488 },
489 Op::Contiguous { input } => Op::Contiguous {
490 input: remap(input),
491 },
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::trace::trace;
499
500 #[test]
501 fn test_dead_code_elimination() {
502 let graph = trace(|tracer| {
503 let a = tracer.input("a", &[2, 3]);
504 let b = tracer.input("b", &[2, 3]);
505 let _unused = a.mul(&b); let c = a.add(&b);
507 tracer.output("result", c)
508 });
509
510 let optimizer = Optimizer::new();
511 let mut opt = optimizer;
512 opt.add_pass(OptimizationPass::DeadCodeElimination);
513 let optimized = opt.optimize(graph);
514
515 let has_mul = optimized
517 .nodes()
518 .iter()
519 .any(|n| matches!(n.op, Op::Mul { .. }));
520 assert!(!has_mul);
521 }
522
523 #[test]
524 fn test_algebraic_simplification() {
525 let graph = trace(|tracer| {
526 let x = tracer.input("x", &[2, 3]);
527 let y = x.mul_scalar(1.0); tracer.output("y", y)
529 });
530
531 let mut optimizer = Optimizer::new();
532 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
533 let optimized = optimizer.optimize(graph);
534
535 let has_mul_scalar = optimized
537 .nodes()
538 .iter()
539 .any(|n| matches!(n.op, Op::MulScalar { .. }));
540 assert!(!has_mul_scalar);
541 }
542
543 #[test]
544 fn test_constant_folding() {
545 let graph = trace(|tracer| {
546 let x = tracer.input("x", &[2, 3]);
547 let y = x.mul_scalar(0.0); tracer.output("y", y)
549 });
550
551 let mut optimizer = Optimizer::new();
552 optimizer.add_pass(OptimizationPass::ConstantFolding);
553 let optimized = optimizer.optimize(graph);
554
555 let has_constant = optimized
557 .nodes()
558 .iter()
559 .any(|n| matches!(n.op, Op::Constant { .. }));
560 assert!(has_constant);
561 }
562}