llhd/pass/
dce.rs

1// Copyright (c) 2017-2021 Fabian Schuiki
2
3//! Dead Code Elimination
4
5use crate::{ir::prelude::*, opt::prelude::*};
6use std::collections::{HashMap, HashSet};
7
8/// Dead Code Elimination
9///
10/// This pass implements dead code elimination. It removes instructions whose
11/// value is never used, trivial blocks, and blocks which cannot be reached.
12pub struct DeadCodeElim;
13
14impl Pass for DeadCodeElim {
15    fn run_on_cfg(_ctx: &PassContext, unit: &mut UnitBuilder) -> bool {
16        info!("DCE [{}]", unit.name());
17        let mut modified = false;
18
19        // Gather a list of instructions and investigate which branches and
20        // blocks are trivial.
21        let mut insts = vec![];
22        let mut trivial_branches = HashMap::new();
23        let mut trivial_blocks = HashMap::new();
24        let entry = unit.entry();
25        for bb in unit.blocks() {
26            let term = unit.terminator(bb);
27            check_branch_trivial(unit, bb, term, &mut trivial_blocks, &mut trivial_branches);
28            for inst in unit.insts(bb) {
29                if inst != term {
30                    insts.push(inst);
31                }
32            }
33        }
34        check_block_retargetable(unit, entry, &mut trivial_blocks, &mut trivial_branches);
35        trace!("Trivial Blocks: {:?}", trivial_blocks);
36        trace!("Trivial Branches: {:?}", trivial_branches);
37
38        // Simplify trivial branches.
39        for (inst, target) in trivial_branches
40            .into_iter()
41            .flat_map(|(i, t)| t.map(|t| (i, t)))
42        {
43            if unit[inst].opcode() == Opcode::Br && unit[inst].blocks() == [target] {
44                continue;
45            }
46            debug!(
47                "Replacing {} with br {}",
48                inst.dump(&unit),
49                target.dump(&unit)
50            );
51            unit.insert_before(inst);
52            unit.ins().br(target);
53            unit.delete_inst(inst);
54            modified |= true;
55        }
56
57        // Replace trivial blocks.
58        for (from, to) in trivial_blocks
59            .into_iter()
60            .flat_map(|(b, w)| w.map(|w| (b, w)))
61            .filter(|(from, to)| from != to)
62        {
63            debug!(
64                "Replacing trivial block {} with {}",
65                from.dump(&unit),
66                to.dump(&unit)
67            );
68            unit.replace_block_use(from, to);
69            // If this is the entry block, hoist the target up as the first block.
70            if from == entry {
71                unit.swap_blocks(from, to);
72            }
73            modified |= true;
74        }
75
76        // Prune instructions and unreachable blocks.
77        for inst in insts {
78            modified |= unit.prune_if_unused(inst);
79        }
80        modified |= prune_blocks(unit);
81
82        // Detect trivially sequential blocks. We use a temporal predecessor
83        // table here to avoid treating wait instructions as branches.
84        let pt = unit.temporal_predtbl();
85        let mut merge_blocks = Vec::new();
86        let mut already_merged = HashMap::new();
87        for bb in unit.blocks().filter(|&bb| bb != entry) {
88            let preds = pt.pred_set(bb);
89            if preds.len() == 1 {
90                let pred = preds.iter().cloned().next().unwrap();
91                if pt.is_sole_succ(bb, pred) {
92                    let into = already_merged.get(&pred).cloned().unwrap_or(pred);
93                    merge_blocks.push((bb, into));
94                    already_merged.insert(bb, into);
95                }
96            }
97        }
98
99        // Concatenate trivially sequential blocks.
100        for (block, into) in merge_blocks {
101            debug!("Merge {} into {}", block.dump(&unit), into.dump(&unit));
102            let term = unit.terminator(into);
103            while let Some(inst) = unit.first_inst(block) {
104                unit.remove_inst(inst);
105                // Do not migrate phi nodes, which at this point have only the
106                // `into` block as predecessor and can be trivially replaced.
107                if unit[inst].opcode() == Opcode::Phi {
108                    assert_eq!(
109                        unit[inst].blocks(),
110                        &[into],
111                        "Phi node must be trivially removable"
112                    );
113                    let phi = unit.inst_result(inst);
114                    let repl = unit[inst].args()[0];
115                    unit.replace_use(phi, repl);
116                } else {
117                    unit.insert_inst_before(inst, term);
118                }
119            }
120            unit.remove_inst(term);
121            unit.replace_block_use(block, into);
122            unit.delete_block(block);
123        }
124
125        modified
126    }
127}
128
129/// Check if a branch that terminates a block is trivial.
130fn check_branch_trivial(
131    unit: &UnitBuilder,
132    _block: Block,
133    inst: Inst,
134    triv_bb: &mut HashMap<Block, Option<Block>>,
135    triv_br: &mut HashMap<Inst, Option<Block>>,
136) -> Option<Block> {
137    // Insert a sentinel value to avoid recursion.
138    if let Some(&entry) = triv_br.get(&inst) {
139        return entry;
140    }
141    triv_br.insert(inst, None);
142    trace!("Checking if trivial {}", inst.dump(&unit));
143
144    // Now we know the block is empty. Check for a few common cases of trivial
145    // branches.
146    let data = &unit[inst];
147    let target = match data.opcode() {
148        Opcode::Br => {
149            let bb = data.blocks()[0];
150            check_block_retargetable(unit, bb, triv_bb, triv_br)
151        }
152        Opcode::BrCond => {
153            let arg = data.args()[0];
154            let bbs = data.blocks();
155            let bbs: Vec<_> = bbs
156                .iter()
157                .map(|&bb| check_block_retargetable(unit, bb, triv_bb, triv_br))
158                .collect();
159            if let Some(imm) = unit.get_const_int(arg) {
160                bbs[!imm.is_zero() as usize]
161            } else if bbs[0] == bbs[1] {
162                bbs[0]
163            } else {
164                None
165            }
166        }
167        _ => None,
168    };
169    triv_br.insert(inst, target);
170    target
171}
172
173/// Check if a block can be trivially addressed from a different block, and if
174/// so, return a potential immediate forward through the block if trivial.
175fn check_block_retargetable(
176    unit: &UnitBuilder,
177    block: Block,
178    triv_bb: &mut HashMap<Block, Option<Block>>,
179    triv_br: &mut HashMap<Inst, Option<Block>>,
180) -> Option<Block> {
181    trace!("Checking if trivial {}", block.dump(&unit));
182
183    // Check that there are no phi nodes on the target block.
184    if unit.insts(block).any(|inst| unit[inst].opcode().is_phi()) {
185        triv_bb.insert(block, None);
186        return None;
187    }
188
189    // If the block is not trivially empty, it is retargetable but cannot be
190    // "jumped through".
191    if unit.first_inst(block) != unit.last_inst(block) {
192        triv_bb.insert(block, Some(block));
193        return Some(block);
194    }
195
196    // Dig up the terminator instruction and potentially resolve the target to
197    // its trivial successor.
198    let inst = unit.terminator(block);
199    let target = Some(check_branch_trivial(unit, block, inst, triv_bb, triv_br).unwrap_or(block));
200    triv_bb.insert(block, target);
201    target
202}
203
204/// Eliminate unreachable and trivial blocks in a function layout.
205fn prune_blocks(unit: &mut UnitBuilder) -> bool {
206    let mut modified = false;
207
208    // Find all blocks reachable from the entry point.
209    let first_bb = unit.first_block().unwrap();
210    let mut unreachable: HashSet<Block> = unit.blocks().collect();
211    let mut todo: Vec<Block> = Default::default();
212    todo.push(first_bb);
213    unreachable.remove(&first_bb);
214    while let Some(block) = todo.pop() {
215        let term_inst = unit.terminator(block);
216        for &bb in unit[term_inst].blocks() {
217            if unreachable.remove(&bb) {
218                todo.push(bb);
219            }
220        }
221    }
222
223    // Remove all unreachable blocks.
224    for bb in unreachable {
225        debug!("Prune unreachable block {}", bb.dump(&unit));
226        modified |= true;
227        unit.delete_block(bb);
228    }
229
230    modified
231}