use crate::cfg::analysis::is_branch_point;
use crate::cfg::EdgeType;
use crate::cfg::{BlockId, Cfg, Terminator};
use petgraph::graph::NodeIndex;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BranchType {
Linear,
Conditional,
MultiWay,
Unknown,
}
#[derive(Debug, Clone)]
pub struct IfElsePattern {
pub condition: NodeIndex,
pub true_branch: NodeIndex,
pub false_branch: NodeIndex,
pub merge_point: Option<NodeIndex>,
}
impl IfElsePattern {
pub fn has_else(&self) -> bool {
self.merge_point.is_some()
}
pub fn size(&self) -> usize {
2 + if self.merge_point.is_some() { 1 } else { 0 }
}
}
#[derive(Debug, Clone)]
pub struct MatchPattern {
pub switch_node: NodeIndex,
pub targets: Vec<NodeIndex>,
pub otherwise: NodeIndex,
}
impl MatchPattern {
pub fn branch_count(&self) -> usize {
self.targets.len() + 1 }
pub fn has_explicit_default(&self) -> bool {
true
}
}
pub fn classify_branch(cfg: &Cfg, node: NodeIndex) -> BranchType {
let successors: Vec<_> = cfg.neighbors(node).collect();
match successors.len() {
0 | 1 => BranchType::Linear,
2 => {
let merge = find_common_successor(cfg, successors[0], successors[1]);
if merge.is_some() {
BranchType::Conditional
} else {
BranchType::Unknown
}
}
3.. => {
if let Some(block) = cfg.node_weight(node) {
if matches!(block.terminator, Terminator::SwitchInt { .. }) {
BranchType::MultiWay
} else {
BranchType::Unknown
}
} else {
BranchType::Unknown
}
}
}
}
fn find_common_successor(cfg: &Cfg, n1: NodeIndex, n2: NodeIndex) -> Option<NodeIndex> {
let mut reachable_from_n1 = HashSet::new();
let mut worklist = vec![n1];
while let Some(node) = worklist.pop() {
if node == n1 || node == n2 {
for succ in cfg.neighbors(node) {
if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
worklist.push(succ);
reachable_from_n1.insert(succ);
}
}
continue;
}
if !reachable_from_n1.insert(node) {
continue;
}
for succ in cfg.neighbors(node) {
if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
worklist.push(succ);
}
}
}
let mut visited = HashSet::new();
let mut worklist = vec![n2];
while let Some(node) = worklist.pop() {
if node == n1 || node == n2 {
for succ in cfg.neighbors(node) {
if succ != n1 && succ != n2 && !visited.contains(&succ) {
if reachable_from_n1.contains(&succ) {
return Some(succ);
}
visited.insert(succ);
worklist.push(succ);
}
}
continue;
}
if reachable_from_n1.contains(&node) {
return Some(node);
}
if !visited.insert(node) {
continue;
}
for succ in cfg.neighbors(node) {
if succ != n1 && succ != n2 && !visited.contains(&succ) {
worklist.push(succ);
}
}
}
None
}
pub fn detect_if_else_patterns(cfg: &Cfg) -> Vec<IfElsePattern> {
let mut patterns = Vec::new();
for branch in cfg.node_indices().filter(|&n| is_branch_point(cfg, n)) {
let successors: Vec<_> = cfg.neighbors(branch).collect();
if successors.len() == 2 {
if let Some(block) = cfg.node_weight(branch) {
if let Terminator::SwitchInt { targets, .. } = &block.terminator {
if targets.len() > 1 {
continue;
}
}
}
let merge_point = find_common_successor(cfg, successors[0], successors[1]);
let (true_branch, false_branch) =
order_branches_by_edge_type(cfg, branch, successors[0], successors[1]);
patterns.push(IfElsePattern {
condition: branch,
true_branch,
false_branch,
merge_point,
});
}
}
patterns
}
fn order_branches_by_edge_type(
cfg: &Cfg,
from: NodeIndex,
succ1: NodeIndex,
succ2: NodeIndex,
) -> (NodeIndex, NodeIndex) {
let edge1_type = cfg
.find_edge(from, succ1)
.and_then(|e| cfg.edge_weight(e).copied());
let edge2_type = cfg
.find_edge(from, succ2)
.and_then(|e| cfg.edge_weight(e).copied());
match (edge1_type, edge2_type) {
(Some(EdgeType::TrueBranch), _) => (succ1, succ2),
(_, Some(EdgeType::TrueBranch)) => (succ2, succ1),
(Some(EdgeType::FalseBranch), _) => (succ2, succ1),
(_, Some(EdgeType::FalseBranch)) => (succ1, succ2),
_ => (succ1, succ2), }
}
pub fn detect_match_patterns(cfg: &Cfg) -> Vec<MatchPattern> {
let mut patterns = Vec::new();
for node in cfg.node_indices() {
if let Some(block) = cfg.node_weight(node) {
if let Terminator::SwitchInt { targets, otherwise } = &block.terminator {
if targets.len() < 2 {
continue;
}
let target_indices: Vec<_> = targets
.iter()
.filter_map(|&id| find_node_by_id(cfg, id))
.collect();
if let Some(otherwise_idx) = find_node_by_id(cfg, *otherwise) {
patterns.push(MatchPattern {
switch_node: node,
targets: target_indices,
otherwise: otherwise_idx,
});
}
}
}
}
patterns
}
fn find_node_by_id(cfg: &Cfg, id: BlockId) -> Option<NodeIndex> {
cfg.node_indices()
.find(|&n| cfg.node_weight(n).map_or(false, |b| b.id == id))
}
pub fn detect_all_patterns(cfg: &Cfg) -> (Vec<IfElsePattern>, Vec<MatchPattern>) {
(detect_if_else_patterns(cfg), detect_match_patterns(cfg))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cfg::{BasicBlock, BlockKind, EdgeType, Terminator};
use petgraph::graph::DiGraph;
fn create_diamond_cfg() -> Cfg {
let mut g = DiGraph::new();
let b0 = g.add_node(BasicBlock {
id: 0,
db_id: None,
kind: BlockKind::Entry,
statements: vec![],
terminator: Terminator::Goto { target: 1 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b1 = g.add_node(BasicBlock {
id: 1,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::SwitchInt {
targets: vec![2],
otherwise: 3,
},
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b2 = g.add_node(BasicBlock {
id: 2,
db_id: None,
kind: BlockKind::Normal,
statements: vec!["true branch".to_string()],
terminator: Terminator::Goto { target: 4 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b3 = g.add_node(BasicBlock {
id: 3,
db_id: None,
kind: BlockKind::Normal,
statements: vec!["false branch".to_string()],
terminator: Terminator::Goto { target: 4 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b4 = g.add_node(BasicBlock {
id: 4,
db_id: None,
kind: BlockKind::Exit,
statements: vec!["merge".to_string()],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
g.add_edge(b0, b1, EdgeType::Fallthrough);
g.add_edge(b1, b2, EdgeType::TrueBranch);
g.add_edge(b1, b3, EdgeType::FalseBranch);
g.add_edge(b2, b4, EdgeType::Fallthrough);
g.add_edge(b3, b4, EdgeType::Fallthrough);
g
}
#[test]
fn test_detect_if_else_diamond() {
let cfg = create_diamond_cfg();
let patterns = detect_if_else_patterns(&cfg);
assert_eq!(patterns.len(), 1);
let pattern = &patterns[0];
assert_eq!(pattern.condition.index(), 1);
assert_eq!(pattern.true_branch.index(), 2);
assert_eq!(pattern.false_branch.index(), 3);
assert_eq!(pattern.merge_point, Some(NodeIndex::new(4)));
assert!(pattern.has_else());
}
#[test]
fn test_classify_branch() {
let cfg = create_diamond_cfg();
assert_eq!(classify_branch(&cfg, NodeIndex::new(0)), BranchType::Linear);
assert_eq!(
classify_branch(&cfg, NodeIndex::new(1)),
BranchType::Conditional
);
assert_eq!(classify_branch(&cfg, NodeIndex::new(2)), BranchType::Linear);
assert_eq!(classify_branch(&cfg, NodeIndex::new(4)), BranchType::Linear);
}
#[test]
fn test_detect_match_patterns() {
let mut g = DiGraph::new();
let b0 = g.add_node(BasicBlock {
id: 0,
db_id: None,
kind: BlockKind::Entry,
statements: vec![],
terminator: Terminator::SwitchInt {
targets: vec![1, 2],
otherwise: 3,
},
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b1 = g.add_node(BasicBlock {
id: 1,
db_id: None,
kind: BlockKind::Exit,
statements: vec!["case 1".to_string()],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b2 = g.add_node(BasicBlock {
id: 2,
db_id: None,
kind: BlockKind::Exit,
statements: vec!["case 2".to_string()],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b3 = g.add_node(BasicBlock {
id: 3,
db_id: None,
kind: BlockKind::Exit,
statements: vec!["default".to_string()],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
g.add_edge(b0, b1, EdgeType::TrueBranch);
g.add_edge(b0, b2, EdgeType::TrueBranch);
g.add_edge(b0, b3, EdgeType::FalseBranch);
let patterns = detect_match_patterns(&g);
assert_eq!(patterns.len(), 1);
let pattern = &patterns[0];
assert_eq!(pattern.switch_node.index(), 0);
assert_eq!(pattern.targets.len(), 2);
assert_eq!(pattern.otherwise.index(), 3);
assert_eq!(pattern.branch_count(), 3);
}
#[test]
fn test_classify_multiway() {
let mut g = DiGraph::new();
let b0 = g.add_node(BasicBlock {
id: 0,
db_id: None,
kind: BlockKind::Entry,
statements: vec![],
terminator: Terminator::SwitchInt {
targets: vec![1, 2],
otherwise: 3,
},
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
for i in 1..=3 {
g.add_node(BasicBlock {
id: i,
db_id: None,
kind: BlockKind::Exit,
statements: vec![],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
}
for i in 1..=3 {
g.add_edge(b0, NodeIndex::new(i), EdgeType::TrueBranch);
}
assert_eq!(classify_branch(&g, NodeIndex::new(0)), BranchType::MultiWay);
}
#[test]
fn test_detect_all_patterns() {
let mut g = DiGraph::new();
let b0 = g.add_node(BasicBlock {
id: 0,
db_id: None,
kind: BlockKind::Entry,
statements: vec![],
terminator: Terminator::Goto { target: 1 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b1 = g.add_node(BasicBlock {
id: 1,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::SwitchInt {
targets: vec![2],
otherwise: 3,
},
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b2 = g.add_node(BasicBlock {
id: 2,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::SwitchInt {
targets: vec![4, 5],
otherwise: 6,
},
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b3 = g.add_node(BasicBlock {
id: 3,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::Goto { target: 7 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b4 = g.add_node(BasicBlock {
id: 4,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::Goto { target: 7 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b5 = g.add_node(BasicBlock {
id: 5,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::Goto { target: 7 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b6 = g.add_node(BasicBlock {
id: 6,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::Goto { target: 7 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b7 = g.add_node(BasicBlock {
id: 7,
db_id: None,
kind: BlockKind::Exit,
statements: vec![],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
g.add_edge(b0, b1, EdgeType::Fallthrough);
g.add_edge(b1, b2, EdgeType::TrueBranch);
g.add_edge(b1, b3, EdgeType::FalseBranch);
g.add_edge(b2, b4, EdgeType::TrueBranch);
g.add_edge(b2, b5, EdgeType::TrueBranch);
g.add_edge(b2, b6, EdgeType::FalseBranch);
g.add_edge(b3, b7, EdgeType::Fallthrough);
g.add_edge(b4, b7, EdgeType::Fallthrough);
g.add_edge(b5, b7, EdgeType::Fallthrough);
g.add_edge(b6, b7, EdgeType::Fallthrough);
let (if_patterns, match_patterns) = detect_all_patterns(&g);
assert_eq!(if_patterns.len(), 1);
assert_eq!(if_patterns[0].condition.index(), 1);
assert_eq!(match_patterns.len(), 1);
assert_eq!(match_patterns[0].switch_node.index(), 2);
assert_eq!(match_patterns[0].targets.len(), 2);
assert_eq!(match_patterns[0].branch_count(), 3);
}
#[test]
fn test_empty_cfg() {
let cfg: Cfg = DiGraph::new();
assert!(detect_if_else_patterns(&cfg).is_empty());
assert!(detect_match_patterns(&cfg).is_empty());
}
#[test]
fn test_linear_cfg_no_patterns() {
let mut g = DiGraph::new();
let b0 = g.add_node(BasicBlock {
id: 0,
db_id: None,
kind: BlockKind::Entry,
statements: vec![],
terminator: Terminator::Goto { target: 1 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b1 = g.add_node(BasicBlock {
id: 1,
db_id: None,
kind: BlockKind::Normal,
statements: vec![],
terminator: Terminator::Goto { target: 2 },
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
let b2 = g.add_node(BasicBlock {
id: 2,
db_id: None,
kind: BlockKind::Exit,
statements: vec![],
terminator: Terminator::Return,
source_location: None,
coord_x: 0,
coord_y: 0,
coord_z: 0,
});
g.add_edge(b0, b1, EdgeType::Fallthrough);
g.add_edge(b1, b2, EdgeType::Fallthrough);
assert!(detect_if_else_patterns(&g).is_empty());
assert!(detect_match_patterns(&g).is_empty());
}
}