use super::super::{Block, Function, Instruction, Register};
use super::{Transform, TransformCategory, TransformLevel};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Default)]
pub struct DeadCodeStats {
pub instructions_removed: usize,
pub registers_freed: usize,
}
#[derive(Default)]
pub struct DeadCodeElimination;
impl Transform for DeadCodeElimination {
fn name(&self) -> &'static str {
"dead_code_elimination"
}
fn description(&self) -> &'static str {
"Removes instructions that define registers which are never used"
}
fn category(&self) -> TransformCategory {
TransformCategory::DeadCodeElimination
}
fn level(&self) -> TransformLevel {
TransformLevel::Stable
}
fn apply(&self, func: &mut Function) -> Result<bool, String> {
self.apply_internal(func)
.map(|stats| stats.instructions_removed > 0)
}
}
impl DeadCodeElimination {
pub fn apply_internal(&self, func: &mut Function) -> Result<DeadCodeStats, String> {
let mut stats = DeadCodeStats::default();
let live_out_map = self.compute_liveness(func)?;
for block in &mut func.blocks {
let mut live_regs = live_out_map.get(&block.label).cloned().unwrap_or_default();
let removed = self.remove_dead_instructions_in_block(block, &mut live_regs);
stats.instructions_removed += removed;
}
Ok(stats)
}
fn compute_liveness(
&self,
func: &Function,
) -> Result<HashMap<String, HashSet<Register>>, String> {
let mut live_in: HashMap<String, HashSet<Register>> = HashMap::new();
let mut live_out: HashMap<String, HashSet<Register>> = HashMap::new();
for block in &func.blocks {
live_in.insert(block.label.clone(), HashSet::new());
live_out.insert(block.label.clone(), HashSet::new());
}
let mut changed = true;
let mut iterations = 0;
const MAX_ITERATIONS: usize = 1000;
while changed {
if iterations > MAX_ITERATIONS {
return Err("Liveness analysis failed to converge".to_string());
}
iterations += 1;
changed = false;
for block in func.blocks.iter().rev() {
let label = &block.label;
let mut current_live_out = HashSet::new();
if let Some(terminator) = block.instructions.last() {
match terminator {
Instruction::Jmp { target } => {
if let Some(succ_live_in) = live_in.get(target) {
current_live_out.extend(succ_live_in.iter().cloned());
}
}
Instruction::Br {
true_target,
false_target,
..
} => {
if let Some(succ_live_in) = live_in.get(true_target) {
current_live_out.extend(succ_live_in.iter().cloned());
}
if let Some(succ_live_in) = live_in.get(false_target) {
current_live_out.extend(succ_live_in.iter().cloned());
}
}
Instruction::Switch { cases, default, .. } => {
if let Some(succ_live_in) = live_in.get(default) {
current_live_out.extend(succ_live_in.iter().cloned());
}
for (_, case_target) in cases {
if let Some(succ_live_in) = live_in.get(case_target) {
current_live_out.extend(succ_live_in.iter().cloned());
}
}
}
_ => {} }
}
let prev_live_out = live_out.get(label)
.ok_or_else(|| format!("Block '{}' not found in live_out map - internal error in liveness analysis", label))?;
if current_live_out != *prev_live_out {
live_out.insert(label.clone(), current_live_out.clone());
changed = true;
}
let mut current_live_in = current_live_out.clone();
for instr in block.instructions.iter().rev() {
if let Some(def) = instr.def_reg() {
current_live_in.remove(def);
}
for use_reg in instr.use_regs() {
current_live_in.insert(use_reg.clone());
}
}
let prev_live_in = live_in.get(label).ok_or_else(|| {
format!(
"Block '{}' not found in live_in map - internal error in liveness analysis",
label
)
})?;
if current_live_in != *prev_live_in {
live_in.insert(label.clone(), current_live_in);
changed = true;
}
}
}
Ok(live_out)
}
fn remove_dead_instructions_in_block(
&self,
block: &mut Block,
live_regs: &mut HashSet<Register>,
) -> usize {
let mut removed_count = 0;
let mut instructions_to_keep = Vec::new();
for instr in block.instructions.iter().rev() {
if self.is_dead_instruction(instr, live_regs) {
removed_count += 1;
} else {
instructions_to_keep.push(instr.clone());
if let Some(def_reg) = instr.def_reg() {
live_regs.remove(def_reg);
}
for use_reg in instr.use_regs() {
live_regs.insert(use_reg.clone());
}
}
}
if removed_count > 0 {
instructions_to_keep.reverse();
block.instructions = instructions_to_keep;
}
removed_count
}
fn is_dead_instruction(&self, instr: &Instruction, live_regs: &HashSet<Register>) -> bool {
if instr.is_terminator() {
return false;
}
if let Some(def_reg) = instr.def_reg() {
if !live_regs.contains(def_reg) {
return self.has_no_side_effects(instr);
}
} else {
return self.has_no_side_effects(instr);
}
false
}
fn has_no_side_effects(&self, instr: &Instruction) -> bool {
match instr {
Instruction::IntBinary { .. }
| Instruction::FloatBinary { .. }
| Instruction::FloatUnary { .. }
| Instruction::IntCmp { .. }
| Instruction::FloatCmp { .. }
| Instruction::Select { .. }
| Instruction::VectorOp { .. }
| Instruction::Lea { .. } => true,
#[cfg(feature = "nightly")]
Instruction::SimdBinary { .. }
| Instruction::SimdUnary { .. }
| Instruction::SimdTernary { .. }
| Instruction::SimdShuffle { .. }
| Instruction::SimdExtract { .. }
| Instruction::SimdInsert { .. } => true,
Instruction::Load { .. } | Instruction::Store { .. }
| Instruction::Call { .. }
| Instruction::TailCall { .. }
| Instruction::Ret { .. }
| Instruction::Jmp { .. }
| Instruction::Br { .. }
| Instruction::Switch { .. }
| Instruction::Unreachable
| Instruction::SafePoint
| Instruction::StackMap { .. }
| Instruction::PatchPoint { .. }
| Instruction::Comment { .. } => false,
#[cfg(feature = "nightly")]
Instruction::SimdLoad { .. }
| Instruction::SimdStore { .. }
| Instruction::AtomicLoad { .. }
| Instruction::AtomicStore { .. }
| Instruction::AtomicBinary { .. }
| Instruction::AtomicCompareExchange { .. }
| Instruction::Fence { .. } => false,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::mir::{
FunctionBuilder, Immediate, IntBinOp, MirType, Operand, ScalarType, VirtualReg,
};
#[test]
fn test_dead_code_elimination_basic() {
let func = FunctionBuilder::new("test")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(42)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(10)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(2).into())),
})
.build();
let mut func = func;
let dce = DeadCodeElimination;
let changed = dce.apply(&mut func).expect("DCE should succeed");
assert!(changed);
let entry = func.get_block("entry").expect("entry block exists");
assert_eq!(entry.instructions.len(), 2);
match &entry.instructions[0] {
Instruction::IntBinary { dst, .. } => {
assert_eq!(dst, &VirtualReg::gpr(2).into());
}
_ => panic!("Expected IntBinary"),
}
}
#[test]
fn test_dce_empty_function() {
let mut func = FunctionBuilder::new("empty")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Ret { value: None })
.build();
let dce = DeadCodeElimination;
let result = dce.apply(&mut func);
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn test_dce_single_block_all_live() {
let mut func = FunctionBuilder::new("all_live")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(1).into())),
})
.build();
let dce = DeadCodeElimination;
let changed = dce.apply(&mut func).expect("should succeed");
assert!(!changed);
assert_eq!(func.blocks[0].instructions.len(), 2);
}
#[test]
fn test_dce_preserves_terminators() {
let mut func = FunctionBuilder::new("terminators")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Jmp {
target: "exit".to_string(),
})
.block("exit")
.instr(Instruction::Ret { value: None })
.build();
let dce = DeadCodeElimination;
let result = dce.apply(&mut func);
assert!(result.is_ok());
let entry = func.get_block("entry").unwrap();
assert!(matches!(
entry.instructions.last(),
Some(Instruction::Jmp { .. })
));
}
#[test]
fn test_dce_preserves_side_effects() {
let mut func = FunctionBuilder::new("side_effects")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Store {
ty: MirType::Scalar(ScalarType::I64),
src: Operand::Immediate(Immediate::I64(42)),
addr: crate::mir::AddressMode::BaseOffset {
base: VirtualReg::gpr(0).into(),
offset: 0,
},
attrs: crate::mir::MemoryAttrs::default(),
})
.instr(Instruction::Call {
name: "print".to_string(),
args: vec![Operand::Immediate(Immediate::I64(1))],
ret: None,
})
.instr(Instruction::Ret {
value: Some(Operand::Immediate(Immediate::I64(0))),
})
.build();
let dce = DeadCodeElimination;
let changed = dce.apply(&mut func).expect("should succeed");
assert!(!changed);
assert_eq!(func.blocks[0].instructions.len(), 3);
}
#[test]
fn test_dce_chain_of_dead_instructions() {
let mut func = FunctionBuilder::new("dead_chain")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(2)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(3).into(),
lhs: Operand::Register(VirtualReg::gpr(2).into()),
rhs: Operand::Immediate(Immediate::I64(3)),
})
.instr(Instruction::Ret {
value: Some(Operand::Immediate(Immediate::I64(99))),
})
.build();
let dce = DeadCodeElimination;
let changed = dce.apply(&mut func).expect("should succeed");
assert!(changed);
assert_eq!(func.blocks[0].instructions.len(), 1);
}
#[test]
fn test_dce_multi_block_liveness() {
let mut func = FunctionBuilder::new("multi_block")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(10)),
})
.instr(Instruction::Jmp {
target: "exit".to_string(),
})
.block("exit")
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(1).into())),
})
.build();
let dce = DeadCodeElimination;
let changed = dce.apply(&mut func).expect("should succeed");
assert!(!changed);
let entry = func.get_block("entry").unwrap();
assert_eq!(entry.instructions.len(), 2);
}
#[test]
fn test_dce_loop_back_edge_liveness() {
let mut func = FunctionBuilder::new("loop_liveness")
.param(VirtualReg::gpr(0).into(), MirType::Scalar(ScalarType::I64))
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Jmp {
target: "loop".to_string(),
})
.block("loop")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::Br {
cond: VirtualReg::gpr(0).into(),
true_target: "loop".to_string(),
false_target: "exit".to_string(),
})
.block("exit")
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
})
.build();
let dce = DeadCodeElimination;
let result = dce.apply(&mut func);
assert!(result.is_ok());
let loop_block = func.get_block("loop").unwrap();
assert_eq!(loop_block.instructions.len(), 2);
}
#[test]
fn test_dce_does_not_infinite_loop() {
let mut func = FunctionBuilder::new("stress")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.build();
for i in 0..500 {
func.blocks[0].instructions.insert(
0,
Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(i + 1).into(),
lhs: Operand::Immediate(Immediate::I64(i as i64)),
rhs: Operand::Immediate(Immediate::I64(1)),
},
);
}
func.blocks[0].instructions.push(Instruction::Ret {
value: Some(Operand::Immediate(Immediate::I64(0))),
});
let dce = DeadCodeElimination;
let result = dce.apply(&mut func);
assert!(result.is_ok());
assert_eq!(func.blocks[0].instructions.len(), 1);
}
}