1use std::collections::{HashMap, HashSet};
19
20use crate::{
21 BinaryOp, BlockId, CompareOp, ConstantValue, IrModule, IrNode, IrType, ScalarType, Terminator,
22 UnaryOp, ValueId,
23};
24
25pub trait OptimizationPass {
31 fn run(&self, module: &mut IrModule) -> OptimizationResult;
33
34 fn name(&self) -> &'static str;
36}
37
38#[derive(Debug, Clone, Default)]
40pub struct OptimizationResult {
41 pub changed: bool,
43 pub instructions_removed: usize,
45 pub instructions_modified: usize,
47 pub blocks_removed: usize,
49}
50
51impl OptimizationResult {
52 pub fn unchanged() -> Self {
54 Self::default()
55 }
56
57 pub fn changed() -> Self {
59 Self {
60 changed: true,
61 ..Default::default()
62 }
63 }
64
65 pub fn merge(&mut self, other: OptimizationResult) {
67 self.changed |= other.changed;
68 self.instructions_removed += other.instructions_removed;
69 self.instructions_modified += other.instructions_modified;
70 self.blocks_removed += other.blocks_removed;
71 }
72}
73
74pub struct DeadCodeElimination;
82
83impl DeadCodeElimination {
84 pub fn new() -> Self {
86 Self
87 }
88
89 fn find_used_values(&self, module: &IrModule) -> HashSet<ValueId> {
91 let mut used = HashSet::new();
92
93 for param in &module.parameters {
95 used.insert(param.value_id);
96 }
97
98 for block in module.blocks.values() {
100 for inst in &block.instructions {
102 self.collect_uses(&inst.node, &mut used);
103 }
104
105 if let Some(ref term) = block.terminator {
107 self.collect_terminator_uses(term, &mut used);
108 }
109 }
110
111 used
112 }
113
114 fn collect_uses(&self, node: &IrNode, used: &mut HashSet<ValueId>) {
116 match node {
117 IrNode::BinaryOp(_, lhs, rhs) => {
118 used.insert(*lhs);
119 used.insert(*rhs);
120 }
121 IrNode::UnaryOp(_, operand) => {
122 used.insert(*operand);
123 }
124 IrNode::Compare(_, lhs, rhs) => {
125 used.insert(*lhs);
126 used.insert(*rhs);
127 }
128 IrNode::Cast(_, value, _) => {
129 used.insert(*value);
130 }
131 IrNode::Load(ptr) => {
132 used.insert(*ptr);
133 }
134 IrNode::Store(ptr, value) => {
135 used.insert(*ptr);
136 used.insert(*value);
137 }
138 IrNode::GetElementPtr(base, indices) => {
139 used.insert(*base);
140 for idx in indices {
141 used.insert(*idx);
142 }
143 }
144 IrNode::Select(cond, then_val, else_val) => {
145 used.insert(*cond);
146 used.insert(*then_val);
147 used.insert(*else_val);
148 }
149 IrNode::Phi(incoming) => {
150 for (_, value) in incoming {
151 used.insert(*value);
152 }
153 }
154 IrNode::Atomic(_, ptr, value) => {
155 used.insert(*ptr);
156 used.insert(*value);
157 }
158 IrNode::AtomicCas(ptr, expected, desired) => {
159 used.insert(*ptr);
160 used.insert(*expected);
161 used.insert(*desired);
162 }
163 IrNode::WarpVote(_, pred) => {
164 used.insert(*pred);
165 }
166 IrNode::WarpShuffle(_, value, lane) => {
167 used.insert(*value);
168 used.insert(*lane);
169 }
170 IrNode::WarpReduce(_, value) => {
171 used.insert(*value);
172 }
173 IrNode::Math(_, args) => {
174 for arg in args {
175 used.insert(*arg);
176 }
177 }
178 IrNode::Call(_, args) => {
179 for arg in args {
180 used.insert(*arg);
181 }
182 }
183 IrNode::K2HEnqueue(value) => {
184 used.insert(*value);
185 }
186 IrNode::K2KSend(dest, msg) => {
187 used.insert(*dest);
188 used.insert(*msg);
189 }
190 IrNode::HlcUpdate(ts) => {
191 used.insert(*ts);
192 }
193 IrNode::ExtractField(value, _) => {
194 used.insert(*value);
195 }
196 IrNode::InsertField(base, _, value) => {
197 used.insert(*base);
198 used.insert(*value);
199 }
200 IrNode::Constant(_)
202 | IrNode::Parameter(_)
203 | IrNode::Undef
204 | IrNode::ThreadId(_)
205 | IrNode::BlockId(_)
206 | IrNode::BlockDim(_)
207 | IrNode::GridDim(_)
208 | IrNode::GlobalThreadId(_)
209 | IrNode::WarpId
210 | IrNode::LaneId
211 | IrNode::Barrier
212 | IrNode::MemoryFence(_)
213 | IrNode::GridSync
214 | IrNode::Alloca(_)
215 | IrNode::SharedAlloc(_, _)
216 | IrNode::H2KDequeue
217 | IrNode::H2KIsEmpty
218 | IrNode::K2KRecv
219 | IrNode::K2KTryRecv
220 | IrNode::HlcNow
221 | IrNode::HlcTick => {}
222 }
223 }
224
225 fn collect_terminator_uses(&self, term: &Terminator, used: &mut HashSet<ValueId>) {
227 match term {
228 Terminator::Return(Some(value)) => {
229 used.insert(*value);
230 }
231 Terminator::CondBranch(cond, _, _) => {
232 used.insert(*cond);
233 }
234 Terminator::Switch(value, _, _) => {
235 used.insert(*value);
236 }
237 Terminator::Return(None) | Terminator::Branch(_) | Terminator::Unreachable => {}
238 }
239 }
240
241 fn has_side_effects(&self, node: &IrNode) -> bool {
243 matches!(
244 node,
245 IrNode::Store(_, _)
246 | IrNode::Atomic(_, _, _)
247 | IrNode::AtomicCas(_, _, _)
248 | IrNode::Barrier
249 | IrNode::MemoryFence(_)
250 | IrNode::GridSync
251 | IrNode::Call(_, _)
252 | IrNode::K2HEnqueue(_)
253 | IrNode::K2KSend(_, _)
254 | IrNode::HlcTick
255 | IrNode::HlcUpdate(_)
256 )
257 }
258}
259
260impl Default for DeadCodeElimination {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266impl OptimizationPass for DeadCodeElimination {
267 fn run(&self, module: &mut IrModule) -> OptimizationResult {
268 let used = self.find_used_values(module);
269 let mut result = OptimizationResult::unchanged();
270
271 for block in module.blocks.values_mut() {
273 let original_len = block.instructions.len();
274
275 block.instructions.retain(|inst| {
276 used.contains(&inst.result) || self.has_side_effects(&inst.node)
278 });
279
280 let removed = original_len - block.instructions.len();
281 if removed > 0 {
282 result.changed = true;
283 result.instructions_removed += removed;
284 }
285 }
286
287 result
288 }
289
290 fn name(&self) -> &'static str {
291 "dead-code-elimination"
292 }
293}
294
295pub struct ConstantFolding {
303 #[allow(dead_code)]
305 constants: HashMap<ValueId, ConstantValue>,
306}
307
308impl ConstantFolding {
309 pub fn new() -> Self {
311 Self {
312 constants: HashMap::new(),
313 }
314 }
315
316 fn fold_binary_op(
318 &self,
319 op: BinaryOp,
320 lhs: &ConstantValue,
321 rhs: &ConstantValue,
322 ) -> Option<ConstantValue> {
323 match (lhs, rhs) {
324 (ConstantValue::I32(l), ConstantValue::I32(r)) => {
325 Some(ConstantValue::I32(Self::fold_binary_i32(op, *l, *r)?))
326 }
327 (ConstantValue::U32(l), ConstantValue::U32(r)) => {
328 Some(ConstantValue::U32(Self::fold_binary_u32(op, *l, *r)?))
329 }
330 (ConstantValue::F32(l), ConstantValue::F32(r)) => {
331 Some(ConstantValue::F32(Self::fold_binary_f32(op, *l, *r)?))
332 }
333 (ConstantValue::I64(l), ConstantValue::I64(r)) => {
334 Some(ConstantValue::I64(Self::fold_binary_i64(op, *l, *r)?))
335 }
336 (ConstantValue::U64(l), ConstantValue::U64(r)) => {
337 Some(ConstantValue::U64(Self::fold_binary_u64(op, *l, *r)?))
338 }
339 (ConstantValue::F64(l), ConstantValue::F64(r)) => {
340 Some(ConstantValue::F64(Self::fold_binary_f64(op, *l, *r)?))
341 }
342 _ => None,
343 }
344 }
345
346 fn fold_binary_i32(op: BinaryOp, l: i32, r: i32) -> Option<i32> {
347 Some(match op {
348 BinaryOp::Add => l.wrapping_add(r),
349 BinaryOp::Sub => l.wrapping_sub(r),
350 BinaryOp::Mul => l.wrapping_mul(r),
351 BinaryOp::Div => l.checked_div(r)?,
352 BinaryOp::Rem => l.checked_rem(r)?,
353 BinaryOp::And => l & r,
354 BinaryOp::Or => l | r,
355 BinaryOp::Xor => l ^ r,
356 BinaryOp::Shl => l.wrapping_shl(r as u32),
357 BinaryOp::Shr => l.wrapping_shr(r as u32),
358 BinaryOp::Sar => l >> (r as u32),
359 BinaryOp::Min => l.min(r),
360 BinaryOp::Max => l.max(r),
361 _ => return None,
362 })
363 }
364
365 fn fold_binary_u32(op: BinaryOp, l: u32, r: u32) -> Option<u32> {
366 Some(match op {
367 BinaryOp::Add => l.wrapping_add(r),
368 BinaryOp::Sub => l.wrapping_sub(r),
369 BinaryOp::Mul => l.wrapping_mul(r),
370 BinaryOp::Div => l.checked_div(r)?,
371 BinaryOp::Rem => l.checked_rem(r)?,
372 BinaryOp::And => l & r,
373 BinaryOp::Or => l | r,
374 BinaryOp::Xor => l ^ r,
375 BinaryOp::Shl => l.wrapping_shl(r),
376 BinaryOp::Shr => l.wrapping_shr(r),
377 BinaryOp::Sar => l >> r,
378 BinaryOp::Min => l.min(r),
379 BinaryOp::Max => l.max(r),
380 _ => return None,
381 })
382 }
383
384 fn fold_binary_i64(op: BinaryOp, l: i64, r: i64) -> Option<i64> {
385 Some(match op {
386 BinaryOp::Add => l.wrapping_add(r),
387 BinaryOp::Sub => l.wrapping_sub(r),
388 BinaryOp::Mul => l.wrapping_mul(r),
389 BinaryOp::Div => l.checked_div(r)?,
390 BinaryOp::Rem => l.checked_rem(r)?,
391 BinaryOp::And => l & r,
392 BinaryOp::Or => l | r,
393 BinaryOp::Xor => l ^ r,
394 BinaryOp::Shl => l.wrapping_shl(r as u32),
395 BinaryOp::Shr => l.wrapping_shr(r as u32),
396 BinaryOp::Sar => l >> (r as u32),
397 BinaryOp::Min => l.min(r),
398 BinaryOp::Max => l.max(r),
399 _ => return None,
400 })
401 }
402
403 fn fold_binary_u64(op: BinaryOp, l: u64, r: u64) -> Option<u64> {
404 Some(match op {
405 BinaryOp::Add => l.wrapping_add(r),
406 BinaryOp::Sub => l.wrapping_sub(r),
407 BinaryOp::Mul => l.wrapping_mul(r),
408 BinaryOp::Div => l.checked_div(r)?,
409 BinaryOp::Rem => l.checked_rem(r)?,
410 BinaryOp::And => l & r,
411 BinaryOp::Or => l | r,
412 BinaryOp::Xor => l ^ r,
413 BinaryOp::Shl => l.wrapping_shl(r as u32),
414 BinaryOp::Shr => l.wrapping_shr(r as u32),
415 BinaryOp::Sar => l >> (r as u32),
416 BinaryOp::Min => l.min(r),
417 BinaryOp::Max => l.max(r),
418 _ => return None,
419 })
420 }
421
422 fn fold_binary_f32(op: BinaryOp, l: f32, r: f32) -> Option<f32> {
423 Some(match op {
424 BinaryOp::Add => l + r,
425 BinaryOp::Sub => l - r,
426 BinaryOp::Mul => l * r,
427 BinaryOp::Div => l / r,
428 BinaryOp::Rem => l % r,
429 BinaryOp::Min => l.min(r),
430 BinaryOp::Max => l.max(r),
431 BinaryOp::Pow => l.powf(r),
432 _ => return None,
433 })
434 }
435
436 fn fold_binary_f64(op: BinaryOp, l: f64, r: f64) -> Option<f64> {
437 Some(match op {
438 BinaryOp::Add => l + r,
439 BinaryOp::Sub => l - r,
440 BinaryOp::Mul => l * r,
441 BinaryOp::Div => l / r,
442 BinaryOp::Rem => l % r,
443 BinaryOp::Min => l.min(r),
444 BinaryOp::Max => l.max(r),
445 BinaryOp::Pow => l.powf(r),
446 _ => return None,
447 })
448 }
449
450 fn fold_unary_op(&self, op: UnaryOp, operand: &ConstantValue) -> Option<ConstantValue> {
452 match operand {
453 ConstantValue::I32(v) => Some(ConstantValue::I32(Self::fold_unary_i32(op, *v)?)),
454 ConstantValue::U32(v) => Some(ConstantValue::U32(Self::fold_unary_u32(op, *v)?)),
455 ConstantValue::F32(v) => Some(ConstantValue::F32(Self::fold_unary_f32(op, *v)?)),
456 ConstantValue::F64(v) => Some(ConstantValue::F64(Self::fold_unary_f64(op, *v)?)),
457 ConstantValue::Bool(v) => {
458 if op == UnaryOp::LogicalNot {
459 Some(ConstantValue::Bool(!v))
460 } else {
461 None
462 }
463 }
464 _ => None,
465 }
466 }
467
468 fn fold_unary_i32(op: UnaryOp, v: i32) -> Option<i32> {
469 Some(match op {
470 UnaryOp::Neg => -v,
471 UnaryOp::Not => !v,
472 UnaryOp::Abs => v.abs(),
473 UnaryOp::Sign => v.signum(),
474 _ => return None,
475 })
476 }
477
478 fn fold_unary_u32(op: UnaryOp, v: u32) -> Option<u32> {
479 Some(match op {
480 UnaryOp::Not => !v,
481 _ => return None,
482 })
483 }
484
485 fn fold_unary_f32(op: UnaryOp, v: f32) -> Option<f32> {
486 Some(match op {
487 UnaryOp::Neg => -v,
488 UnaryOp::Abs => v.abs(),
489 UnaryOp::Sqrt => v.sqrt(),
490 UnaryOp::Rsqrt => 1.0 / v.sqrt(),
491 UnaryOp::Floor => v.floor(),
492 UnaryOp::Ceil => v.ceil(),
493 UnaryOp::Round => v.round(),
494 UnaryOp::Trunc => v.trunc(),
495 UnaryOp::Sign => v.signum(),
496 _ => return None,
497 })
498 }
499
500 fn fold_unary_f64(op: UnaryOp, v: f64) -> Option<f64> {
501 Some(match op {
502 UnaryOp::Neg => -v,
503 UnaryOp::Abs => v.abs(),
504 UnaryOp::Sqrt => v.sqrt(),
505 UnaryOp::Rsqrt => 1.0 / v.sqrt(),
506 UnaryOp::Floor => v.floor(),
507 UnaryOp::Ceil => v.ceil(),
508 UnaryOp::Round => v.round(),
509 UnaryOp::Trunc => v.trunc(),
510 UnaryOp::Sign => v.signum(),
511 _ => return None,
512 })
513 }
514
515 fn fold_compare(
517 &self,
518 op: CompareOp,
519 lhs: &ConstantValue,
520 rhs: &ConstantValue,
521 ) -> Option<ConstantValue> {
522 let result = match (lhs, rhs) {
523 (ConstantValue::I32(l), ConstantValue::I32(r)) => Self::compare_i32(op, *l, *r),
524 (ConstantValue::U32(l), ConstantValue::U32(r)) => Self::compare_u32(op, *l, *r),
525 (ConstantValue::F32(l), ConstantValue::F32(r)) => Self::compare_f32(op, *l, *r),
526 (ConstantValue::Bool(l), ConstantValue::Bool(r)) => match op {
527 CompareOp::Eq => *l == *r,
528 CompareOp::Ne => *l != *r,
529 _ => return None,
530 },
531 _ => return None,
532 };
533 Some(ConstantValue::Bool(result))
534 }
535
536 fn compare_i32(op: CompareOp, l: i32, r: i32) -> bool {
537 match op {
538 CompareOp::Eq => l == r,
539 CompareOp::Ne => l != r,
540 CompareOp::Lt => l < r,
541 CompareOp::Le => l <= r,
542 CompareOp::Gt => l > r,
543 CompareOp::Ge => l >= r,
544 }
545 }
546
547 fn compare_u32(op: CompareOp, l: u32, r: u32) -> bool {
548 match op {
549 CompareOp::Eq => l == r,
550 CompareOp::Ne => l != r,
551 CompareOp::Lt => l < r,
552 CompareOp::Le => l <= r,
553 CompareOp::Gt => l > r,
554 CompareOp::Ge => l >= r,
555 }
556 }
557
558 fn compare_f32(op: CompareOp, l: f32, r: f32) -> bool {
559 match op {
560 CompareOp::Eq => l == r,
561 CompareOp::Ne => l != r,
562 CompareOp::Lt => l < r,
563 CompareOp::Le => l <= r,
564 CompareOp::Gt => l > r,
565 CompareOp::Ge => l >= r,
566 }
567 }
568
569 #[allow(dead_code)]
571 fn get_constant<'a>(&'a self, id: ValueId, module: &'a IrModule) -> Option<&'a ConstantValue> {
572 if let Some(c) = self.constants.get(&id) {
574 return Some(c);
575 }
576
577 if let Some(value) = module.get_value(id) {
579 if let IrNode::Constant(ref c) = value.node {
580 return Some(c);
581 }
582 }
583
584 None
585 }
586}
587
588impl Default for ConstantFolding {
589 fn default() -> Self {
590 Self::new()
591 }
592}
593
594impl OptimizationPass for ConstantFolding {
595 fn run(&self, module: &mut IrModule) -> OptimizationResult {
596 let mut result = OptimizationResult::unchanged();
597 let mut constants = HashMap::new();
598
599 for value in module.values.values() {
601 if let IrNode::Constant(ref c) = value.node {
602 constants.insert(value.id, c.clone());
603 }
604 }
605
606 for block in module.blocks.values_mut() {
608 for inst in &mut block.instructions {
609 let folded = match &inst.node {
610 IrNode::BinaryOp(op, lhs, rhs) => {
611 let lhs_const = constants.get(lhs);
612 let rhs_const = constants.get(rhs);
613
614 if let (Some(l), Some(r)) = (lhs_const, rhs_const) {
615 Self::new().fold_binary_op(*op, l, r)
616 } else {
617 None
618 }
619 }
620 IrNode::UnaryOp(op, operand) => {
621 if let Some(c) = constants.get(operand) {
622 Self::new().fold_unary_op(*op, c)
623 } else {
624 None
625 }
626 }
627 IrNode::Compare(op, lhs, rhs) => {
628 let lhs_const = constants.get(lhs);
629 let rhs_const = constants.get(rhs);
630
631 if let (Some(l), Some(r)) = (lhs_const, rhs_const) {
632 Self::new().fold_compare(*op, l, r)
633 } else {
634 None
635 }
636 }
637 IrNode::Select(cond, then_val, else_val) => {
638 if let Some(ConstantValue::Bool(c)) = constants.get(cond) {
639 let selected = if *c { then_val } else { else_val };
641 constants.get(selected).cloned()
642 } else {
643 None
644 }
645 }
646 _ => None,
647 };
648
649 if let Some(constant) = folded {
650 let new_type = constant.ir_type();
652 inst.node = IrNode::Constant(constant.clone());
653 inst.result_type = new_type;
654 constants.insert(inst.result, constant);
655 result.changed = true;
656 result.instructions_modified += 1;
657 }
658 }
659 }
660
661 result
662 }
663
664 fn name(&self) -> &'static str {
665 "constant-folding"
666 }
667}
668
669pub struct DeadBlockElimination;
677
678impl DeadBlockElimination {
679 pub fn new() -> Self {
681 Self
682 }
683
684 fn find_reachable_blocks(&self, module: &IrModule) -> HashSet<BlockId> {
686 let mut reachable = HashSet::new();
687 let mut worklist = vec![module.entry_block];
688
689 while let Some(block_id) = worklist.pop() {
690 if !reachable.insert(block_id) {
691 continue;
692 }
693
694 if let Some(block) = module.get_block(block_id) {
695 match &block.terminator {
697 Some(Terminator::Branch(target)) => {
698 worklist.push(*target);
699 }
700 Some(Terminator::CondBranch(_, then_target, else_target)) => {
701 worklist.push(*then_target);
702 worklist.push(*else_target);
703 }
704 Some(Terminator::Switch(_, default, cases)) => {
705 worklist.push(*default);
706 for (_, target) in cases {
707 worklist.push(*target);
708 }
709 }
710 _ => {}
711 }
712 }
713 }
714
715 reachable
716 }
717}
718
719impl Default for DeadBlockElimination {
720 fn default() -> Self {
721 Self::new()
722 }
723}
724
725impl OptimizationPass for DeadBlockElimination {
726 fn run(&self, module: &mut IrModule) -> OptimizationResult {
727 let reachable = self.find_reachable_blocks(module);
728 let mut result = OptimizationResult::unchanged();
729
730 let unreachable: Vec<BlockId> = module
732 .blocks
733 .keys()
734 .filter(|id| !reachable.contains(id))
735 .copied()
736 .collect();
737
738 for block_id in unreachable {
740 module.blocks.remove(&block_id);
741 result.changed = true;
742 result.blocks_removed += 1;
743 }
744
745 result
746 }
747
748 fn name(&self) -> &'static str {
749 "dead-block-elimination"
750 }
751}
752
753pub struct AlgebraicSimplification;
769
770impl AlgebraicSimplification {
771 pub fn new() -> Self {
773 Self
774 }
775
776 fn is_zero(c: &ConstantValue) -> bool {
778 match c {
779 ConstantValue::I32(0) => true,
780 ConstantValue::U32(0) => true,
781 ConstantValue::I64(0) => true,
782 ConstantValue::U64(0) => true,
783 ConstantValue::F32(f) => *f == 0.0,
784 ConstantValue::F64(f) => *f == 0.0,
785 _ => false,
786 }
787 }
788
789 fn is_one(c: &ConstantValue) -> bool {
791 match c {
792 ConstantValue::I32(1) => true,
793 ConstantValue::U32(1) => true,
794 ConstantValue::I64(1) => true,
795 ConstantValue::U64(1) => true,
796 ConstantValue::F32(f) => *f == 1.0,
797 ConstantValue::F64(f) => *f == 1.0,
798 _ => false,
799 }
800 }
801
802 fn zero_for_type(ty: &IrType) -> Option<ConstantValue> {
804 Some(match ty {
805 IrType::Scalar(ScalarType::I32) => ConstantValue::I32(0),
806 IrType::Scalar(ScalarType::U32) => ConstantValue::U32(0),
807 IrType::Scalar(ScalarType::I64) => ConstantValue::I64(0),
808 IrType::Scalar(ScalarType::U64) => ConstantValue::U64(0),
809 IrType::Scalar(ScalarType::F32) => ConstantValue::F32(0.0),
810 IrType::Scalar(ScalarType::F64) => ConstantValue::F64(0.0),
811 _ => return None,
812 })
813 }
814}
815
816impl Default for AlgebraicSimplification {
817 fn default() -> Self {
818 Self::new()
819 }
820}
821
822impl OptimizationPass for AlgebraicSimplification {
823 fn run(&self, module: &mut IrModule) -> OptimizationResult {
824 let mut result = OptimizationResult::unchanged();
825
826 let mut constants = HashMap::new();
828 for value in module.values.values() {
829 if let IrNode::Constant(ref c) = value.node {
830 constants.insert(value.id, c.clone());
831 }
832 }
833
834 for block in module.blocks.values_mut() {
836 for inst in &mut block.instructions {
837 let simplified = match &inst.node {
838 IrNode::BinaryOp(op, lhs, rhs) => {
839 let lhs_const = constants.get(lhs);
840 let rhs_const = constants.get(rhs);
841
842 match op {
843 BinaryOp::Add if rhs_const.is_some_and(Self::is_zero) => {
845 Some(IrNode::Parameter(0)) }
847 BinaryOp::Add if lhs_const.is_some_and(Self::is_zero) => {
849 Some(IrNode::Parameter(1))
850 }
851 BinaryOp::Mul if rhs_const.is_some_and(Self::is_one) => {
853 Some(IrNode::Parameter(0))
854 }
855 BinaryOp::Mul if lhs_const.is_some_and(Self::is_one) => {
857 Some(IrNode::Parameter(1))
858 }
859 BinaryOp::Mul
861 if rhs_const.is_some_and(Self::is_zero)
862 || lhs_const.is_some_and(Self::is_zero) =>
863 {
864 Self::zero_for_type(&inst.result_type).map(IrNode::Constant)
865 }
866 BinaryOp::Sub if rhs_const.is_some_and(Self::is_zero) => {
868 Some(IrNode::Parameter(0))
869 }
870 BinaryOp::Div if rhs_const.is_some_and(Self::is_one) => {
872 Some(IrNode::Parameter(0))
873 }
874 BinaryOp::And if rhs_const.is_some_and(Self::is_zero) => {
876 Self::zero_for_type(&inst.result_type).map(IrNode::Constant)
877 }
878 BinaryOp::Or if rhs_const.is_some_and(Self::is_zero) => {
880 Some(IrNode::Parameter(0))
881 }
882 BinaryOp::Xor if rhs_const.is_some_and(Self::is_zero) => {
884 Some(IrNode::Parameter(0))
885 }
886 _ => None,
887 }
888 }
889 _ => None,
890 };
891
892 if let Some(simplified_node) = simplified {
894 match simplified_node {
895 IrNode::Parameter(0) => {
896 }
901 IrNode::Parameter(1) => {
902 }
905 IrNode::Constant(c) => {
906 inst.node = IrNode::Constant(c.clone());
907 constants.insert(inst.result, c);
908 result.changed = true;
909 result.instructions_modified += 1;
910 }
911 _ => {}
912 }
913 }
914 }
915 }
916
917 result
918 }
919
920 fn name(&self) -> &'static str {
921 "algebraic-simplification"
922 }
923}
924
925pub struct PassManager {
931 passes: Vec<Box<dyn OptimizationPass>>,
932 max_iterations: usize,
933}
934
935impl PassManager {
936 pub fn new() -> Self {
938 Self {
939 passes: vec![
940 Box::new(ConstantFolding::new()),
941 Box::new(AlgebraicSimplification::new()),
942 Box::new(DeadCodeElimination::new()),
943 Box::new(DeadBlockElimination::new()),
944 ],
945 max_iterations: 10,
946 }
947 }
948
949 pub fn empty() -> Self {
951 Self {
952 passes: Vec::new(),
953 max_iterations: 10,
954 }
955 }
956
957 pub fn add_pass<P: OptimizationPass + 'static>(&mut self, pass: P) -> &mut Self {
959 self.passes.push(Box::new(pass));
960 self
961 }
962
963 pub fn max_iterations(&mut self, n: usize) -> &mut Self {
965 self.max_iterations = n;
966 self
967 }
968
969 pub fn run(&self, module: &mut IrModule) -> OptimizationResult {
971 let mut total_result = OptimizationResult::unchanged();
972
973 for iteration in 0..self.max_iterations {
974 let mut changed = false;
975
976 for pass in &self.passes {
977 let pass_result = pass.run(module);
978 changed |= pass_result.changed;
979 total_result.merge(pass_result);
980 }
981
982 if !changed {
983 break;
984 }
985
986 if iteration == self.max_iterations - 1 {
988 eprintln!(
989 "Warning: optimization reached max iterations ({})",
990 self.max_iterations
991 );
992 }
993 }
994
995 total_result
996 }
997}
998
999impl Default for PassManager {
1000 fn default() -> Self {
1001 Self::new()
1002 }
1003}
1004
1005pub fn optimize(module: &mut IrModule) -> OptimizationResult {
1011 PassManager::new().run(module)
1012}
1013
1014pub fn run_dce(module: &mut IrModule) -> OptimizationResult {
1016 DeadCodeElimination::new().run(module)
1017}
1018
1019pub fn run_constant_folding(module: &mut IrModule) -> OptimizationResult {
1021 ConstantFolding::new().run(module)
1022}
1023
1024#[cfg(test)]
1029mod tests {
1030 use super::*;
1031 use crate::IrBuilder;
1032
1033 #[test]
1034 fn test_dce_removes_unused() {
1035 let mut builder = IrBuilder::new("test");
1036
1037 let a = builder.const_i32(10);
1039 let b = builder.const_i32(20);
1040
1041 let _unused_sum = builder.add(a, b);
1043
1044 let c = builder.const_i32(5);
1046 let used = builder.mul(c, c);
1047
1048 builder.ret_value(used);
1050
1051 let mut module = builder.build();
1052
1053 let result = DeadCodeElimination::new().run(&mut module);
1054
1055 assert!(result.changed);
1057 assert!(result.instructions_removed > 0);
1058 }
1059
1060 #[test]
1061 fn test_constant_folding_binary() {
1062 let mut builder = IrBuilder::new("test");
1063
1064 let a = builder.const_i32(2);
1066 let b = builder.const_i32(3);
1067 let sum = builder.add(a, b);
1068
1069 builder.ret_value(sum);
1070
1071 let mut module = builder.build();
1072
1073 let result = ConstantFolding::new().run(&mut module);
1074
1075 assert!(result.changed);
1076 assert!(result.instructions_modified > 0);
1077 }
1078
1079 #[test]
1080 fn test_constant_folding_unary() {
1081 let mut builder = IrBuilder::new("test");
1082
1083 let a = builder.const_i32(5);
1085 let neg = builder.neg(a);
1086
1087 builder.ret_value(neg);
1088
1089 let mut module = builder.build();
1090
1091 let result = ConstantFolding::new().run(&mut module);
1092
1093 assert!(result.changed);
1094 }
1095
1096 #[test]
1097 fn test_pass_manager() {
1098 let mut builder = IrBuilder::new("test");
1099
1100 let a = builder.const_i32(2);
1102 let b = builder.const_i32(3);
1103 let sum = builder.add(a, b);
1104 let _unused = builder.const_i32(999);
1105
1106 builder.ret_value(sum);
1107
1108 let mut module = builder.build();
1109
1110 let result = PassManager::new().run(&mut module);
1111
1112 assert!(result.changed);
1113 }
1114
1115 #[test]
1116 fn test_optimization_result_merge() {
1117 let mut r1 = OptimizationResult {
1118 changed: true,
1119 instructions_removed: 5,
1120 instructions_modified: 3,
1121 blocks_removed: 1,
1122 };
1123
1124 let r2 = OptimizationResult {
1125 changed: false,
1126 instructions_removed: 2,
1127 instructions_modified: 1,
1128 blocks_removed: 0,
1129 };
1130
1131 r1.merge(r2);
1132
1133 assert!(r1.changed);
1134 assert_eq!(r1.instructions_removed, 7);
1135 assert_eq!(r1.instructions_modified, 4);
1136 assert_eq!(r1.blocks_removed, 1);
1137 }
1138}