Skip to main content

wave_compiler/optimize/
dce.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Dead code elimination pass.
5//!
6//! Marks all instructions whose results are used. Removes unmarked
7//! instructions that have no side effects (stores, barriers, atomics
8//! are always considered live).
9
10use std::collections::HashSet;
11
12use super::pass::Pass;
13use crate::mir::function::MirFunction;
14use crate::mir::value::ValueId;
15
16/// Dead code elimination pass.
17pub struct Dce;
18
19impl Pass for Dce {
20    fn name(&self) -> &'static str {
21        "dce"
22    }
23
24    fn run(&self, func: &mut MirFunction) -> bool {
25        let used_values = collect_used_values(func);
26        let mut changed = false;
27
28        for block in &mut func.blocks {
29            let original_len = block.instructions.len();
30            block.instructions.retain(|inst| {
31                if inst.has_side_effects() {
32                    return true;
33                }
34                match inst.dest() {
35                    Some(dest) => used_values.contains(&dest),
36                    None => true,
37                }
38            });
39            if block.instructions.len() != original_len {
40                changed = true;
41            }
42        }
43
44        changed
45    }
46}
47
48fn collect_used_values(func: &MirFunction) -> HashSet<ValueId> {
49    let mut used = HashSet::new();
50    for block in &func.blocks {
51        for phi in &block.phis {
52            for (_, val) in &phi.incoming {
53                used.insert(*val);
54            }
55        }
56        for inst in &block.instructions {
57            for operand in inst.operands() {
58                used.insert(operand);
59            }
60        }
61        for operand in block.terminator.operands() {
62            used.insert(operand);
63        }
64    }
65    used
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::hir::expr::BinOp;
72    use crate::mir::basic_block::{BasicBlock, Terminator};
73    use crate::mir::instruction::{ConstValue, MirInst};
74    use crate::mir::types::MirType;
75    use crate::mir::value::BlockId;
76
77    #[test]
78    fn test_dce_removes_dead_code() {
79        let mut func = MirFunction::new("test".into(), BlockId(0));
80        let mut bb = BasicBlock::new(BlockId(0));
81        bb.instructions.push(MirInst::Const {
82            dest: ValueId(0),
83            value: ConstValue::I32(42),
84        });
85        bb.instructions.push(MirInst::Const {
86            dest: ValueId(1),
87            value: ConstValue::I32(99),
88        });
89        bb.terminator = Terminator::Return;
90        func.blocks.push(bb);
91
92        let pass = Dce;
93        let changed = pass.run(&mut func);
94        assert!(changed);
95        assert!(func.blocks[0].instructions.is_empty());
96    }
97
98    #[test]
99    fn test_dce_preserves_used_values() {
100        let mut func = MirFunction::new("test".into(), BlockId(0));
101        let mut bb = BasicBlock::new(BlockId(0));
102        bb.instructions.push(MirInst::Const {
103            dest: ValueId(0),
104            value: ConstValue::I32(1),
105        });
106        bb.instructions.push(MirInst::Const {
107            dest: ValueId(1),
108            value: ConstValue::I32(2),
109        });
110        bb.instructions.push(MirInst::BinOp {
111            dest: ValueId(2),
112            op: BinOp::Add,
113            lhs: ValueId(0),
114            rhs: ValueId(1),
115            ty: MirType::I32,
116        });
117        bb.terminator = Terminator::CondBranch {
118            cond: ValueId(2),
119            true_target: BlockId(0),
120            false_target: BlockId(0),
121        };
122        func.blocks.push(bb);
123
124        let pass = Dce;
125        let changed = pass.run(&mut func);
126        assert!(!changed);
127        assert_eq!(func.blocks[0].instructions.len(), 3);
128    }
129}