1use crate::ir::{Graph, NodeId, Op};
18use rustc_hash::{FxHashMap, FxHashSet};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum OptimizationPass {
23 ConstantFolding,
25 DeadCodeElimination,
27 ElementwiseFusion,
29 CommonSubexpressionElimination,
31 AlgebraicSimplification,
33 StrengthReduction,
35}
36
37pub struct Optimizer {
39 passes: Vec<OptimizationPass>,
40}
41
42impl Optimizer {
43 pub fn new() -> Self {
45 Self { passes: Vec::new() }
46 }
47
48 pub fn default_passes() -> Self {
50 Self {
51 passes: vec![
52 OptimizationPass::ConstantFolding,
53 OptimizationPass::AlgebraicSimplification,
54 OptimizationPass::DeadCodeElimination,
55 OptimizationPass::CommonSubexpressionElimination,
56 ],
57 }
58 }
59
60 pub fn add_pass(&mut self, pass: OptimizationPass) {
62 self.passes.push(pass);
63 }
64
65 pub fn optimize(&self, mut graph: Graph) -> Graph {
67 for pass in &self.passes {
68 graph = self.run_pass(graph, *pass);
69 }
70 graph
71 }
72
73 fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
74 match pass {
75 OptimizationPass::ConstantFolding => constant_folding(graph),
76 OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
77 OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
78 OptimizationPass::CommonSubexpressionElimination => cse(graph),
79 OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
80 OptimizationPass::StrengthReduction => strength_reduction(graph),
81 }
82 }
83}
84
85impl Default for Optimizer {
86 fn default() -> Self {
87 Self::default_passes()
88 }
89}
90
91fn constant_folding(graph: Graph) -> Graph {
93 let mut new_graph = Graph::new();
96 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
97 let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
98
99 for node in graph.nodes() {
100 if let Op::Constant { value } = &node.op {
102 constants.insert(node.id, *value);
103 }
104
105 let new_op = match &node.op {
107 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
108 let new_input = node_map.get(input).copied().unwrap_or(*input);
110 node_map.insert(node.id, new_input);
111 continue;
112 }
113 Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
114 Op::Constant { value: 0.0 }
116 }
117 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
118 let new_input = node_map.get(input).copied().unwrap_or(*input);
120 node_map.insert(node.id, new_input);
121 continue;
122 }
123 other => remap_op(other, &node_map),
124 };
125
126 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
127 node_map.insert(node.id, new_id);
128 }
129
130 for (name, id) in graph.inputs() {
132 if let Some(&new_id) = node_map.get(id) {
133 new_graph.register_input(name, new_id);
134 }
135 }
136 for (name, id) in graph.outputs() {
137 if let Some(&new_id) = node_map.get(id) {
138 new_graph.register_output(name, new_id);
139 }
140 }
141
142 new_graph
143}
144
145fn dead_code_elimination(graph: Graph) -> Graph {
147 let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
149 let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
150
151 while let Some(id) = worklist.pop() {
152 if live_nodes.insert(id) {
153 let node = graph.node(id);
154 for input_id in node.op.inputs() {
155 worklist.push(input_id);
156 }
157 }
158 }
159
160 let mut new_graph = Graph::new();
162 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
163
164 for node in graph.nodes() {
165 if !live_nodes.contains(&node.id) {
166 continue;
167 }
168
169 let new_op = remap_op(&node.op, &node_map);
170 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
171 node_map.insert(node.id, new_id);
172 }
173
174 for (name, id) in graph.inputs() {
176 if let Some(&new_id) = node_map.get(id) {
177 new_graph.register_input(name, new_id);
178 }
179 }
180 for (name, id) in graph.outputs() {
181 if let Some(&new_id) = node_map.get(id) {
182 new_graph.register_output(name, new_id);
183 }
184 }
185
186 new_graph
187}
188
189fn elementwise_fusion(graph: Graph) -> Graph {
191 graph
195}
196
197fn cse(graph: Graph) -> Graph {
199 let mut new_graph = Graph::new();
201 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
202 let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
203
204 for node in graph.nodes() {
205 let remapped_op = remap_op(&node.op, &node_map);
206 let expr_key = format!("{:?}", remapped_op);
207
208 if let Some(&existing_id) = expr_map.get(&expr_key) {
209 node_map.insert(node.id, existing_id);
211 } else {
212 let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
213 node_map.insert(node.id, new_id);
214 expr_map.insert(expr_key, new_id);
215 }
216 }
217
218 for (name, id) in graph.inputs() {
220 if let Some(&new_id) = node_map.get(id) {
221 new_graph.register_input(name, new_id);
222 }
223 }
224 for (name, id) in graph.outputs() {
225 if let Some(&new_id) = node_map.get(id) {
226 new_graph.register_output(name, new_id);
227 }
228 }
229
230 new_graph
231}
232
233fn algebraic_simplification(graph: Graph) -> Graph {
235 let mut new_graph = Graph::new();
236 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
237
238 for node in graph.nodes() {
239 let simplified_op = match &node.op {
240 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
242 let new_input = node_map.get(input).copied().unwrap_or(*input);
243 node_map.insert(node.id, new_input);
244 continue;
245 }
246 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
248 let new_input = node_map.get(input).copied().unwrap_or(*input);
249 node_map.insert(node.id, new_input);
250 continue;
251 }
252 Op::Neg { input } => {
256 let actual_input = node_map.get(input).copied().unwrap_or(*input);
257 if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
258 if let Op::Neg { input: inner } = &input_node.op {
259 node_map.insert(node.id, *inner);
260 continue;
261 }
262 }
263 Op::Neg {
264 input: actual_input,
265 }
266 }
267 other => remap_op(other, &node_map),
268 };
269
270 let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
271 node_map.insert(node.id, new_id);
272 }
273
274 for (name, id) in graph.inputs() {
276 if let Some(&new_id) = node_map.get(id) {
277 new_graph.register_input(name, new_id);
278 }
279 }
280 for (name, id) in graph.outputs() {
281 if let Some(&new_id) = node_map.get(id) {
282 new_graph.register_output(name, new_id);
283 }
284 }
285
286 new_graph
287}
288
289fn strength_reduction(graph: Graph) -> Graph {
291 let mut new_graph = Graph::new();
292 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
293
294 for node in graph.nodes() {
295 let reduced_op = match &node.op {
296 Op::Pow { .. } => {
298 remap_op(&node.op, &node_map)
301 }
302 Op::Div { .. } => {
304 remap_op(&node.op, &node_map)
306 }
307 other => remap_op(other, &node_map),
308 };
309
310 let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
311 node_map.insert(node.id, new_id);
312 }
313
314 for (name, id) in graph.inputs() {
316 if let Some(&new_id) = node_map.get(id) {
317 new_graph.register_input(name, new_id);
318 }
319 }
320 for (name, id) in graph.outputs() {
321 if let Some(&new_id) = node_map.get(id) {
322 new_graph.register_output(name, new_id);
323 }
324 }
325
326 new_graph
327}
328
329fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
331 let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
332
333 match op {
334 Op::Input { name } => Op::Input { name: name.clone() },
335 Op::Output { name, input } => Op::Output {
336 name: name.clone(),
337 input: remap(input),
338 },
339 Op::Constant { value } => Op::Constant { value: *value },
340
341 Op::Add { lhs, rhs } => Op::Add {
342 lhs: remap(lhs),
343 rhs: remap(rhs),
344 },
345 Op::Sub { lhs, rhs } => Op::Sub {
346 lhs: remap(lhs),
347 rhs: remap(rhs),
348 },
349 Op::Mul { lhs, rhs } => Op::Mul {
350 lhs: remap(lhs),
351 rhs: remap(rhs),
352 },
353 Op::Div { lhs, rhs } => Op::Div {
354 lhs: remap(lhs),
355 rhs: remap(rhs),
356 },
357 Op::Pow { base, exp } => Op::Pow {
358 base: remap(base),
359 exp: remap(exp),
360 },
361 Op::Max { lhs, rhs } => Op::Max {
362 lhs: remap(lhs),
363 rhs: remap(rhs),
364 },
365 Op::Min { lhs, rhs } => Op::Min {
366 lhs: remap(lhs),
367 rhs: remap(rhs),
368 },
369
370 Op::Neg { input } => Op::Neg {
371 input: remap(input),
372 },
373 Op::Abs { input } => Op::Abs {
374 input: remap(input),
375 },
376 Op::Sqrt { input } => Op::Sqrt {
377 input: remap(input),
378 },
379 Op::Exp { input } => Op::Exp {
380 input: remap(input),
381 },
382 Op::Log { input } => Op::Log {
383 input: remap(input),
384 },
385 Op::Sin { input } => Op::Sin {
386 input: remap(input),
387 },
388 Op::Cos { input } => Op::Cos {
389 input: remap(input),
390 },
391 Op::Tanh { input } => Op::Tanh {
392 input: remap(input),
393 },
394
395 Op::Relu { input } => Op::Relu {
396 input: remap(input),
397 },
398 Op::Sigmoid { input } => Op::Sigmoid {
399 input: remap(input),
400 },
401 Op::Gelu { input } => Op::Gelu {
402 input: remap(input),
403 },
404 Op::Silu { input } => Op::Silu {
405 input: remap(input),
406 },
407
408 Op::AddScalar { input, scalar } => Op::AddScalar {
409 input: remap(input),
410 scalar: *scalar,
411 },
412 Op::MulScalar { input, scalar } => Op::MulScalar {
413 input: remap(input),
414 scalar: *scalar,
415 },
416
417 Op::Sum { input } => Op::Sum {
418 input: remap(input),
419 },
420 Op::SumAxis {
421 input,
422 axis,
423 keepdim,
424 } => Op::SumAxis {
425 input: remap(input),
426 axis: *axis,
427 keepdim: *keepdim,
428 },
429 Op::Mean { input } => Op::Mean {
430 input: remap(input),
431 },
432 Op::MeanAxis {
433 input,
434 axis,
435 keepdim,
436 } => Op::MeanAxis {
437 input: remap(input),
438 axis: *axis,
439 keepdim: *keepdim,
440 },
441 Op::MaxAxis {
442 input,
443 axis,
444 keepdim,
445 } => Op::MaxAxis {
446 input: remap(input),
447 axis: *axis,
448 keepdim: *keepdim,
449 },
450
451 Op::Reshape { input, shape } => Op::Reshape {
452 input: remap(input),
453 shape: shape.clone(),
454 },
455 Op::Transpose { input, dim0, dim1 } => Op::Transpose {
456 input: remap(input),
457 dim0: *dim0,
458 dim1: *dim1,
459 },
460 Op::Squeeze { input, dim } => Op::Squeeze {
461 input: remap(input),
462 dim: *dim,
463 },
464 Op::Unsqueeze { input, dim } => Op::Unsqueeze {
465 input: remap(input),
466 dim: *dim,
467 },
468 Op::Broadcast { input, shape } => Op::Broadcast {
469 input: remap(input),
470 shape: shape.clone(),
471 },
472
473 Op::MatMul { lhs, rhs } => Op::MatMul {
474 lhs: remap(lhs),
475 rhs: remap(rhs),
476 },
477
478 Op::Gt { lhs, rhs } => Op::Gt {
479 lhs: remap(lhs),
480 rhs: remap(rhs),
481 },
482 Op::Lt { lhs, rhs } => Op::Lt {
483 lhs: remap(lhs),
484 rhs: remap(rhs),
485 },
486 Op::Eq { lhs, rhs } => Op::Eq {
487 lhs: remap(lhs),
488 rhs: remap(rhs),
489 },
490
491 Op::Where { condition, x, y } => Op::Where {
492 condition: remap(condition),
493 x: remap(x),
494 y: remap(y),
495 },
496
497 Op::Cast { input, dtype } => Op::Cast {
498 input: remap(input),
499 dtype: *dtype,
500 },
501 Op::Contiguous { input } => Op::Contiguous {
502 input: remap(input),
503 },
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use crate::trace::trace;
511
512 #[test]
513 fn test_dead_code_elimination() {
514 let graph = trace(|tracer| {
515 let a = tracer.input("a", &[2, 3]);
516 let b = tracer.input("b", &[2, 3]);
517 let _unused = a.mul(&b); let c = a.add(&b);
519 tracer.output("result", c)
520 });
521
522 let optimizer = Optimizer::new();
523 let mut opt = optimizer;
524 opt.add_pass(OptimizationPass::DeadCodeElimination);
525 let optimized = opt.optimize(graph);
526
527 let has_mul = optimized
529 .nodes()
530 .iter()
531 .any(|n| matches!(n.op, Op::Mul { .. }));
532 assert!(!has_mul);
533 }
534
535 #[test]
536 fn test_algebraic_simplification() {
537 let graph = trace(|tracer| {
538 let x = tracer.input("x", &[2, 3]);
539 let y = x.mul_scalar(1.0); tracer.output("y", y)
541 });
542
543 let mut optimizer = Optimizer::new();
544 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
545 let optimized = optimizer.optimize(graph);
546
547 let has_mul_scalar = optimized
549 .nodes()
550 .iter()
551 .any(|n| matches!(n.op, Op::MulScalar { .. }));
552 assert!(!has_mul_scalar);
553 }
554
555 #[test]
556 fn test_constant_folding() {
557 let graph = trace(|tracer| {
558 let x = tracer.input("x", &[2, 3]);
559 let y = x.mul_scalar(0.0); tracer.output("y", y)
561 });
562
563 let mut optimizer = Optimizer::new();
564 optimizer.add_pass(OptimizationPass::ConstantFolding);
565 let optimized = optimizer.optimize(graph);
566
567 let has_constant = optimized
569 .nodes()
570 .iter()
571 .any(|n| matches!(n.op, Op::Constant { .. }));
572 assert!(has_constant);
573 }
574}