use crate::{ir::prelude::*, opt::prelude::*};
use std::collections::{HashMap, HashSet};
pub struct DeadCodeElim;
impl Pass for DeadCodeElim {
fn run_on_cfg(_ctx: &PassContext, unit: &mut UnitBuilder) -> bool {
info!("DCE [{}]", unit.name());
let mut modified = false;
let mut insts = vec![];
let mut trivial_branches = HashMap::new();
let mut trivial_blocks = HashMap::new();
let entry = unit.entry();
for bb in unit.blocks() {
let term = unit.terminator(bb);
check_branch_trivial(unit, bb, term, &mut trivial_blocks, &mut trivial_branches);
for inst in unit.insts(bb) {
if inst != term {
insts.push(inst);
}
}
}
check_block_retargetable(unit, entry, &mut trivial_blocks, &mut trivial_branches);
trace!("Trivial Blocks: {:?}", trivial_blocks);
trace!("Trivial Branches: {:?}", trivial_branches);
for (inst, target) in trivial_branches
.into_iter()
.flat_map(|(i, t)| t.map(|t| (i, t)))
{
if unit[inst].opcode() == Opcode::Br && unit[inst].blocks() == [target] {
continue;
}
debug!(
"Replacing {} with br {}",
inst.dump(&unit),
target.dump(&unit)
);
unit.insert_before(inst);
unit.ins().br(target);
unit.delete_inst(inst);
modified |= true;
}
for (from, to) in trivial_blocks
.into_iter()
.flat_map(|(b, w)| w.map(|w| (b, w)))
.filter(|(from, to)| from != to)
{
debug!(
"Replacing trivial block {} with {}",
from.dump(&unit),
to.dump(&unit)
);
unit.replace_block_use(from, to);
if from == entry {
unit.swap_blocks(from, to);
}
modified |= true;
}
for inst in insts {
modified |= unit.prune_if_unused(inst);
}
modified |= prune_blocks(unit);
let pt = unit.temporal_predtbl();
let mut merge_blocks = Vec::new();
let mut already_merged = HashMap::new();
for bb in unit.blocks().filter(|&bb| bb != entry) {
let preds = pt.pred_set(bb);
if preds.len() == 1 {
let pred = preds.iter().cloned().next().unwrap();
if pt.is_sole_succ(bb, pred) {
let into = already_merged.get(&pred).cloned().unwrap_or(pred);
merge_blocks.push((bb, into));
already_merged.insert(bb, into);
}
}
}
for (block, into) in merge_blocks {
debug!("Merge {} into {}", block.dump(&unit), into.dump(&unit));
let term = unit.terminator(into);
while let Some(inst) = unit.first_inst(block) {
unit.remove_inst(inst);
if unit[inst].opcode() == Opcode::Phi {
assert_eq!(
unit[inst].blocks(),
&[into],
"Phi node must be trivially removable"
);
let phi = unit.inst_result(inst);
let repl = unit[inst].args()[0];
unit.replace_use(phi, repl);
} else {
unit.insert_inst_before(inst, term);
}
}
unit.remove_inst(term);
unit.replace_block_use(block, into);
unit.delete_block(block);
}
modified
}
}
fn check_branch_trivial(
unit: &UnitBuilder,
_block: Block,
inst: Inst,
triv_bb: &mut HashMap<Block, Option<Block>>,
triv_br: &mut HashMap<Inst, Option<Block>>,
) -> Option<Block> {
if let Some(&entry) = triv_br.get(&inst) {
return entry;
}
triv_br.insert(inst, None);
trace!("Checking if trivial {}", inst.dump(&unit));
let data = &unit[inst];
let target = match data.opcode() {
Opcode::Br => {
let bb = data.blocks()[0];
check_block_retargetable(unit, bb, triv_bb, triv_br)
}
Opcode::BrCond => {
let arg = data.args()[0];
let bbs = data.blocks();
let bbs: Vec<_> = bbs
.iter()
.map(|&bb| check_block_retargetable(unit, bb, triv_bb, triv_br))
.collect();
if let Some(imm) = unit.get_const_int(arg) {
bbs[!imm.is_zero() as usize]
} else if bbs[0] == bbs[1] {
bbs[0]
} else {
None
}
}
_ => None,
};
triv_br.insert(inst, target);
target
}
fn check_block_retargetable(
unit: &UnitBuilder,
block: Block,
triv_bb: &mut HashMap<Block, Option<Block>>,
triv_br: &mut HashMap<Inst, Option<Block>>,
) -> Option<Block> {
trace!("Checking if trivial {}", block.dump(&unit));
if unit.insts(block).any(|inst| unit[inst].opcode().is_phi()) {
triv_bb.insert(block, None);
return None;
}
if unit.first_inst(block) != unit.last_inst(block) {
triv_bb.insert(block, Some(block));
return Some(block);
}
let inst = unit.terminator(block);
let target = Some(check_branch_trivial(unit, block, inst, triv_bb, triv_br).unwrap_or(block));
triv_bb.insert(block, target);
target
}
fn prune_blocks(unit: &mut UnitBuilder) -> bool {
let mut modified = false;
let first_bb = unit.first_block().unwrap();
let mut unreachable: HashSet<Block> = unit.blocks().collect();
let mut todo: Vec<Block> = Default::default();
todo.push(first_bb);
unreachable.remove(&first_bb);
while let Some(block) = todo.pop() {
let term_inst = unit.terminator(block);
for &bb in unit[term_inst].blocks() {
if unreachable.remove(&bb) {
todo.push(bb);
}
}
}
for bb in unreachable {
debug!("Prune unreachable block {}", bb.dump(&unit));
modified |= true;
unit.delete_block(bb);
}
modified
}