1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use cfgraph::*;
use prelude::{BTreeSet, BTreeMap};

pub struct RemoveDeadBasicBlocks;

impl Optimizer for RemoveDeadBasicBlocks {
    type Return = ();

    fn optimize(&self, cfg: &mut CFGraph) -> OptimizeResult<()> {
        if cfg.blocks.len() == 0 {
            return Ok(());
        }

        let mut reachable: BTreeSet<BlockId> = BTreeSet::new();

        // Perform a depth-first search on the CFG to figure out reachable blocks.
        {
            let mut dfs_stack: Vec<BlockId> = vec! [ BlockId(0) ];

            while let Some(blk_id) = dfs_stack.pop() {
                if reachable.contains(&blk_id) {
                    continue;
                }

                reachable.insert(blk_id);

                let blk = &cfg.blocks[blk_id.0];
                match *blk.br.as_ref().unwrap() {
                    Branch::Jmp(t) => {
                        dfs_stack.push(t);
                    },
                    Branch::JmpEither(a, b) => {
                        dfs_stack.push(a);
                        dfs_stack.push(b);
                    },
                    Branch::JmpTable(ref targets, otherwise) => {
                        for t in targets {
                            dfs_stack.push(*t);
                        }
                        dfs_stack.push(otherwise);
                    },
                    Branch::Return => {}
                }
            }
        }

        // Maps old block ids to new ones.
        let mut block_id_mappings: BTreeMap<BlockId, BlockId> = BTreeMap::new();

        // Reachable basic blocks
        let mut new_basic_blocks = Vec::with_capacity(reachable.len());

        {
            // Old basic blocks
            let mut old_basic_blocks = ::prelude::mem::replace(&mut cfg.blocks, Vec::new());

            // reachable is a Set so blk_id will never duplicate.
            for (i, blk_id) in reachable.iter().enumerate() {
                block_id_mappings.insert(*blk_id, BlockId(i));
                new_basic_blocks.push(
                    ::prelude::mem::replace(
                        &mut old_basic_blocks[blk_id.0],
                        BasicBlock::new()
                    )
                );
            }
        }

        for bb in &mut new_basic_blocks {
            let old_br = bb.br.take().unwrap();
            bb.br = Some(match old_br {
                Branch::Jmp(id) => Branch::Jmp(*block_id_mappings.get(&id).unwrap()),
                Branch::JmpEither(a, b) => Branch::JmpEither(
                    *block_id_mappings.get(&a).unwrap(),
                    *block_id_mappings.get(&b).unwrap()
                ),
                Branch::JmpTable(targets, otherwise) => Branch::JmpTable(
                    targets.into_iter().map(|t| *block_id_mappings.get(&t).unwrap()).collect(),
                    *block_id_mappings.get(&otherwise).unwrap()
                ),
                Branch::Return => Branch::Return
            });
        }

        cfg.blocks = new_basic_blocks;

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use opcode::Opcode;

    #[test]
    fn test_remove_dead_basic_blocks() {
        let opcodes: Vec<Opcode> = vec! [
            // bb 0
            Opcode::I32Const(100), // 0
            Opcode::Jmp(3), // 1
            // bb 1, never reached
            Opcode::I32Const(50), // 2
            // bb 2 (due to jmp)
            Opcode::I32Const(25), // 3
            Opcode::JmpIf(0), // 4
            // bb 3
            Opcode::Return // 5
        ];

        let mut cfg = CFGraph::from_function(opcodes.as_slice()).unwrap();
        cfg.validate().unwrap();
        cfg.optimize(RemoveDeadBasicBlocks).unwrap();
        cfg.validate().unwrap();

        assert_eq!(cfg.blocks.len(), 3);
        assert_eq!(cfg.blocks[0].br, Some(Branch::Jmp(BlockId(1))));
        assert_eq!(cfg.blocks[1].br, Some(Branch::JmpEither(BlockId(0), BlockId(2))));
        assert_eq!(cfg.blocks[2].br, Some(Branch::Return));

        eprintln!("{:?}", cfg);

        eprintln!("{:?}", cfg.gen_opcodes());
    }
}