1use crate::cfg::analysis::is_branch_point;
4use crate::cfg::EdgeType;
5use crate::cfg::{BlockId, Cfg, Terminator};
6use petgraph::graph::NodeIndex;
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum BranchType {
12 Linear,
14 Conditional,
16 MultiWay,
18 Unknown,
20}
21
22#[derive(Debug, Clone)]
24pub struct IfElsePattern {
25 pub condition: NodeIndex,
27 pub true_branch: NodeIndex,
29 pub false_branch: NodeIndex,
31 pub merge_point: Option<NodeIndex>,
34}
35
36impl IfElsePattern {
37 pub fn has_else(&self) -> bool {
39 self.merge_point.is_some()
40 }
41
42 pub fn size(&self) -> usize {
44 2 + if self.merge_point.is_some() { 1 } else { 0 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct MatchPattern {
51 pub switch_node: NodeIndex,
53 pub targets: Vec<NodeIndex>,
55 pub otherwise: NodeIndex,
57}
58
59impl MatchPattern {
60 pub fn branch_count(&self) -> usize {
62 self.targets.len() + 1 }
64
65 pub fn has_explicit_default(&self) -> bool {
68 true
71 }
72}
73
74pub fn classify_branch(cfg: &Cfg, node: NodeIndex) -> BranchType {
93 let successors: Vec<_> = cfg.neighbors(node).collect();
94
95 match successors.len() {
96 0 | 1 => BranchType::Linear,
97 2 => {
98 let merge = find_common_successor(cfg, successors[0], successors[1]);
100 if merge.is_some() {
101 BranchType::Conditional
102 } else {
103 BranchType::Unknown
105 }
106 }
107 3.. => {
108 if let Some(block) = cfg.node_weight(node) {
110 if matches!(block.terminator, Terminator::SwitchInt { .. }) {
111 BranchType::MultiWay
112 } else {
113 BranchType::Unknown
114 }
115 } else {
116 BranchType::Unknown
117 }
118 }
119 }
120}
121
122fn find_common_successor(cfg: &Cfg, n1: NodeIndex, n2: NodeIndex) -> Option<NodeIndex> {
127 let mut reachable_from_n1 = HashSet::new();
129 let mut worklist = vec![n1];
130
131 while let Some(node) = worklist.pop() {
132 if node == n1 || node == n2 {
134 for succ in cfg.neighbors(node) {
135 if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
136 worklist.push(succ);
137 reachable_from_n1.insert(succ);
138 }
139 }
140 continue;
141 }
142
143 if !reachable_from_n1.insert(node) {
144 continue;
145 }
146 for succ in cfg.neighbors(node) {
147 if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
148 worklist.push(succ);
149 }
150 }
151 }
152
153 let mut visited = HashSet::new();
155 let mut worklist = vec![n2];
156
157 while let Some(node) = worklist.pop() {
158 if node == n1 || node == n2 {
159 for succ in cfg.neighbors(node) {
160 if succ != n1 && succ != n2 && !visited.contains(&succ) {
161 if reachable_from_n1.contains(&succ) {
162 return Some(succ);
163 }
164 visited.insert(succ);
165 worklist.push(succ);
166 }
167 }
168 continue;
169 }
170
171 if reachable_from_n1.contains(&node) {
172 return Some(node);
173 }
174
175 if !visited.insert(node) {
176 continue;
177 }
178 for succ in cfg.neighbors(node) {
179 if succ != n1 && succ != n2 && !visited.contains(&succ) {
180 worklist.push(succ);
181 }
182 }
183 }
184
185 None
186}
187
188pub fn detect_if_else_patterns(cfg: &Cfg) -> Vec<IfElsePattern> {
211 let mut patterns = Vec::new();
212
213 for branch in cfg.node_indices().filter(|&n| is_branch_point(cfg, n)) {
214 let successors: Vec<_> = cfg.neighbors(branch).collect();
215
216 if successors.len() == 2 {
217 if let Some(block) = cfg.node_weight(branch) {
220 if let Terminator::SwitchInt { targets, .. } = &block.terminator {
221 if targets.len() > 1 {
222 continue;
224 }
225 }
226 }
227
228 let merge_point = find_common_successor(cfg, successors[0], successors[1]);
230
231 let (true_branch, false_branch) =
233 order_branches_by_edge_type(cfg, branch, successors[0], successors[1]);
234
235 patterns.push(IfElsePattern {
236 condition: branch,
237 true_branch,
238 false_branch,
239 merge_point,
240 });
241 }
242 }
243
244 patterns
245}
246
247fn order_branches_by_edge_type(
252 cfg: &Cfg,
253 from: NodeIndex,
254 succ1: NodeIndex,
255 succ2: NodeIndex,
256) -> (NodeIndex, NodeIndex) {
257 let edge1_type = cfg
258 .find_edge(from, succ1)
259 .and_then(|e| cfg.edge_weight(e).copied());
260 let edge2_type = cfg
261 .find_edge(from, succ2)
262 .and_then(|e| cfg.edge_weight(e).copied());
263
264 match (edge1_type, edge2_type) {
265 (Some(EdgeType::TrueBranch), _) => (succ1, succ2),
266 (_, Some(EdgeType::TrueBranch)) => (succ2, succ1),
267 (Some(EdgeType::FalseBranch), _) => (succ2, succ1),
268 (_, Some(EdgeType::FalseBranch)) => (succ1, succ2),
269 _ => (succ1, succ2), }
271}
272
273pub fn detect_match_patterns(cfg: &Cfg) -> Vec<MatchPattern> {
292 let mut patterns = Vec::new();
293
294 for node in cfg.node_indices() {
295 if let Some(block) = cfg.node_weight(node) {
296 if let Terminator::SwitchInt { targets, otherwise } = &block.terminator {
297 if targets.len() < 2 {
300 continue;
301 }
302
303 let target_indices: Vec<_> = targets
305 .iter()
306 .filter_map(|&id| find_node_by_id(cfg, id))
307 .collect();
308
309 if let Some(otherwise_idx) = find_node_by_id(cfg, *otherwise) {
310 patterns.push(MatchPattern {
311 switch_node: node,
312 targets: target_indices,
313 otherwise: otherwise_idx,
314 });
315 }
316 }
317 }
318 }
319
320 patterns
321}
322
323fn find_node_by_id(cfg: &Cfg, id: BlockId) -> Option<NodeIndex> {
325 cfg.node_indices()
326 .find(|&n| cfg.node_weight(n).map_or(false, |b| b.id == id))
327}
328
329pub fn detect_all_patterns(cfg: &Cfg) -> (Vec<IfElsePattern>, Vec<MatchPattern>) {
334 (detect_if_else_patterns(cfg), detect_match_patterns(cfg))
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::cfg::{BasicBlock, BlockKind, EdgeType, Terminator};
341 use petgraph::graph::DiGraph;
342
343 fn create_diamond_cfg() -> Cfg {
345 let mut g = DiGraph::new();
346
347 let b0 = g.add_node(BasicBlock {
349 id: 0,
350 db_id: None,
351 kind: BlockKind::Entry,
352 statements: vec![],
353 terminator: Terminator::Goto { target: 1 },
354 source_location: None,
355 coord_x: 0,
356 coord_y: 0,
357 coord_z: 0,
358 });
359
360 let b1 = g.add_node(BasicBlock {
362 id: 1,
363 db_id: None,
364 kind: BlockKind::Normal,
365 statements: vec![],
366 terminator: Terminator::SwitchInt {
367 targets: vec![2],
368 otherwise: 3,
369 },
370 source_location: None,
371 coord_x: 0,
372 coord_y: 0,
373 coord_z: 0,
374 });
375
376 let b2 = g.add_node(BasicBlock {
378 id: 2,
379 db_id: None,
380 kind: BlockKind::Normal,
381 statements: vec!["true branch".to_string()],
382 terminator: Terminator::Goto { target: 4 },
383 source_location: None,
384 coord_x: 0,
385 coord_y: 0,
386 coord_z: 0,
387 });
388
389 let b3 = g.add_node(BasicBlock {
391 id: 3,
392 db_id: None,
393 kind: BlockKind::Normal,
394 statements: vec!["false branch".to_string()],
395 terminator: Terminator::Goto { target: 4 },
396 source_location: None,
397 coord_x: 0,
398 coord_y: 0,
399 coord_z: 0,
400 });
401
402 let b4 = g.add_node(BasicBlock {
404 id: 4,
405 db_id: None,
406 kind: BlockKind::Exit,
407 statements: vec!["merge".to_string()],
408 terminator: Terminator::Return,
409 source_location: None,
410 coord_x: 0,
411 coord_y: 0,
412 coord_z: 0,
413 });
414
415 g.add_edge(b0, b1, EdgeType::Fallthrough);
416 g.add_edge(b1, b2, EdgeType::TrueBranch);
417 g.add_edge(b1, b3, EdgeType::FalseBranch);
418 g.add_edge(b2, b4, EdgeType::Fallthrough);
419 g.add_edge(b3, b4, EdgeType::Fallthrough);
420
421 g
422 }
423
424 #[test]
425 fn test_detect_if_else_diamond() {
426 let cfg = create_diamond_cfg();
427 let patterns = detect_if_else_patterns(&cfg);
428
429 assert_eq!(patterns.len(), 1);
430
431 let pattern = &patterns[0];
432 assert_eq!(pattern.condition.index(), 1);
433 assert_eq!(pattern.true_branch.index(), 2);
434 assert_eq!(pattern.false_branch.index(), 3);
435 assert_eq!(pattern.merge_point, Some(NodeIndex::new(4)));
436 assert!(pattern.has_else());
437 }
438
439 #[test]
440 fn test_classify_branch() {
441 let cfg = create_diamond_cfg();
442
443 assert_eq!(classify_branch(&cfg, NodeIndex::new(0)), BranchType::Linear);
444 assert_eq!(
445 classify_branch(&cfg, NodeIndex::new(1)),
446 BranchType::Conditional
447 );
448 assert_eq!(classify_branch(&cfg, NodeIndex::new(2)), BranchType::Linear);
449 assert_eq!(classify_branch(&cfg, NodeIndex::new(4)), BranchType::Linear);
450 }
451
452 #[test]
453 fn test_detect_match_patterns() {
454 let mut g = DiGraph::new();
455
456 let b0 = g.add_node(BasicBlock {
458 id: 0,
459 db_id: None,
460 kind: BlockKind::Entry,
461 statements: vec![],
462 terminator: Terminator::SwitchInt {
463 targets: vec![1, 2],
464 otherwise: 3,
465 },
466 source_location: None,
467 coord_x: 0,
468 coord_y: 0,
469 coord_z: 0,
470 });
471
472 let b1 = g.add_node(BasicBlock {
473 id: 1,
474 db_id: None,
475 kind: BlockKind::Exit,
476 statements: vec!["case 1".to_string()],
477 terminator: Terminator::Return,
478 source_location: None,
479 coord_x: 0,
480 coord_y: 0,
481 coord_z: 0,
482 });
483
484 let b2 = g.add_node(BasicBlock {
485 id: 2,
486 db_id: None,
487 kind: BlockKind::Exit,
488 statements: vec!["case 2".to_string()],
489 terminator: Terminator::Return,
490 source_location: None,
491 coord_x: 0,
492 coord_y: 0,
493 coord_z: 0,
494 });
495
496 let b3 = g.add_node(BasicBlock {
497 id: 3,
498 db_id: None,
499 kind: BlockKind::Exit,
500 statements: vec!["default".to_string()],
501 terminator: Terminator::Return,
502 source_location: None,
503 coord_x: 0,
504 coord_y: 0,
505 coord_z: 0,
506 });
507
508 g.add_edge(b0, b1, EdgeType::TrueBranch);
509 g.add_edge(b0, b2, EdgeType::TrueBranch);
510 g.add_edge(b0, b3, EdgeType::FalseBranch);
511
512 let patterns = detect_match_patterns(&g);
513 assert_eq!(patterns.len(), 1);
514
515 let pattern = &patterns[0];
516 assert_eq!(pattern.switch_node.index(), 0);
517 assert_eq!(pattern.targets.len(), 2);
518 assert_eq!(pattern.otherwise.index(), 3);
519 assert_eq!(pattern.branch_count(), 3);
520 }
521
522 #[test]
523 fn test_classify_multiway() {
524 let mut g = DiGraph::new();
525
526 let b0 = g.add_node(BasicBlock {
527 id: 0,
528 db_id: None,
529 kind: BlockKind::Entry,
530 statements: vec![],
531 terminator: Terminator::SwitchInt {
532 targets: vec![1, 2],
533 otherwise: 3,
534 },
535 source_location: None,
536 coord_x: 0,
537 coord_y: 0,
538 coord_z: 0,
539 });
540
541 for i in 1..=3 {
542 g.add_node(BasicBlock {
543 id: i,
544 db_id: None,
545 kind: BlockKind::Exit,
546 statements: vec![],
547 terminator: Terminator::Return,
548 source_location: None,
549 coord_x: 0,
550 coord_y: 0,
551 coord_z: 0,
552 });
553 }
554
555 for i in 1..=3 {
556 g.add_edge(b0, NodeIndex::new(i), EdgeType::TrueBranch);
557 }
558
559 assert_eq!(classify_branch(&g, NodeIndex::new(0)), BranchType::MultiWay);
560 }
561
562 #[test]
563 fn test_detect_all_patterns() {
564 let mut g = DiGraph::new();
565
566 let b0 = g.add_node(BasicBlock {
570 id: 0,
571 db_id: None,
572 kind: BlockKind::Entry,
573 statements: vec![],
574 terminator: Terminator::Goto { target: 1 },
575 source_location: None,
576 coord_x: 0,
577 coord_y: 0,
578 coord_z: 0,
579 });
580
581 let b1 = g.add_node(BasicBlock {
583 id: 1,
584 db_id: None,
585 kind: BlockKind::Normal,
586 statements: vec![],
587 terminator: Terminator::SwitchInt {
588 targets: vec![2],
589 otherwise: 3,
590 },
591 source_location: None,
592 coord_x: 0,
593 coord_y: 0,
594 coord_z: 0,
595 });
596
597 let b2 = g.add_node(BasicBlock {
599 id: 2,
600 db_id: None,
601 kind: BlockKind::Normal,
602 statements: vec![],
603 terminator: Terminator::SwitchInt {
605 targets: vec![4, 5],
606 otherwise: 6,
607 },
608 source_location: None,
609 coord_x: 0,
610 coord_y: 0,
611 coord_z: 0,
612 });
613
614 let b3 = g.add_node(BasicBlock {
616 id: 3,
617 db_id: None,
618 kind: BlockKind::Normal,
619 statements: vec![],
620 terminator: Terminator::Goto { target: 7 },
621 source_location: None,
622 coord_x: 0,
623 coord_y: 0,
624 coord_z: 0,
625 });
626
627 let b4 = g.add_node(BasicBlock {
629 id: 4,
630 db_id: None,
631 kind: BlockKind::Normal,
632 statements: vec![],
633 terminator: Terminator::Goto { target: 7 },
634 source_location: None,
635 coord_x: 0,
636 coord_y: 0,
637 coord_z: 0,
638 });
639
640 let b5 = g.add_node(BasicBlock {
641 id: 5,
642 db_id: None,
643 kind: BlockKind::Normal,
644 statements: vec![],
645 terminator: Terminator::Goto { target: 7 },
646 source_location: None,
647 coord_x: 0,
648 coord_y: 0,
649 coord_z: 0,
650 });
651
652 let b6 = g.add_node(BasicBlock {
653 id: 6,
654 db_id: None,
655 kind: BlockKind::Normal,
656 statements: vec![],
657 terminator: Terminator::Goto { target: 7 },
658 source_location: None,
659 coord_x: 0,
660 coord_y: 0,
661 coord_z: 0,
662 });
663
664 let b7 = g.add_node(BasicBlock {
666 id: 7,
667 db_id: None,
668 kind: BlockKind::Exit,
669 statements: vec![],
670 terminator: Terminator::Return,
671 source_location: None,
672 coord_x: 0,
673 coord_y: 0,
674 coord_z: 0,
675 });
676
677 g.add_edge(b0, b1, EdgeType::Fallthrough);
678 g.add_edge(b1, b2, EdgeType::TrueBranch);
679 g.add_edge(b1, b3, EdgeType::FalseBranch);
680 g.add_edge(b2, b4, EdgeType::TrueBranch);
681 g.add_edge(b2, b5, EdgeType::TrueBranch);
682 g.add_edge(b2, b6, EdgeType::FalseBranch);
683 g.add_edge(b3, b7, EdgeType::Fallthrough);
684 g.add_edge(b4, b7, EdgeType::Fallthrough);
685 g.add_edge(b5, b7, EdgeType::Fallthrough);
686 g.add_edge(b6, b7, EdgeType::Fallthrough);
687
688 let (if_patterns, match_patterns) = detect_all_patterns(&g);
689
690 assert_eq!(if_patterns.len(), 1);
692 assert_eq!(if_patterns[0].condition.index(), 1);
693
694 assert_eq!(match_patterns.len(), 1);
696 assert_eq!(match_patterns[0].switch_node.index(), 2);
697 assert_eq!(match_patterns[0].targets.len(), 2);
698 assert_eq!(match_patterns[0].branch_count(), 3);
699 }
700
701 #[test]
702 fn test_empty_cfg() {
703 let cfg: Cfg = DiGraph::new();
704 assert!(detect_if_else_patterns(&cfg).is_empty());
705 assert!(detect_match_patterns(&cfg).is_empty());
706 }
707
708 #[test]
709 fn test_linear_cfg_no_patterns() {
710 let mut g = DiGraph::new();
711
712 let b0 = g.add_node(BasicBlock {
714 id: 0,
715 db_id: None,
716 kind: BlockKind::Entry,
717 statements: vec![],
718 terminator: Terminator::Goto { target: 1 },
719 source_location: None,
720 coord_x: 0,
721 coord_y: 0,
722 coord_z: 0,
723 });
724
725 let b1 = g.add_node(BasicBlock {
726 id: 1,
727 db_id: None,
728 kind: BlockKind::Normal,
729 statements: vec![],
730 terminator: Terminator::Goto { target: 2 },
731 source_location: None,
732 coord_x: 0,
733 coord_y: 0,
734 coord_z: 0,
735 });
736
737 let b2 = g.add_node(BasicBlock {
738 id: 2,
739 db_id: None,
740 kind: BlockKind::Exit,
741 statements: vec![],
742 terminator: Terminator::Return,
743 source_location: None,
744 coord_x: 0,
745 coord_y: 0,
746 coord_z: 0,
747 });
748
749 g.add_edge(b0, b1, EdgeType::Fallthrough);
750 g.add_edge(b1, b2, EdgeType::Fallthrough);
751
752 assert!(detect_if_else_patterns(&g).is_empty());
753 assert!(detect_match_patterns(&g).is_empty());
754 }
755}