1use super::types::{BasicBlockId, MirFunction, TerminatorKind};
4use std::collections::{HashMap, HashSet, VecDeque};
5
6#[derive(Debug)]
9pub struct ControlFlowGraph {
10 successors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
12 predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>>,
14}
15
16impl ControlFlowGraph {
17 pub fn build(mir: &MirFunction) -> Self {
19 let mut successors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
20 let mut predecessors: HashMap<BasicBlockId, Vec<BasicBlockId>> = HashMap::new();
21
22 for block in &mir.blocks {
23 let succs = Self::terminator_successors(&block.terminator.kind);
24 for &succ in &succs {
25 predecessors.entry(succ).or_default().push(block.id);
26 }
27 successors.insert(block.id, succs);
28 }
29
30 ControlFlowGraph {
31 successors,
32 predecessors,
33 }
34 }
35
36 pub fn successors(&self, block: BasicBlockId) -> &[BasicBlockId] {
38 self.successors.get(&block).map_or(&[], |v| v.as_slice())
39 }
40
41 pub fn predecessors(&self, block: BasicBlockId) -> &[BasicBlockId] {
43 self.predecessors.get(&block).map_or(&[], |v| v.as_slice())
44 }
45
46 pub fn reverse_postorder(&self) -> Vec<BasicBlockId> {
48 let mut visited = HashSet::new();
49 let mut postorder = Vec::new();
50 let entry = BasicBlockId(0);
51
52 self.dfs_postorder(entry, &mut visited, &mut postorder);
53 postorder.reverse();
54 postorder
55 }
56
57 fn dfs_postorder(
58 &self,
59 block: BasicBlockId,
60 visited: &mut HashSet<BasicBlockId>,
61 postorder: &mut Vec<BasicBlockId>,
62 ) {
63 if !visited.insert(block) {
64 return;
65 }
66 for &succ in self.successors(block) {
67 self.dfs_postorder(succ, visited, postorder);
68 }
69 postorder.push(block);
70 }
71
72 pub fn dominators(&self) -> HashMap<BasicBlockId, BasicBlockId> {
74 let rpo = self.reverse_postorder();
75 let entry = BasicBlockId(0);
76 let mut doms: HashMap<BasicBlockId, BasicBlockId> = HashMap::new();
77 doms.insert(entry, entry);
78
79 let mut changed = true;
80 while changed {
81 changed = false;
82 for &b in &rpo {
83 if b == entry {
84 continue;
85 }
86 let preds = self.predecessors(b);
87 let mut new_idom = None;
88 for &p in preds {
89 if doms.contains_key(&p) {
90 new_idom = Some(match new_idom {
91 None => p,
92 Some(current) => self.intersect(current, p, &doms, &rpo),
93 });
94 }
95 }
96 if let Some(new_idom) = new_idom {
97 if doms.get(&b) != Some(&new_idom) {
98 doms.insert(b, new_idom);
99 changed = true;
100 }
101 }
102 }
103 }
104
105 doms
106 }
107
108 fn intersect(
109 &self,
110 mut a: BasicBlockId,
111 mut b: BasicBlockId,
112 doms: &HashMap<BasicBlockId, BasicBlockId>,
113 rpo: &[BasicBlockId],
114 ) -> BasicBlockId {
115 let rpo_index: HashMap<BasicBlockId, usize> =
116 rpo.iter().enumerate().map(|(i, &bb)| (bb, i)).collect();
117 while a != b {
118 while rpo_index.get(&a).copied().unwrap_or(0) > rpo_index.get(&b).copied().unwrap_or(0)
119 {
120 a = *doms.get(&a).unwrap_or(&a);
121 }
122 while rpo_index.get(&b).copied().unwrap_or(0) > rpo_index.get(&a).copied().unwrap_or(0)
123 {
124 b = *doms.get(&b).unwrap_or(&b);
125 }
126 }
127 a
128 }
129
130 pub fn is_reachable(&self, target: BasicBlockId) -> bool {
132 let mut visited = HashSet::new();
133 let mut queue = VecDeque::new();
134 queue.push_back(BasicBlockId(0));
135 visited.insert(BasicBlockId(0));
136
137 while let Some(block) = queue.pop_front() {
138 if block == target {
139 return true;
140 }
141 for &succ in self.successors(block) {
142 if visited.insert(succ) {
143 queue.push_back(succ);
144 }
145 }
146 }
147 false
148 }
149
150 fn terminator_successors(kind: &TerminatorKind) -> Vec<BasicBlockId> {
151 match kind {
152 TerminatorKind::Goto(target) => vec![*target],
153 TerminatorKind::SwitchBool {
154 true_bb, false_bb, ..
155 } => vec![*true_bb, *false_bb],
156 TerminatorKind::Call { next, .. } => vec![*next],
157 TerminatorKind::Return | TerminatorKind::Unreachable => vec![],
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::mir::types::*;
166
167 fn span() -> shape_ast::ast::Span {
168 shape_ast::ast::Span { start: 0, end: 1 }
169 }
170
171 fn make_terminator(kind: TerminatorKind) -> super::super::types::Terminator {
172 super::super::types::Terminator { kind, span: span() }
173 }
174
175 #[test]
176 fn test_linear_cfg() {
177 let mir = MirFunction {
178 name: "test".to_string(),
179 blocks: vec![
180 BasicBlock {
181 id: BasicBlockId(0),
182 statements: vec![],
183 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
184 },
185 BasicBlock {
186 id: BasicBlockId(1),
187 statements: vec![],
188 terminator: make_terminator(TerminatorKind::Return),
189 },
190 ],
191 num_locals: 0,
192 param_slots: vec![],
193 param_reference_kinds: vec![],
194 local_types: vec![],
195 span: span(),
196 };
197 let cfg = ControlFlowGraph::build(&mir);
198 assert_eq!(cfg.successors(BasicBlockId(0)), &[BasicBlockId(1)]);
199 assert_eq!(cfg.predecessors(BasicBlockId(1)), &[BasicBlockId(0)]);
200 }
201
202 #[test]
203 fn test_branch_cfg() {
204 let mir = MirFunction {
205 name: "test".to_string(),
206 blocks: vec![
207 BasicBlock {
208 id: BasicBlockId(0),
209 statements: vec![],
210 terminator: make_terminator(TerminatorKind::SwitchBool {
211 operand: Operand::Constant(MirConstant::Bool(true)),
212 true_bb: BasicBlockId(1),
213 false_bb: BasicBlockId(2),
214 }),
215 },
216 BasicBlock {
217 id: BasicBlockId(1),
218 statements: vec![],
219 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
220 },
221 BasicBlock {
222 id: BasicBlockId(2),
223 statements: vec![],
224 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(3))),
225 },
226 BasicBlock {
227 id: BasicBlockId(3),
228 statements: vec![],
229 terminator: make_terminator(TerminatorKind::Return),
230 },
231 ],
232 num_locals: 0,
233 param_slots: vec![],
234 param_reference_kinds: vec![],
235 local_types: vec![],
236 span: span(),
237 };
238 let cfg = ControlFlowGraph::build(&mir);
239 let rpo = cfg.reverse_postorder();
240 assert_eq!(rpo[0], BasicBlockId(0)); assert!(cfg.is_reachable(BasicBlockId(3)));
242 }
243
244 #[test]
245 fn test_loop_cfg() {
246 let mir = MirFunction {
247 name: "test".to_string(),
248 blocks: vec![
249 BasicBlock {
250 id: BasicBlockId(0),
251 statements: vec![],
252 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
253 },
254 BasicBlock {
255 id: BasicBlockId(1),
256 statements: vec![],
257 terminator: make_terminator(TerminatorKind::SwitchBool {
258 operand: Operand::Constant(MirConstant::Bool(true)),
259 true_bb: BasicBlockId(2),
260 false_bb: BasicBlockId(3),
261 }),
262 },
263 BasicBlock {
264 id: BasicBlockId(2),
265 statements: vec![],
266 terminator: make_terminator(TerminatorKind::Goto(BasicBlockId(1))),
267 },
268 BasicBlock {
269 id: BasicBlockId(3),
270 statements: vec![],
271 terminator: make_terminator(TerminatorKind::Return),
272 },
273 ],
274 num_locals: 0,
275 param_slots: vec![],
276 param_reference_kinds: vec![],
277 local_types: vec![],
278 span: span(),
279 };
280 let cfg = ControlFlowGraph::build(&mir);
281 let preds = cfg.predecessors(BasicBlockId(1));
283 assert_eq!(preds.len(), 2);
284 }
285}