1use crate::ir::{Graph, NodeId, Op};
43use rustc_hash::{FxHashMap, FxHashSet};
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum OptimizationPass {
52 ConstantFolding,
54 DeadCodeElimination,
56 ElementwiseFusion,
58 CommonSubexpressionElimination,
60 AlgebraicSimplification,
62 StrengthReduction,
64}
65
66pub struct Optimizer {
72 passes: Vec<OptimizationPass>,
73}
74
75impl Optimizer {
76 pub fn new() -> Self {
78 Self { passes: Vec::new() }
79 }
80
81 pub fn default_passes() -> Self {
83 Self {
84 passes: vec![
85 OptimizationPass::ConstantFolding,
86 OptimizationPass::AlgebraicSimplification,
87 OptimizationPass::DeadCodeElimination,
88 OptimizationPass::CommonSubexpressionElimination,
89 ],
90 }
91 }
92
93 pub fn add_pass(&mut self, pass: OptimizationPass) {
95 self.passes.push(pass);
96 }
97
98 pub fn optimize(&self, mut graph: Graph) -> Graph {
100 for pass in &self.passes {
101 graph = self.run_pass(graph, *pass);
102 }
103 graph
104 }
105
106 fn run_pass(&self, graph: Graph, pass: OptimizationPass) -> Graph {
107 match pass {
108 OptimizationPass::ConstantFolding => constant_folding(graph),
109 OptimizationPass::DeadCodeElimination => dead_code_elimination(graph),
110 OptimizationPass::ElementwiseFusion => elementwise_fusion(graph),
111 OptimizationPass::CommonSubexpressionElimination => cse(graph),
112 OptimizationPass::AlgebraicSimplification => algebraic_simplification(graph),
113 OptimizationPass::StrengthReduction => strength_reduction(graph),
114 }
115 }
116}
117
118impl Default for Optimizer {
119 fn default() -> Self {
120 Self::default_passes()
121 }
122}
123
124fn constant_folding(graph: Graph) -> Graph {
134 let mut new_graph = Graph::new();
137 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
138 let mut constants: FxHashMap<NodeId, f64> = FxHashMap::default();
139
140 for node in graph.nodes() {
141 if let Op::Constant { value } = &node.op {
143 constants.insert(node.id, *value);
144 }
145
146 let new_op = match &node.op {
148 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
149 let new_input = node_map.get(input).copied().unwrap_or(*input);
151 node_map.insert(node.id, new_input);
152 continue;
153 }
154 Op::MulScalar { input: _, scalar } if *scalar == 0.0 => {
155 Op::Constant { value: 0.0 }
157 }
158 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
159 let new_input = node_map.get(input).copied().unwrap_or(*input);
161 node_map.insert(node.id, new_input);
162 continue;
163 }
164 other => remap_op(other, &node_map),
165 };
166
167 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
168 node_map.insert(node.id, new_id);
169 }
170
171 for (name, id) in graph.inputs() {
173 if let Some(&new_id) = node_map.get(id) {
174 new_graph.register_input(name, new_id);
175 }
176 }
177 for (name, id) in graph.outputs() {
178 if let Some(&new_id) = node_map.get(id) {
179 new_graph.register_output(name, new_id);
180 }
181 }
182
183 new_graph
184}
185
186fn dead_code_elimination(graph: Graph) -> Graph {
192 let mut live_nodes: FxHashSet<NodeId> = FxHashSet::default();
194 let mut worklist: Vec<NodeId> = graph.outputs().values().copied().collect();
195
196 while let Some(id) = worklist.pop() {
197 if live_nodes.insert(id) {
198 let node = graph.node(id);
199 for input_id in node.op.inputs() {
200 worklist.push(input_id);
201 }
202 }
203 }
204
205 let mut new_graph = Graph::new();
207 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
208
209 for node in graph.nodes() {
210 if !live_nodes.contains(&node.id) {
211 continue;
212 }
213
214 let new_op = remap_op(&node.op, &node_map);
215 let new_id = new_graph.add_node(new_op, node.dtype, node.shape.clone());
216 node_map.insert(node.id, new_id);
217 }
218
219 for (name, id) in graph.inputs() {
221 if let Some(&new_id) = node_map.get(id) {
222 new_graph.register_input(name, new_id);
223 }
224 }
225 for (name, id) in graph.outputs() {
226 if let Some(&new_id) = node_map.get(id) {
227 new_graph.register_output(name, new_id);
228 }
229 }
230
231 new_graph
232}
233
234fn elementwise_fusion(graph: Graph) -> Graph {
240 graph
244}
245
246fn cse(graph: Graph) -> Graph {
252 let mut new_graph = Graph::new();
254 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
255 let mut expr_map: FxHashMap<String, NodeId> = FxHashMap::default();
256
257 for node in graph.nodes() {
258 let remapped_op = remap_op(&node.op, &node_map);
259 let expr_key = format!("{:?}", remapped_op);
260
261 if let Some(&existing_id) = expr_map.get(&expr_key) {
262 node_map.insert(node.id, existing_id);
264 } else {
265 let new_id = new_graph.add_node(remapped_op, node.dtype, node.shape.clone());
266 node_map.insert(node.id, new_id);
267 expr_map.insert(expr_key, new_id);
268 }
269 }
270
271 for (name, id) in graph.inputs() {
273 if let Some(&new_id) = node_map.get(id) {
274 new_graph.register_input(name, new_id);
275 }
276 }
277 for (name, id) in graph.outputs() {
278 if let Some(&new_id) = node_map.get(id) {
279 new_graph.register_output(name, new_id);
280 }
281 }
282
283 new_graph
284}
285
286fn algebraic_simplification(graph: Graph) -> Graph {
292 let mut new_graph = Graph::new();
293 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
294
295 for node in graph.nodes() {
296 let simplified_op = match &node.op {
297 Op::MulScalar { input, scalar } if *scalar == 1.0 => {
299 let new_input = node_map.get(input).copied().unwrap_or(*input);
300 node_map.insert(node.id, new_input);
301 continue;
302 }
303 Op::AddScalar { input, scalar } if *scalar == 0.0 => {
305 let new_input = node_map.get(input).copied().unwrap_or(*input);
306 node_map.insert(node.id, new_input);
307 continue;
308 }
309 Op::Neg { input } => {
313 let actual_input = node_map.get(input).copied().unwrap_or(*input);
314 if let Some(input_node) = new_graph.nodes().iter().find(|n| n.id == actual_input) {
315 if let Op::Neg { input: inner } = &input_node.op {
316 node_map.insert(node.id, *inner);
317 continue;
318 }
319 }
320 Op::Neg {
321 input: actual_input,
322 }
323 }
324 other => remap_op(other, &node_map),
325 };
326
327 let new_id = new_graph.add_node(simplified_op, node.dtype, node.shape.clone());
328 node_map.insert(node.id, new_id);
329 }
330
331 for (name, id) in graph.inputs() {
333 if let Some(&new_id) = node_map.get(id) {
334 new_graph.register_input(name, new_id);
335 }
336 }
337 for (name, id) in graph.outputs() {
338 if let Some(&new_id) = node_map.get(id) {
339 new_graph.register_output(name, new_id);
340 }
341 }
342
343 new_graph
344}
345
346fn strength_reduction(graph: Graph) -> Graph {
352 let mut new_graph = Graph::new();
353 let mut node_map: FxHashMap<NodeId, NodeId> = FxHashMap::default();
354
355 for node in graph.nodes() {
356 let reduced_op = match &node.op {
357 Op::Pow { .. } => {
359 remap_op(&node.op, &node_map)
362 }
363 Op::Div { .. } => {
365 remap_op(&node.op, &node_map)
367 }
368 other => remap_op(other, &node_map),
369 };
370
371 let new_id = new_graph.add_node(reduced_op, node.dtype, node.shape.clone());
372 node_map.insert(node.id, new_id);
373 }
374
375 for (name, id) in graph.inputs() {
377 if let Some(&new_id) = node_map.get(id) {
378 new_graph.register_input(name, new_id);
379 }
380 }
381 for (name, id) in graph.outputs() {
382 if let Some(&new_id) = node_map.get(id) {
383 new_graph.register_output(name, new_id);
384 }
385 }
386
387 new_graph
388}
389
390fn remap_op(op: &Op, node_map: &FxHashMap<NodeId, NodeId>) -> Op {
396 let remap = |id: &NodeId| node_map.get(id).copied().unwrap_or(*id);
397
398 match op {
399 Op::Input { name } => Op::Input { name: name.clone() },
400 Op::Output { name, input } => Op::Output {
401 name: name.clone(),
402 input: remap(input),
403 },
404 Op::Constant { value } => Op::Constant { value: *value },
405
406 Op::Add { lhs, rhs } => Op::Add {
407 lhs: remap(lhs),
408 rhs: remap(rhs),
409 },
410 Op::Sub { lhs, rhs } => Op::Sub {
411 lhs: remap(lhs),
412 rhs: remap(rhs),
413 },
414 Op::Mul { lhs, rhs } => Op::Mul {
415 lhs: remap(lhs),
416 rhs: remap(rhs),
417 },
418 Op::Div { lhs, rhs } => Op::Div {
419 lhs: remap(lhs),
420 rhs: remap(rhs),
421 },
422 Op::Pow { base, exp } => Op::Pow {
423 base: remap(base),
424 exp: remap(exp),
425 },
426 Op::Max { lhs, rhs } => Op::Max {
427 lhs: remap(lhs),
428 rhs: remap(rhs),
429 },
430 Op::Min { lhs, rhs } => Op::Min {
431 lhs: remap(lhs),
432 rhs: remap(rhs),
433 },
434
435 Op::Neg { input } => Op::Neg {
436 input: remap(input),
437 },
438 Op::Abs { input } => Op::Abs {
439 input: remap(input),
440 },
441 Op::Sqrt { input } => Op::Sqrt {
442 input: remap(input),
443 },
444 Op::Exp { input } => Op::Exp {
445 input: remap(input),
446 },
447 Op::Log { input } => Op::Log {
448 input: remap(input),
449 },
450 Op::Sin { input } => Op::Sin {
451 input: remap(input),
452 },
453 Op::Cos { input } => Op::Cos {
454 input: remap(input),
455 },
456 Op::Tanh { input } => Op::Tanh {
457 input: remap(input),
458 },
459
460 Op::Relu { input } => Op::Relu {
461 input: remap(input),
462 },
463 Op::Sigmoid { input } => Op::Sigmoid {
464 input: remap(input),
465 },
466 Op::Gelu { input } => Op::Gelu {
467 input: remap(input),
468 },
469 Op::Silu { input } => Op::Silu {
470 input: remap(input),
471 },
472
473 Op::AddScalar { input, scalar } => Op::AddScalar {
474 input: remap(input),
475 scalar: *scalar,
476 },
477 Op::MulScalar { input, scalar } => Op::MulScalar {
478 input: remap(input),
479 scalar: *scalar,
480 },
481
482 Op::Sum { input } => Op::Sum {
483 input: remap(input),
484 },
485 Op::SumAxis {
486 input,
487 axis,
488 keepdim,
489 } => Op::SumAxis {
490 input: remap(input),
491 axis: *axis,
492 keepdim: *keepdim,
493 },
494 Op::Mean { input } => Op::Mean {
495 input: remap(input),
496 },
497 Op::MeanAxis {
498 input,
499 axis,
500 keepdim,
501 } => Op::MeanAxis {
502 input: remap(input),
503 axis: *axis,
504 keepdim: *keepdim,
505 },
506 Op::MaxAxis {
507 input,
508 axis,
509 keepdim,
510 } => Op::MaxAxis {
511 input: remap(input),
512 axis: *axis,
513 keepdim: *keepdim,
514 },
515
516 Op::Reshape { input, shape } => Op::Reshape {
517 input: remap(input),
518 shape: shape.clone(),
519 },
520 Op::Transpose { input, dim0, dim1 } => Op::Transpose {
521 input: remap(input),
522 dim0: *dim0,
523 dim1: *dim1,
524 },
525 Op::Squeeze { input, dim } => Op::Squeeze {
526 input: remap(input),
527 dim: *dim,
528 },
529 Op::Unsqueeze { input, dim } => Op::Unsqueeze {
530 input: remap(input),
531 dim: *dim,
532 },
533 Op::Broadcast { input, shape } => Op::Broadcast {
534 input: remap(input),
535 shape: shape.clone(),
536 },
537
538 Op::MatMul { lhs, rhs } => Op::MatMul {
539 lhs: remap(lhs),
540 rhs: remap(rhs),
541 },
542
543 Op::Gt { lhs, rhs } => Op::Gt {
544 lhs: remap(lhs),
545 rhs: remap(rhs),
546 },
547 Op::Lt { lhs, rhs } => Op::Lt {
548 lhs: remap(lhs),
549 rhs: remap(rhs),
550 },
551 Op::Eq { lhs, rhs } => Op::Eq {
552 lhs: remap(lhs),
553 rhs: remap(rhs),
554 },
555
556 Op::Where { condition, x, y } => Op::Where {
557 condition: remap(condition),
558 x: remap(x),
559 y: remap(y),
560 },
561
562 Op::Cast { input, dtype } => Op::Cast {
563 input: remap(input),
564 dtype: *dtype,
565 },
566 Op::Contiguous { input } => Op::Contiguous {
567 input: remap(input),
568 },
569 }
570}
571
572#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::trace::trace;
580
581 #[test]
582 fn test_dead_code_elimination() {
583 let graph = trace(|tracer| {
584 let a = tracer.input("a", &[2, 3]);
585 let b = tracer.input("b", &[2, 3]);
586 let _unused = a.mul(&b); let c = a.add(&b);
588 tracer.output("result", c)
589 });
590
591 let optimizer = Optimizer::new();
592 let mut opt = optimizer;
593 opt.add_pass(OptimizationPass::DeadCodeElimination);
594 let optimized = opt.optimize(graph);
595
596 let has_mul = optimized
598 .nodes()
599 .iter()
600 .any(|n| matches!(n.op, Op::Mul { .. }));
601 assert!(!has_mul);
602 }
603
604 #[test]
605 fn test_algebraic_simplification() {
606 let graph = trace(|tracer| {
607 let x = tracer.input("x", &[2, 3]);
608 let y = x.mul_scalar(1.0); tracer.output("y", y)
610 });
611
612 let mut optimizer = Optimizer::new();
613 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
614 let optimized = optimizer.optimize(graph);
615
616 let has_mul_scalar = optimized
618 .nodes()
619 .iter()
620 .any(|n| matches!(n.op, Op::MulScalar { .. }));
621 assert!(!has_mul_scalar);
622 }
623
624 #[test]
625 fn test_constant_folding() {
626 let graph = trace(|tracer| {
627 let x = tracer.input("x", &[2, 3]);
628 let y = x.mul_scalar(0.0); tracer.output("y", y)
630 });
631
632 let mut optimizer = Optimizer::new();
633 optimizer.add_pass(OptimizationPass::ConstantFolding);
634 let optimized = optimizer.optimize(graph);
635
636 let has_constant = optimized
638 .nodes()
639 .iter()
640 .any(|n| matches!(n.op, Op::Constant { .. }));
641 assert!(has_constant);
642 }
643}