wave_compiler/optimize/
dce.rs1use std::collections::HashSet;
11
12use super::pass::Pass;
13use crate::mir::function::MirFunction;
14use crate::mir::value::ValueId;
15
16pub struct Dce;
18
19impl Pass for Dce {
20 fn name(&self) -> &'static str {
21 "dce"
22 }
23
24 fn run(&self, func: &mut MirFunction) -> bool {
25 let used_values = collect_used_values(func);
26 let mut changed = false;
27
28 for block in &mut func.blocks {
29 let original_len = block.instructions.len();
30 block.instructions.retain(|inst| {
31 if inst.has_side_effects() {
32 return true;
33 }
34 match inst.dest() {
35 Some(dest) => used_values.contains(&dest),
36 None => true,
37 }
38 });
39 if block.instructions.len() != original_len {
40 changed = true;
41 }
42 }
43
44 changed
45 }
46}
47
48fn collect_used_values(func: &MirFunction) -> HashSet<ValueId> {
49 let mut used = HashSet::new();
50 for block in &func.blocks {
51 for phi in &block.phis {
52 for (_, val) in &phi.incoming {
53 used.insert(*val);
54 }
55 }
56 for inst in &block.instructions {
57 for operand in inst.operands() {
58 used.insert(operand);
59 }
60 }
61 for operand in block.terminator.operands() {
62 used.insert(operand);
63 }
64 }
65 used
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::hir::expr::BinOp;
72 use crate::mir::basic_block::{BasicBlock, Terminator};
73 use crate::mir::instruction::{ConstValue, MirInst};
74 use crate::mir::types::MirType;
75 use crate::mir::value::BlockId;
76
77 #[test]
78 fn test_dce_removes_dead_code() {
79 let mut func = MirFunction::new("test".into(), BlockId(0));
80 let mut bb = BasicBlock::new(BlockId(0));
81 bb.instructions.push(MirInst::Const {
82 dest: ValueId(0),
83 value: ConstValue::I32(42),
84 });
85 bb.instructions.push(MirInst::Const {
86 dest: ValueId(1),
87 value: ConstValue::I32(99),
88 });
89 bb.terminator = Terminator::Return;
90 func.blocks.push(bb);
91
92 let pass = Dce;
93 let changed = pass.run(&mut func);
94 assert!(changed);
95 assert!(func.blocks[0].instructions.is_empty());
96 }
97
98 #[test]
99 fn test_dce_preserves_used_values() {
100 let mut func = MirFunction::new("test".into(), BlockId(0));
101 let mut bb = BasicBlock::new(BlockId(0));
102 bb.instructions.push(MirInst::Const {
103 dest: ValueId(0),
104 value: ConstValue::I32(1),
105 });
106 bb.instructions.push(MirInst::Const {
107 dest: ValueId(1),
108 value: ConstValue::I32(2),
109 });
110 bb.instructions.push(MirInst::BinOp {
111 dest: ValueId(2),
112 op: BinOp::Add,
113 lhs: ValueId(0),
114 rhs: ValueId(1),
115 ty: MirType::I32,
116 });
117 bb.terminator = Terminator::CondBranch {
118 cond: ValueId(2),
119 true_target: BlockId(0),
120 false_target: BlockId(0),
121 };
122 func.blocks.push(bb);
123
124 let pass = Dce;
125 let changed = pass.run(&mut func);
126 assert!(!changed);
127 assert_eq!(func.blocks[0].instructions.len(), 3);
128 }
129}