wave_compiler/analysis/
dominance.rs1use std::collections::{HashMap, HashSet};
10
11use super::cfg::Cfg;
12use crate::mir::value::BlockId;
13
14pub struct DomTree {
16 pub idom: HashMap<BlockId, BlockId>,
18 pub children: HashMap<BlockId, Vec<BlockId>>,
20 pub frontiers: HashMap<BlockId, HashSet<BlockId>>,
22}
23
24impl DomTree {
25 #[must_use]
27 pub fn compute(cfg: &Cfg) -> Self {
28 let rpo = cfg.reverse_postorder();
29 let mut block_to_rpo: HashMap<BlockId, usize> = HashMap::new();
30 for (i, &bid) in rpo.iter().enumerate() {
31 block_to_rpo.insert(bid, i);
32 }
33
34 let mut idom: HashMap<BlockId, BlockId> = HashMap::new();
35 idom.insert(cfg.entry, cfg.entry);
36
37 let mut changed = true;
38 while changed {
39 changed = false;
40 for &bid in &rpo {
41 if bid == cfg.entry {
42 continue;
43 }
44 let preds = cfg.preds(bid);
45 let mut new_idom: Option<BlockId> = None;
46
47 for &pred in preds {
48 if idom.contains_key(&pred) {
49 new_idom = Some(match new_idom {
50 None => pred,
51 Some(current) => intersect(current, pred, &idom, &block_to_rpo),
52 });
53 }
54 }
55
56 if let Some(new_id) = new_idom {
57 if idom.get(&bid) != Some(&new_id) {
58 idom.insert(bid, new_id);
59 changed = true;
60 }
61 }
62 }
63 }
64
65 let mut children: HashMap<BlockId, Vec<BlockId>> = HashMap::new();
66 for &bid in &cfg.blocks {
67 children.entry(bid).or_default();
68 }
69 for (&bid, &dom) in &idom {
70 if bid != dom {
71 children.entry(dom).or_default().push(bid);
72 }
73 }
74
75 let frontiers = compute_frontiers(cfg, &idom);
76
77 Self {
78 idom,
79 children,
80 frontiers,
81 }
82 }
83
84 #[must_use]
86 pub fn dominates(&self, a: BlockId, b: BlockId) -> bool {
87 if a == b {
88 return true;
89 }
90 let mut current = b;
91 loop {
92 match self.idom.get(¤t) {
93 Some(&dom) if dom == current => return false,
94 Some(&dom) if dom == a => return true,
95 Some(&dom) => current = dom,
96 None => return false,
97 }
98 }
99 }
100
101 #[must_use]
103 pub fn frontier(&self, block: BlockId) -> HashSet<BlockId> {
104 self.frontiers.get(&block).cloned().unwrap_or_default()
105 }
106}
107
108fn intersect(
109 mut a: BlockId,
110 mut b: BlockId,
111 idom: &HashMap<BlockId, BlockId>,
112 rpo: &HashMap<BlockId, usize>,
113) -> BlockId {
114 while a != b {
115 let a_rpo = rpo.get(&a).copied().unwrap_or(usize::MAX);
116 let b_rpo = rpo.get(&b).copied().unwrap_or(usize::MAX);
117 if a_rpo > b_rpo {
118 a = *idom.get(&a).unwrap_or(&a);
119 } else {
120 b = *idom.get(&b).unwrap_or(&b);
121 }
122 }
123 a
124}
125
126fn compute_frontiers(
127 cfg: &Cfg,
128 idom: &HashMap<BlockId, BlockId>,
129) -> HashMap<BlockId, HashSet<BlockId>> {
130 let mut frontiers: HashMap<BlockId, HashSet<BlockId>> = HashMap::new();
131 for &bid in &cfg.blocks {
132 frontiers.entry(bid).or_default();
133 }
134
135 for &bid in &cfg.blocks {
136 let preds = cfg.preds(bid);
137 if preds.len() >= 2 {
138 for &pred in preds {
139 let mut runner = pred;
140 while runner != *idom.get(&bid).unwrap_or(&bid) {
141 frontiers.entry(runner).or_default().insert(bid);
142 if runner == *idom.get(&runner).unwrap_or(&runner) {
143 break;
144 }
145 runner = *idom.get(&runner).unwrap_or(&runner);
146 }
147 }
148 }
149 }
150
151 frontiers
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::mir::basic_block::{BasicBlock, Terminator};
158 use crate::mir::function::MirFunction;
159 use crate::mir::value::ValueId;
160
161 fn make_diamond() -> MirFunction {
162 let mut func = MirFunction::new("test".into(), BlockId(0));
163 let mut bb0 = BasicBlock::new(BlockId(0));
164 bb0.terminator = Terminator::CondBranch {
165 cond: ValueId(0),
166 true_target: BlockId(1),
167 false_target: BlockId(2),
168 };
169 let mut bb1 = BasicBlock::new(BlockId(1));
170 bb1.terminator = Terminator::Branch { target: BlockId(3) };
171 let mut bb2 = BasicBlock::new(BlockId(2));
172 bb2.terminator = Terminator::Branch { target: BlockId(3) };
173 let bb3 = BasicBlock::new(BlockId(3));
174 func.blocks.push(bb0);
175 func.blocks.push(bb1);
176 func.blocks.push(bb2);
177 func.blocks.push(bb3);
178 func
179 }
180
181 #[test]
182 fn test_dominator_tree() {
183 let func = make_diamond();
184 let cfg = Cfg::build(&func);
185 let dom = DomTree::compute(&cfg);
186
187 assert!(dom.dominates(BlockId(0), BlockId(1)));
188 assert!(dom.dominates(BlockId(0), BlockId(2)));
189 assert!(dom.dominates(BlockId(0), BlockId(3)));
190 assert!(!dom.dominates(BlockId(1), BlockId(2)));
191 }
192
193 #[test]
194 fn test_dominance_frontier() {
195 let func = make_diamond();
196 let cfg = Cfg::build(&func);
197 let dom = DomTree::compute(&cfg);
198
199 let df1 = dom.frontier(BlockId(1));
200 assert!(df1.contains(&BlockId(3)));
201 let df2 = dom.frontier(BlockId(2));
202 assert!(df2.contains(&BlockId(3)));
203 let df0 = dom.frontier(BlockId(0));
204 assert!(df0.is_empty());
205 }
206
207 #[test]
208 fn test_self_dominance() {
209 let func = make_diamond();
210 let cfg = Cfg::build(&func);
211 let dom = DomTree::compute(&cfg);
212 for bid in [BlockId(0), BlockId(1), BlockId(2), BlockId(3)] {
213 assert!(dom.dominates(bid, bid));
214 }
215 }
216}