1use crate::MirFunction;
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone)]
10pub struct ControlFlowGraph {
11 pub blocks: HashMap<String, BasicBlock>,
13 pub edges: HashMap<String, Vec<String>>,
15 pub entry_block: String,
17 pub exit_blocks: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct BasicBlock {
24 pub id: String,
25 pub statements: Vec<String>,
27 pub terminator: Terminator,
29}
30
31#[derive(Debug, Clone)]
33pub enum Terminator {
34 Goto { target: String },
36
37 SwitchInt {
39 condition: String,
41 targets: Vec<(String, String)>, otherwise: Option<String>,
45 },
46
47 Return,
49
50 Call {
52 return_target: Option<String>,
54 unwind_target: Option<String>,
56 },
57
58 Assert {
60 success_target: String,
62 failure_target: Option<String>,
64 },
65
66 Drop {
68 target: String,
69 unwind_target: Option<String>,
70 },
71
72 Unreachable,
74
75 Unknown(String),
77}
78
79impl ControlFlowGraph {
80 pub fn from_mir_function(function: &MirFunction) -> Self {
82 let mut blocks = HashMap::new();
83 let mut edges = HashMap::new();
84 let mut exit_blocks = Vec::new();
85
86 let parsed_blocks = Self::parse_basic_blocks(&function.body);
88
89 for (id, block) in parsed_blocks {
90 let successors = Self::extract_successors(&block.terminator);
92 if !successors.is_empty() {
93 edges.insert(id.clone(), successors);
94 }
95
96 if matches!(block.terminator, Terminator::Return) {
98 exit_blocks.push(id.clone());
99 }
100
101 blocks.insert(id, block);
102 }
103
104 ControlFlowGraph {
105 blocks,
106 edges,
107 entry_block: "bb0".to_string(),
108 exit_blocks,
109 }
110 }
111
112 pub fn block_count(&self) -> usize {
114 self.blocks.len()
115 }
116
117 fn parse_basic_blocks(body: &[String]) -> HashMap<String, BasicBlock> {
119 let mut blocks = HashMap::new();
120 let mut current_block_id: Option<String> = None;
121 let mut current_statements = Vec::new();
122 let mut current_terminator: Option<Terminator> = None;
123
124 for line in body {
125 let trimmed = line.trim();
126
127 if let Some(block_id) = Self::extract_block_id(trimmed) {
129 if let Some(id) = current_block_id.take() {
131 if let Some(term) = current_terminator.take() {
132 blocks.insert(
133 id.clone(),
134 BasicBlock {
135 id,
136 statements: std::mem::take(&mut current_statements),
137 terminator: term,
138 },
139 );
140 }
141 }
142
143 current_block_id = Some(block_id);
145 current_statements.clear();
146 current_terminator = None;
147 }
148 else if trimmed.starts_with("goto ")
150 || trimmed.starts_with("switchInt")
151 || trimmed.starts_with("return")
152 || trimmed.contains(" -> [return:")
153 || trimmed.starts_with("assert(")
154 || trimmed.starts_with("drop(")
155 || trimmed.starts_with("unreachable")
156 {
157 if trimmed.contains(" = ") && trimmed.contains(" -> [return:") {
160 current_statements.push(trimmed.to_string());
161 }
162 current_terminator = Some(Self::parse_terminator(trimmed));
163 }
164 else if !trimmed.is_empty()
166 && !trimmed.starts_with("}")
167 && !trimmed.starts_with("scope")
168 && !trimmed.starts_with("debug")
169 && !trimmed.starts_with("let")
170 {
171 current_statements.push(trimmed.to_string());
172 }
173 }
174
175 if let Some(id) = current_block_id {
177 if let Some(term) = current_terminator {
178 blocks.insert(
179 id.clone(),
180 BasicBlock {
181 id,
182 statements: current_statements,
183 terminator: term,
184 },
185 );
186 }
187 }
188
189 blocks
190 }
191
192 fn extract_block_id(line: &str) -> Option<String> {
194 if line.starts_with("bb") && line.contains(": {") {
195 let id = line.split(": {").next()?;
196 Some(id.to_string())
197 } else {
198 None
199 }
200 }
201
202 fn parse_terminator(line: &str) -> Terminator {
204 let line = line.trim().trim_end_matches(';');
205
206 if let Some(rest) = line.strip_prefix("goto -> ") {
208 return Terminator::Goto {
209 target: rest.to_string(),
210 };
211 }
212
213 if line == "return" {
215 return Terminator::Return;
216 }
217
218 if line == "unreachable" {
220 return Terminator::Unreachable;
221 }
222
223 if let Some(rest) = line.strip_prefix("switchInt(") {
225 if let Some(paren_end) = rest.find(") -> [") {
226 let condition = rest[..paren_end].to_string();
227 let targets_str = &rest[paren_end + 6..]; let mut targets = Vec::new();
230 let mut otherwise = None;
231
232 for part in targets_str.trim_end_matches(']').split(", ") {
234 if let Some((value, block)) = part.split_once(": ") {
235 if value == "otherwise" {
236 otherwise = Some(block.to_string());
237 } else {
238 targets.push((value.to_string(), block.to_string()));
239 }
240 }
241 }
242
243 return Terminator::SwitchInt {
244 condition,
245 targets,
246 otherwise,
247 };
248 }
249 }
250
251 if line.contains(" -> [return:") {
253 let mut return_target = None;
254 let mut unwind_target = None;
255
256 if let Some(arrow_pos) = line.find(" -> [") {
257 let targets_str = &line[arrow_pos + 5..]; for part in targets_str.trim_end_matches(']').split(", ") {
260 if let Some(rest) = part.strip_prefix("return: ") {
261 return_target = Some(rest.to_string());
262 } else if let Some(rest) = part.strip_prefix("unwind: ") {
263 unwind_target = Some(rest.to_string());
264 }
265 }
266 }
267
268 return Terminator::Call {
269 return_target,
270 unwind_target,
271 };
272 }
273
274 if let Some(rest) = line.strip_prefix("assert(") {
276 if let Some(arrow_pos) = rest.find(" -> [") {
277 let targets_str = &rest[arrow_pos + 5..];
278 let mut success_target = String::new();
279 let mut failure_target = None;
280
281 for part in targets_str.trim_end_matches(']').split(", ") {
282 if let Some(rest) = part.strip_prefix("success: ") {
283 success_target = rest.to_string();
284 } else if let Some(rest) = part.strip_prefix("unwind: ") {
285 failure_target = Some(rest.to_string());
286 }
287 }
288
289 return Terminator::Assert {
290 success_target,
291 failure_target,
292 };
293 }
294 }
295
296 if let Some(rest) = line.strip_prefix("drop(") {
298 if let Some(arrow_pos) = rest.find(" -> [") {
299 let targets_str = &rest[arrow_pos + 5..];
300 let mut target = String::new();
301 let mut unwind_target = None;
302
303 for part in targets_str.trim_end_matches(']').split(", ") {
304 if let Some(rest) = part.strip_prefix("return: ") {
305 target = rest.to_string();
306 } else if let Some(rest) = part.strip_prefix("unwind: ") {
307 unwind_target = Some(rest.to_string());
308 }
309 }
310
311 return Terminator::Drop {
312 target,
313 unwind_target,
314 };
315 }
316 }
317
318 Terminator::Unknown(line.to_string())
320 }
321
322 fn extract_successors(terminator: &Terminator) -> Vec<String> {
324 match terminator {
325 Terminator::Goto { target } => vec![target.clone()],
326
327 Terminator::SwitchInt {
328 targets, otherwise, ..
329 } => {
330 let mut successors: Vec<String> =
331 targets.iter().map(|(_, block)| block.clone()).collect();
332 if let Some(other) = otherwise {
333 successors.push(other.clone());
334 }
335 successors
336 }
337
338 Terminator::Return | Terminator::Unreachable => vec![],
339
340 Terminator::Call {
341 return_target,
342 unwind_target,
343 } => {
344 let mut successors = Vec::new();
345 if let Some(ret) = return_target {
346 successors.push(ret.clone());
347 }
348 if let Some(unw) = unwind_target {
349 successors.push(unw.clone());
350 }
351 successors
352 }
353
354 Terminator::Assert {
355 success_target,
356 failure_target,
357 } => {
358 let mut successors = vec![success_target.clone()];
359 if let Some(fail) = failure_target {
360 successors.push(fail.clone());
361 }
362 successors
363 }
364
365 Terminator::Drop {
366 target,
367 unwind_target,
368 } => {
369 let mut successors = vec![target.clone()];
370 if let Some(unw) = unwind_target {
371 successors.push(unw.clone());
372 }
373 successors
374 }
375
376 Terminator::Unknown(_) => vec![],
377 }
378 }
379
380 pub fn branch_count(&self) -> usize {
382 self.edges
383 .values()
384 .filter(|successors| successors.len() > 1)
385 .count()
386 }
387
388 pub fn is_too_complex_for_path_enumeration(&self) -> bool {
402 const MAX_BLOCKS_FOR_PATH_ENUM: usize = 500;
403 const MAX_BRANCHES_FOR_PATH_ENUM: usize = 100;
404
405 self.blocks.len() > MAX_BLOCKS_FOR_PATH_ENUM
406 || self.branch_count() > MAX_BRANCHES_FOR_PATH_ENUM
407 }
408
409 pub fn get_all_paths(&self) -> (Vec<Vec<String>>, bool) {
417 if self.is_too_complex_for_path_enumeration() {
420 return (Vec::new(), true);
422 }
423
424 let mut paths = Vec::new();
426 let mut current_path = Vec::new();
427 let mut visited = HashSet::new();
428
429 const MAX_PATHS: usize = 1000;
432 const MAX_DEPTH: usize = 50;
433 self.dfs_paths(
434 &self.entry_block,
435 &mut current_path,
436 &mut visited,
437 &mut paths,
438 0,
439 MAX_DEPTH,
440 MAX_PATHS,
441 );
442
443 (paths, false)
445 }
446
447 fn dfs_paths(
449 &self,
450 current_block: &str,
451 current_path: &mut Vec<String>,
452 visited: &mut HashSet<String>,
453 paths: &mut Vec<Vec<String>>,
454 depth: usize,
455 max_depth: usize,
456 max_paths: usize,
457 ) {
458 if depth > max_depth || visited.contains(current_block) || paths.len() >= max_paths {
460 return;
461 }
462
463 current_path.push(current_block.to_string());
464 visited.insert(current_block.to_string());
465
466 if self.exit_blocks.contains(¤t_block.to_string()) {
468 paths.push(current_path.clone());
469 } else if let Some(successors) = self.edges.get(current_block) {
470 for successor in successors {
472 if paths.len() >= max_paths {
473 break; }
475 self.dfs_paths(
476 successor,
477 current_path,
478 visited,
479 paths,
480 depth + 1,
481 max_depth,
482 max_paths,
483 );
484 }
485 }
486
487 current_path.pop();
489 visited.remove(current_block);
490 }
491
492 pub fn get_block(&self, block_id: &str) -> Option<&BasicBlock> {
494 self.blocks.get(block_id)
495 }
496
497 pub fn has_branching(&self) -> bool {
499 self.edges.values().any(|successors| successors.len() > 1)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_parse_goto() {
509 let term = ControlFlowGraph::parse_terminator("goto -> bb5;");
510 match term {
511 Terminator::Goto { target } => assert_eq!(target, "bb5"),
512 _ => panic!("Expected Goto"),
513 }
514 }
515
516 #[test]
517 fn test_parse_return() {
518 let term = ControlFlowGraph::parse_terminator("return;");
519 assert!(matches!(term, Terminator::Return));
520 }
521
522 #[test]
523 fn test_parse_switch_int() {
524 let term =
525 ControlFlowGraph::parse_terminator("switchInt(move _5) -> [0: bb12, otherwise: bb7];");
526 match term {
527 Terminator::SwitchInt {
528 condition,
529 targets,
530 otherwise,
531 } => {
532 assert_eq!(condition, "move _5");
533 assert_eq!(targets.len(), 1);
534 assert_eq!(targets[0], ("0".to_string(), "bb12".to_string()));
535 assert_eq!(otherwise, Some("bb7".to_string()));
536 }
537 _ => panic!("Expected SwitchInt"),
538 }
539 }
540
541 #[test]
542 fn test_parse_call() {
543 let term =
544 ControlFlowGraph::parse_terminator("some_func() -> [return: bb2, unwind continue];");
545 match term {
546 Terminator::Call { return_target, .. } => {
547 assert_eq!(return_target, Some("bb2".to_string()));
548 }
549 _ => panic!("Expected Call"),
550 }
551 }
552}