Skip to main content

mirage/cfg/
patterns.rs

1//! Branching pattern recovery from CFG shapes
2
3use crate::cfg::analysis::is_branch_point;
4use crate::cfg::EdgeType;
5use crate::cfg::{BlockId, Cfg, Terminator};
6use petgraph::graph::NodeIndex;
7use std::collections::HashSet;
8
9/// Classification of a node's branching structure
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum BranchType {
12    /// No branching (0 or 1 successor) - linear control flow
13    Linear,
14    /// Two-way conditional branch (if/else pattern)
15    Conditional,
16    /// Multi-way branch (match/switch)
17    MultiWay,
18    /// Unknown or complex branching
19    Unknown,
20}
21
22/// Represents an if/else structure detected in the CFG
23#[derive(Debug, Clone)]
24pub struct IfElsePattern {
25    /// Condition node (branch point)
26    pub condition: NodeIndex,
27    /// True branch target
28    pub true_branch: NodeIndex,
29    /// False branch target
30    pub false_branch: NodeIndex,
31    /// Merge point (where branches reconverge)
32    /// None if branches don't merge (e.g., early return)
33    pub merge_point: Option<NodeIndex>,
34}
35
36impl IfElsePattern {
37    /// Check if this is a complete if/else (branches merge)
38    pub fn has_else(&self) -> bool {
39        self.merge_point.is_some()
40    }
41
42    /// Get the number of blocks in this pattern
43    pub fn size(&self) -> usize {
44        2 + if self.merge_point.is_some() { 1 } else { 0 }
45    }
46}
47
48/// Represents a match/switch structure detected in the CFG
49#[derive(Debug, Clone)]
50pub struct MatchPattern {
51    /// Switch node (contains SwitchInt terminator)
52    pub switch_node: NodeIndex,
53    /// Branch targets (excluding default/otherwise)
54    pub targets: Vec<NodeIndex>,
55    /// Default/otherwise branch
56    pub otherwise: NodeIndex,
57}
58
59impl MatchPattern {
60    /// Get total number of branches
61    pub fn branch_count(&self) -> usize {
62        self.targets.len() + 1 // +1 for otherwise
63    }
64
65    /// Check if this match is exhaustive (all branches defined)
66    /// This is a simplified check - true exhaustiveness requires type info
67    pub fn has_explicit_default(&self) -> bool {
68        // In our representation, otherwise always exists
69        // A more sophisticated version would check if it's reachable
70        true
71    }
72}
73
74/// Classify a node's branching structure
75///
76/// Returns the type of control flow at this node based on
77/// outgoing edges and terminator type.
78///
79/// # Example
80/// ```rust,no_run
81/// # use mirage_analyzer::cfg::patterns::{classify_branch, BranchType};
82/// # use mirage_analyzer::cfg::Cfg;
83/// # let graph: Cfg = unimplemented!();
84/// for node in graph.node_indices() {
85///     match classify_branch(&graph, node) {
86///         BranchType::Conditional => println!("if/else at {:?}", node),
87///         BranchType::MultiWay => println!("match at {:?}", node),
88///         _ => {}
89///     }
90/// }
91/// ```
92pub fn classify_branch(cfg: &Cfg, node: NodeIndex) -> BranchType {
93    let successors: Vec<_> = cfg.neighbors(node).collect();
94
95    match successors.len() {
96        0 | 1 => BranchType::Linear,
97        2 => {
98            // Check if it's a diamond pattern (if/else)
99            let merge = find_common_successor(cfg, successors[0], successors[1]);
100            if merge.is_some() {
101                BranchType::Conditional
102            } else {
103                // Could be if without else, if with early return
104                BranchType::Unknown
105            }
106        }
107        3.. => {
108            // Multi-way branch - check for SwitchInt
109            if let Some(block) = cfg.node_weight(node) {
110                if matches!(block.terminator, Terminator::SwitchInt { .. }) {
111                    BranchType::MultiWay
112                } else {
113                    BranchType::Unknown
114                }
115            } else {
116                BranchType::Unknown
117            }
118        }
119    }
120}
121
122/// Find common successor of two nodes (merge point)
123///
124/// Returns the first node reachable from both n1 and n2 (excluding
125/// the nodes themselves). This identifies where branches reconverge.
126fn find_common_successor(cfg: &Cfg, n1: NodeIndex, n2: NodeIndex) -> Option<NodeIndex> {
127    // Collect reachable nodes from n1 (excluding n1 and n2)
128    let mut reachable_from_n1 = HashSet::new();
129    let mut worklist = vec![n1];
130
131    while let Some(node) = worklist.pop() {
132        // Skip n1 and n2 themselves - we want their successors
133        if node == n1 || node == n2 {
134            for succ in cfg.neighbors(node) {
135                if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
136                    worklist.push(succ);
137                    reachable_from_n1.insert(succ);
138                }
139            }
140            continue;
141        }
142
143        if !reachable_from_n1.insert(node) {
144            continue;
145        }
146        for succ in cfg.neighbors(node) {
147            if succ != n1 && succ != n2 && !reachable_from_n1.contains(&succ) {
148                worklist.push(succ);
149            }
150        }
151    }
152
153    // Check nodes reachable from n2
154    let mut visited = HashSet::new();
155    let mut worklist = vec![n2];
156
157    while let Some(node) = worklist.pop() {
158        if node == n1 || node == n2 {
159            for succ in cfg.neighbors(node) {
160                if succ != n1 && succ != n2 && !visited.contains(&succ) {
161                    if reachable_from_n1.contains(&succ) {
162                        return Some(succ);
163                    }
164                    visited.insert(succ);
165                    worklist.push(succ);
166                }
167            }
168            continue;
169        }
170
171        if reachable_from_n1.contains(&node) {
172            return Some(node);
173        }
174
175        if !visited.insert(node) {
176            continue;
177        }
178        for succ in cfg.neighbors(node) {
179            if succ != n1 && succ != n2 && !visited.contains(&succ) {
180                worklist.push(succ);
181            }
182        }
183    }
184
185    None
186}
187
188/// Detect if/else patterns by looking for diamond structures
189///
190/// A diamond structure is:
191/// - A branch point with 2 successors
192/// - Both successors eventually merge to a common point
193/// - NOT a multi-way SwitchInt (that's a match, not if/else)
194///
195/// Note: This distinguishes if/else from match by checking if the SwitchInt
196/// has more than 1 target (if/else has 1 target + otherwise = 2 branches,
197/// match has 2+ targets + otherwise = 3+ branches).
198///
199/// Returns an empty vec if no patterns found.
200///
201/// # Example
202/// ```rust,no_run
203/// # use mirage_analyzer::cfg::patterns::detect_if_else_patterns;
204/// # let graph = unimplemented!();
205/// let patterns = detect_if_else_patterns(&graph);
206/// for pattern in patterns {
207///     println!("if/else at {:?}, merges at {:?}", pattern.condition, pattern.merge_point);
208/// }
209/// ```
210pub fn detect_if_else_patterns(cfg: &Cfg) -> Vec<IfElsePattern> {
211    let mut patterns = Vec::new();
212
213    for branch in cfg.node_indices().filter(|&n| is_branch_point(cfg, n)) {
214        let successors: Vec<_> = cfg.neighbors(branch).collect();
215
216        if successors.len() == 2 {
217            // Exclude multi-way SwitchInt terminators (3+ branches) - those are matches
218            // If/else uses SwitchInt with 1 target (2 branches total)
219            if let Some(block) = cfg.node_weight(branch) {
220                if let Terminator::SwitchInt { targets, .. } = &block.terminator {
221                    if targets.len() > 1 {
222                        // This is a match, not if/else
223                        continue;
224                    }
225                }
226            }
227
228            // Check for diamond pattern (merge point)
229            let merge_point = find_common_successor(cfg, successors[0], successors[1]);
230
231            // Determine which branch is true/false based on edge type
232            let (true_branch, false_branch) =
233                order_branches_by_edge_type(cfg, branch, successors[0], successors[1]);
234
235            patterns.push(IfElsePattern {
236                condition: branch,
237                true_branch,
238                false_branch,
239                merge_point,
240            });
241        }
242    }
243
244    patterns
245}
246
247/// Order branches as (true, false) based on edge type
248///
249/// Uses EdgeType to determine which successor is the true branch
250/// and which is the false branch.
251fn order_branches_by_edge_type(
252    cfg: &Cfg,
253    from: NodeIndex,
254    succ1: NodeIndex,
255    succ2: NodeIndex,
256) -> (NodeIndex, NodeIndex) {
257    let edge1_type = cfg
258        .find_edge(from, succ1)
259        .and_then(|e| cfg.edge_weight(e).copied());
260    let edge2_type = cfg
261        .find_edge(from, succ2)
262        .and_then(|e| cfg.edge_weight(e).copied());
263
264    match (edge1_type, edge2_type) {
265        (Some(EdgeType::TrueBranch), _) => (succ1, succ2),
266        (_, Some(EdgeType::TrueBranch)) => (succ2, succ1),
267        (Some(EdgeType::FalseBranch), _) => (succ2, succ1),
268        (_, Some(EdgeType::FalseBranch)) => (succ1, succ2),
269        _ => (succ1, succ2), // Default order if unclear
270    }
271}
272
273/// Detect match patterns by looking for SwitchInt terminators
274///
275/// SwitchInt terminators with 2+ targets indicate multi-way branches (match/switch).
276/// Two-way SwitchInt (1 target + otherwise) represents if/else, not match.
277///
278/// The pattern includes all branch targets plus the default/otherwise.
279///
280/// Returns an empty vec if no patterns found.
281///
282/// # Example
283/// ```rust,no_run
284/// # use mirage_analyzer::cfg::patterns::detect_match_patterns;
285/// # let graph = unimplemented!();
286/// let patterns = detect_match_patterns(&graph);
287/// for pattern in patterns {
288///     println!("match at {:?} with {} branches", pattern.switch_node, pattern.branch_count());
289/// }
290/// ```
291pub fn detect_match_patterns(cfg: &Cfg) -> Vec<MatchPattern> {
292    let mut patterns = Vec::new();
293
294    for node in cfg.node_indices() {
295        if let Some(block) = cfg.node_weight(node) {
296            if let Terminator::SwitchInt { targets, otherwise } = &block.terminator {
297                // Only detect multi-way matches (2+ targets = 3+ branches)
298                // Single-target SwitchInt (2 branches) is if/else, not match
299                if targets.len() < 2 {
300                    continue;
301                }
302
303                // Convert BlockIds to NodeIndices
304                let target_indices: Vec<_> = targets
305                    .iter()
306                    .filter_map(|&id| find_node_by_id(cfg, id))
307                    .collect();
308
309                if let Some(otherwise_idx) = find_node_by_id(cfg, *otherwise) {
310                    patterns.push(MatchPattern {
311                        switch_node: node,
312                        targets: target_indices,
313                        otherwise: otherwise_idx,
314                    });
315                }
316            }
317        }
318    }
319
320    patterns
321}
322
323/// Helper: find NodeIndex by BlockId
324fn find_node_by_id(cfg: &Cfg, id: BlockId) -> Option<NodeIndex> {
325    cfg.node_indices()
326        .find(|&n| cfg.node_weight(n).map_or(false, |b| b.id == id))
327}
328
329/// Get all branching patterns in the CFG
330///
331/// Returns both if/else and match patterns for a complete view
332/// of control flow structure.
333pub fn detect_all_patterns(cfg: &Cfg) -> (Vec<IfElsePattern>, Vec<MatchPattern>) {
334    (detect_if_else_patterns(cfg), detect_match_patterns(cfg))
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::cfg::{BasicBlock, BlockKind, EdgeType, Terminator};
341    use petgraph::graph::DiGraph;
342
343    /// Create a diamond pattern CFG (if/else)
344    fn create_diamond_cfg() -> Cfg {
345        let mut g = DiGraph::new();
346
347        // Block 0: entry, goes to 1
348        let b0 = g.add_node(BasicBlock {
349            id: 0,
350            db_id: None,
351            kind: BlockKind::Entry,
352            statements: vec![],
353            terminator: Terminator::Goto { target: 1 },
354            source_location: None,
355            coord_x: 0,
356            coord_y: 0,
357            coord_z: 0,
358        });
359
360        // Block 1: if condition
361        let b1 = g.add_node(BasicBlock {
362            id: 1,
363            db_id: None,
364            kind: BlockKind::Normal,
365            statements: vec![],
366            terminator: Terminator::SwitchInt {
367                targets: vec![2],
368                otherwise: 3,
369            },
370            source_location: None,
371            coord_x: 0,
372            coord_y: 0,
373            coord_z: 0,
374        });
375
376        // Block 2: true branch
377        let b2 = g.add_node(BasicBlock {
378            id: 2,
379            db_id: None,
380            kind: BlockKind::Normal,
381            statements: vec!["true branch".to_string()],
382            terminator: Terminator::Goto { target: 4 },
383            source_location: None,
384            coord_x: 0,
385            coord_y: 0,
386            coord_z: 0,
387        });
388
389        // Block 3: false branch
390        let b3 = g.add_node(BasicBlock {
391            id: 3,
392            db_id: None,
393            kind: BlockKind::Normal,
394            statements: vec!["false branch".to_string()],
395            terminator: Terminator::Goto { target: 4 },
396            source_location: None,
397            coord_x: 0,
398            coord_y: 0,
399            coord_z: 0,
400        });
401
402        // Block 4: merge point
403        let b4 = g.add_node(BasicBlock {
404            id: 4,
405            db_id: None,
406            kind: BlockKind::Exit,
407            statements: vec!["merge".to_string()],
408            terminator: Terminator::Return,
409            source_location: None,
410            coord_x: 0,
411            coord_y: 0,
412            coord_z: 0,
413        });
414
415        g.add_edge(b0, b1, EdgeType::Fallthrough);
416        g.add_edge(b1, b2, EdgeType::TrueBranch);
417        g.add_edge(b1, b3, EdgeType::FalseBranch);
418        g.add_edge(b2, b4, EdgeType::Fallthrough);
419        g.add_edge(b3, b4, EdgeType::Fallthrough);
420
421        g
422    }
423
424    #[test]
425    fn test_detect_if_else_diamond() {
426        let cfg = create_diamond_cfg();
427        let patterns = detect_if_else_patterns(&cfg);
428
429        assert_eq!(patterns.len(), 1);
430
431        let pattern = &patterns[0];
432        assert_eq!(pattern.condition.index(), 1);
433        assert_eq!(pattern.true_branch.index(), 2);
434        assert_eq!(pattern.false_branch.index(), 3);
435        assert_eq!(pattern.merge_point, Some(NodeIndex::new(4)));
436        assert!(pattern.has_else());
437    }
438
439    #[test]
440    fn test_classify_branch() {
441        let cfg = create_diamond_cfg();
442
443        assert_eq!(classify_branch(&cfg, NodeIndex::new(0)), BranchType::Linear);
444        assert_eq!(
445            classify_branch(&cfg, NodeIndex::new(1)),
446            BranchType::Conditional
447        );
448        assert_eq!(classify_branch(&cfg, NodeIndex::new(2)), BranchType::Linear);
449        assert_eq!(classify_branch(&cfg, NodeIndex::new(4)), BranchType::Linear);
450    }
451
452    #[test]
453    fn test_detect_match_patterns() {
454        let mut g = DiGraph::new();
455
456        // Block 0: match with 3 branches
457        let b0 = g.add_node(BasicBlock {
458            id: 0,
459            db_id: None,
460            kind: BlockKind::Entry,
461            statements: vec![],
462            terminator: Terminator::SwitchInt {
463                targets: vec![1, 2],
464                otherwise: 3,
465            },
466            source_location: None,
467            coord_x: 0,
468            coord_y: 0,
469            coord_z: 0,
470        });
471
472        let b1 = g.add_node(BasicBlock {
473            id: 1,
474            db_id: None,
475            kind: BlockKind::Exit,
476            statements: vec!["case 1".to_string()],
477            terminator: Terminator::Return,
478            source_location: None,
479            coord_x: 0,
480            coord_y: 0,
481            coord_z: 0,
482        });
483
484        let b2 = g.add_node(BasicBlock {
485            id: 2,
486            db_id: None,
487            kind: BlockKind::Exit,
488            statements: vec!["case 2".to_string()],
489            terminator: Terminator::Return,
490            source_location: None,
491            coord_x: 0,
492            coord_y: 0,
493            coord_z: 0,
494        });
495
496        let b3 = g.add_node(BasicBlock {
497            id: 3,
498            db_id: None,
499            kind: BlockKind::Exit,
500            statements: vec!["default".to_string()],
501            terminator: Terminator::Return,
502            source_location: None,
503            coord_x: 0,
504            coord_y: 0,
505            coord_z: 0,
506        });
507
508        g.add_edge(b0, b1, EdgeType::TrueBranch);
509        g.add_edge(b0, b2, EdgeType::TrueBranch);
510        g.add_edge(b0, b3, EdgeType::FalseBranch);
511
512        let patterns = detect_match_patterns(&g);
513        assert_eq!(patterns.len(), 1);
514
515        let pattern = &patterns[0];
516        assert_eq!(pattern.switch_node.index(), 0);
517        assert_eq!(pattern.targets.len(), 2);
518        assert_eq!(pattern.otherwise.index(), 3);
519        assert_eq!(pattern.branch_count(), 3);
520    }
521
522    #[test]
523    fn test_classify_multiway() {
524        let mut g = DiGraph::new();
525
526        let b0 = g.add_node(BasicBlock {
527            id: 0,
528            db_id: None,
529            kind: BlockKind::Entry,
530            statements: vec![],
531            terminator: Terminator::SwitchInt {
532                targets: vec![1, 2],
533                otherwise: 3,
534            },
535            source_location: None,
536            coord_x: 0,
537            coord_y: 0,
538            coord_z: 0,
539        });
540
541        for i in 1..=3 {
542            g.add_node(BasicBlock {
543                id: i,
544                db_id: None,
545                kind: BlockKind::Exit,
546                statements: vec![],
547                terminator: Terminator::Return,
548                source_location: None,
549                coord_x: 0,
550                coord_y: 0,
551                coord_z: 0,
552            });
553        }
554
555        for i in 1..=3 {
556            g.add_edge(b0, NodeIndex::new(i), EdgeType::TrueBranch);
557        }
558
559        assert_eq!(classify_branch(&g, NodeIndex::new(0)), BranchType::MultiWay);
560    }
561
562    #[test]
563    fn test_detect_all_patterns() {
564        let mut g = DiGraph::new();
565
566        // Create CFG with if/else and multi-way match
567
568        // Entry
569        let b0 = g.add_node(BasicBlock {
570            id: 0,
571            db_id: None,
572            kind: BlockKind::Entry,
573            statements: vec![],
574            terminator: Terminator::Goto { target: 1 },
575            source_location: None,
576            coord_x: 0,
577            coord_y: 0,
578            coord_z: 0,
579        });
580
581        // If/else at block 1
582        let b1 = g.add_node(BasicBlock {
583            id: 1,
584            db_id: None,
585            kind: BlockKind::Normal,
586            statements: vec![],
587            terminator: Terminator::SwitchInt {
588                targets: vec![2],
589                otherwise: 3,
590            },
591            source_location: None,
592            coord_x: 0,
593            coord_y: 0,
594            coord_z: 0,
595        });
596
597        // True branch (leads to match)
598        let b2 = g.add_node(BasicBlock {
599            id: 2,
600            db_id: None,
601            kind: BlockKind::Normal,
602            statements: vec![],
603            // Multi-way match with 2 targets (3 branches total)
604            terminator: Terminator::SwitchInt {
605                targets: vec![4, 5],
606                otherwise: 6,
607            },
608            source_location: None,
609            coord_x: 0,
610            coord_y: 0,
611            coord_z: 0,
612        });
613
614        // False branch
615        let b3 = g.add_node(BasicBlock {
616            id: 3,
617            db_id: None,
618            kind: BlockKind::Normal,
619            statements: vec![],
620            terminator: Terminator::Goto { target: 7 },
621            source_location: None,
622            coord_x: 0,
623            coord_y: 0,
624            coord_z: 0,
625        });
626
627        // Match branches
628        let b4 = g.add_node(BasicBlock {
629            id: 4,
630            db_id: None,
631            kind: BlockKind::Normal,
632            statements: vec![],
633            terminator: Terminator::Goto { target: 7 },
634            source_location: None,
635            coord_x: 0,
636            coord_y: 0,
637            coord_z: 0,
638        });
639
640        let b5 = g.add_node(BasicBlock {
641            id: 5,
642            db_id: None,
643            kind: BlockKind::Normal,
644            statements: vec![],
645            terminator: Terminator::Goto { target: 7 },
646            source_location: None,
647            coord_x: 0,
648            coord_y: 0,
649            coord_z: 0,
650        });
651
652        let b6 = g.add_node(BasicBlock {
653            id: 6,
654            db_id: None,
655            kind: BlockKind::Normal,
656            statements: vec![],
657            terminator: Terminator::Goto { target: 7 },
658            source_location: None,
659            coord_x: 0,
660            coord_y: 0,
661            coord_z: 0,
662        });
663
664        // Merge point
665        let b7 = g.add_node(BasicBlock {
666            id: 7,
667            db_id: None,
668            kind: BlockKind::Exit,
669            statements: vec![],
670            terminator: Terminator::Return,
671            source_location: None,
672            coord_x: 0,
673            coord_y: 0,
674            coord_z: 0,
675        });
676
677        g.add_edge(b0, b1, EdgeType::Fallthrough);
678        g.add_edge(b1, b2, EdgeType::TrueBranch);
679        g.add_edge(b1, b3, EdgeType::FalseBranch);
680        g.add_edge(b2, b4, EdgeType::TrueBranch);
681        g.add_edge(b2, b5, EdgeType::TrueBranch);
682        g.add_edge(b2, b6, EdgeType::FalseBranch);
683        g.add_edge(b3, b7, EdgeType::Fallthrough);
684        g.add_edge(b4, b7, EdgeType::Fallthrough);
685        g.add_edge(b5, b7, EdgeType::Fallthrough);
686        g.add_edge(b6, b7, EdgeType::Fallthrough);
687
688        let (if_patterns, match_patterns) = detect_all_patterns(&g);
689
690        // Should detect 1 if/else (at block 1) - all branches merge at 7
691        assert_eq!(if_patterns.len(), 1);
692        assert_eq!(if_patterns[0].condition.index(), 1);
693
694        // Should detect 1 match (at block 2) - has 2 targets (3 branches)
695        assert_eq!(match_patterns.len(), 1);
696        assert_eq!(match_patterns[0].switch_node.index(), 2);
697        assert_eq!(match_patterns[0].targets.len(), 2);
698        assert_eq!(match_patterns[0].branch_count(), 3);
699    }
700
701    #[test]
702    fn test_empty_cfg() {
703        let cfg: Cfg = DiGraph::new();
704        assert!(detect_if_else_patterns(&cfg).is_empty());
705        assert!(detect_match_patterns(&cfg).is_empty());
706    }
707
708    #[test]
709    fn test_linear_cfg_no_patterns() {
710        let mut g = DiGraph::new();
711
712        // Linear: 0 -> 1 -> 2
713        let b0 = g.add_node(BasicBlock {
714            id: 0,
715            db_id: None,
716            kind: BlockKind::Entry,
717            statements: vec![],
718            terminator: Terminator::Goto { target: 1 },
719            source_location: None,
720            coord_x: 0,
721            coord_y: 0,
722            coord_z: 0,
723        });
724
725        let b1 = g.add_node(BasicBlock {
726            id: 1,
727            db_id: None,
728            kind: BlockKind::Normal,
729            statements: vec![],
730            terminator: Terminator::Goto { target: 2 },
731            source_location: None,
732            coord_x: 0,
733            coord_y: 0,
734            coord_z: 0,
735        });
736
737        let b2 = g.add_node(BasicBlock {
738            id: 2,
739            db_id: None,
740            kind: BlockKind::Exit,
741            statements: vec![],
742            terminator: Terminator::Return,
743            source_location: None,
744            coord_x: 0,
745            coord_y: 0,
746            coord_z: 0,
747        });
748
749        g.add_edge(b0, b1, EdgeType::Fallthrough);
750        g.add_edge(b1, b2, EdgeType::Fallthrough);
751
752        assert!(detect_if_else_patterns(&g).is_empty());
753        assert!(detect_match_patterns(&g).is_empty());
754    }
755}