Skip to main content

mir_extractor/dataflow/
cfg.rs

1//! Control Flow Graph (CFG) extraction from MIR
2//!
3//! This module parses MIR basic blocks to build a control flow graph for path-sensitive analysis.
4
5use crate::MirFunction;
6use std::collections::{HashMap, HashSet};
7
8/// Control Flow Graph extracted from MIR basic blocks
9#[derive(Debug, Clone)]
10pub struct ControlFlowGraph {
11    /// Map from block ID to BasicBlock
12    pub blocks: HashMap<String, BasicBlock>,
13    /// Map from block ID to successor block IDs
14    pub edges: HashMap<String, Vec<String>>,
15    /// Entry block (usually "bb0")
16    pub entry_block: String,
17    /// Exit blocks (blocks with Return terminator)
18    pub exit_blocks: Vec<String>,
19}
20
21/// A basic block in the control flow graph
22#[derive(Debug, Clone)]
23pub struct BasicBlock {
24    pub id: String,
25    /// Statements in the block (assignments, calls, etc.)
26    pub statements: Vec<String>,
27    /// Terminator that determines control flow
28    pub terminator: Terminator,
29}
30
31/// Terminator instruction that ends a basic block
32#[derive(Debug, Clone)]
33pub enum Terminator {
34    /// Unconditional jump: goto -> bb5
35    Goto { target: String },
36
37    /// Conditional branch: switchInt(x) -> [0: bb1, otherwise: bb2]
38    SwitchInt {
39        /// Condition variable being tested
40        condition: String,
41        /// Map of values to target blocks
42        targets: Vec<(String, String)>, // (value, target_block)
43        /// Default target (otherwise case)
44        otherwise: Option<String>,
45    },
46
47    /// Function return
48    Return,
49
50    /// Function call with potential branches (return/unwind)
51    Call {
52        /// Target block on successful return
53        return_target: Option<String>,
54        /// Target block on unwind (panic)
55        unwind_target: Option<String>,
56    },
57
58    /// Assertion (similar to Call)
59    Assert {
60        /// Target block if assertion passes
61        success_target: String,
62        /// Target block if assertion fails
63        failure_target: Option<String>,
64    },
65
66    /// Drop value (can unwind)
67    Drop {
68        target: String,
69        unwind_target: Option<String>,
70    },
71
72    /// Unreachable code
73    Unreachable,
74
75    /// Unknown or unparsed terminator
76    Unknown(String),
77}
78
79impl ControlFlowGraph {
80    /// Extract CFG from a MIR function's body
81    pub fn from_mir_function(function: &MirFunction) -> Self {
82        let mut blocks = HashMap::new();
83        let mut edges = HashMap::new();
84        let mut exit_blocks = Vec::new();
85
86        // Parse basic blocks from function body
87        let parsed_blocks = Self::parse_basic_blocks(&function.body);
88
89        for (id, block) in parsed_blocks {
90            // Extract edges from terminator
91            let successors = Self::extract_successors(&block.terminator);
92            if !successors.is_empty() {
93                edges.insert(id.clone(), successors);
94            }
95
96            // Track exit blocks
97            if matches!(block.terminator, Terminator::Return) {
98                exit_blocks.push(id.clone());
99            }
100
101            blocks.insert(id, block);
102        }
103
104        ControlFlowGraph {
105            blocks,
106            edges,
107            entry_block: "bb0".to_string(),
108            exit_blocks,
109        }
110    }
111
112    /// Get the number of basic blocks in the CFG
113    pub fn block_count(&self) -> usize {
114        self.blocks.len()
115    }
116
117    /// Parse basic blocks from MIR body lines
118    fn parse_basic_blocks(body: &[String]) -> HashMap<String, BasicBlock> {
119        let mut blocks = HashMap::new();
120        let mut current_block_id: Option<String> = None;
121        let mut current_statements = Vec::new();
122        let mut current_terminator: Option<Terminator> = None;
123
124        for line in body {
125            let trimmed = line.trim();
126
127            // Check for block start: "bb0: {"
128            if let Some(block_id) = Self::extract_block_id(trimmed) {
129                // Save previous block if any
130                if let Some(id) = current_block_id.take() {
131                    if let Some(term) = current_terminator.take() {
132                        blocks.insert(
133                            id.clone(),
134                            BasicBlock {
135                                id,
136                                statements: std::mem::take(&mut current_statements),
137                                terminator: term,
138                            },
139                        );
140                    }
141                }
142
143                // Start new block
144                current_block_id = Some(block_id);
145                current_statements.clear();
146                current_terminator = None;
147            }
148            // Check for terminator instructions
149            else if trimmed.starts_with("goto ")
150                || trimmed.starts_with("switchInt")
151                || trimmed.starts_with("return")
152                || trimmed.contains(" -> [return:")
153                || trimmed.starts_with("assert(")
154                || trimmed.starts_with("drop(")
155                || trimmed.starts_with("unreachable")
156            {
157                // If this is a call (has " = " and " -> [return:"), also add it as a statement
158                // This ensures we can analyze taint flow through function call results
159                if trimmed.contains(" = ") && trimmed.contains(" -> [return:") {
160                    current_statements.push(trimmed.to_string());
161                }
162                current_terminator = Some(Self::parse_terminator(trimmed));
163            }
164            // Regular statement
165            else if !trimmed.is_empty()
166                && !trimmed.starts_with("}")
167                && !trimmed.starts_with("scope")
168                && !trimmed.starts_with("debug")
169                && !trimmed.starts_with("let")
170            {
171                current_statements.push(trimmed.to_string());
172            }
173        }
174
175        // Save last block
176        if let Some(id) = current_block_id {
177            if let Some(term) = current_terminator {
178                blocks.insert(
179                    id.clone(),
180                    BasicBlock {
181                        id,
182                        statements: current_statements,
183                        terminator: term,
184                    },
185                );
186            }
187        }
188
189        blocks
190    }
191
192    /// Extract block ID from line like "bb0: {"
193    fn extract_block_id(line: &str) -> Option<String> {
194        if line.starts_with("bb") && line.contains(": {") {
195            let id = line.split(": {").next()?;
196            Some(id.to_string())
197        } else {
198            None
199        }
200    }
201
202    /// Parse a terminator instruction
203    fn parse_terminator(line: &str) -> Terminator {
204        let line = line.trim().trim_end_matches(';');
205
206        // goto -> bb5
207        if let Some(rest) = line.strip_prefix("goto -> ") {
208            return Terminator::Goto {
209                target: rest.to_string(),
210            };
211        }
212
213        // return
214        if line == "return" {
215            return Terminator::Return;
216        }
217
218        // unreachable
219        if line == "unreachable" {
220            return Terminator::Unreachable;
221        }
222
223        // switchInt(move _5) -> [0: bb12, otherwise: bb7]
224        if let Some(rest) = line.strip_prefix("switchInt(") {
225            if let Some(paren_end) = rest.find(") -> [") {
226                let condition = rest[..paren_end].to_string();
227                let targets_str = &rest[paren_end + 6..]; // Skip ") -> ["
228
229                let mut targets = Vec::new();
230                let mut otherwise = None;
231
232                // Parse targets: "0: bb12, 1: bb13, otherwise: bb7]"
233                for part in targets_str.trim_end_matches(']').split(", ") {
234                    if let Some((value, block)) = part.split_once(": ") {
235                        if value == "otherwise" {
236                            otherwise = Some(block.to_string());
237                        } else {
238                            targets.push((value.to_string(), block.to_string()));
239                        }
240                    }
241                }
242
243                return Terminator::SwitchInt {
244                    condition,
245                    targets,
246                    otherwise,
247                };
248            }
249        }
250
251        // Function call: some_func() -> [return: bb2, unwind continue]
252        if line.contains(" -> [return:") {
253            let mut return_target = None;
254            let mut unwind_target = None;
255
256            if let Some(arrow_pos) = line.find(" -> [") {
257                let targets_str = &line[arrow_pos + 5..]; // Skip " -> ["
258
259                for part in targets_str.trim_end_matches(']').split(", ") {
260                    if let Some(rest) = part.strip_prefix("return: ") {
261                        return_target = Some(rest.to_string());
262                    } else if let Some(rest) = part.strip_prefix("unwind: ") {
263                        unwind_target = Some(rest.to_string());
264                    }
265                }
266            }
267
268            return Terminator::Call {
269                return_target,
270                unwind_target,
271            };
272        }
273
274        // assert: assert(cond) -> [success: bb5, unwind: bb6]
275        if let Some(rest) = line.strip_prefix("assert(") {
276            if let Some(arrow_pos) = rest.find(" -> [") {
277                let targets_str = &rest[arrow_pos + 5..];
278                let mut success_target = String::new();
279                let mut failure_target = None;
280
281                for part in targets_str.trim_end_matches(']').split(", ") {
282                    if let Some(rest) = part.strip_prefix("success: ") {
283                        success_target = rest.to_string();
284                    } else if let Some(rest) = part.strip_prefix("unwind: ") {
285                        failure_target = Some(rest.to_string());
286                    }
287                }
288
289                return Terminator::Assert {
290                    success_target,
291                    failure_target,
292                };
293            }
294        }
295
296        // drop: drop(_x) -> [return: bb3, unwind: bb4]
297        if let Some(rest) = line.strip_prefix("drop(") {
298            if let Some(arrow_pos) = rest.find(" -> [") {
299                let targets_str = &rest[arrow_pos + 5..];
300                let mut target = String::new();
301                let mut unwind_target = None;
302
303                for part in targets_str.trim_end_matches(']').split(", ") {
304                    if let Some(rest) = part.strip_prefix("return: ") {
305                        target = rest.to_string();
306                    } else if let Some(rest) = part.strip_prefix("unwind: ") {
307                        unwind_target = Some(rest.to_string());
308                    }
309                }
310
311                return Terminator::Drop {
312                    target,
313                    unwind_target,
314                };
315            }
316        }
317
318        // Unknown terminator
319        Terminator::Unknown(line.to_string())
320    }
321
322    /// Extract successor block IDs from a terminator
323    fn extract_successors(terminator: &Terminator) -> Vec<String> {
324        match terminator {
325            Terminator::Goto { target } => vec![target.clone()],
326
327            Terminator::SwitchInt {
328                targets, otherwise, ..
329            } => {
330                let mut successors: Vec<String> =
331                    targets.iter().map(|(_, block)| block.clone()).collect();
332                if let Some(other) = otherwise {
333                    successors.push(other.clone());
334                }
335                successors
336            }
337
338            Terminator::Return | Terminator::Unreachable => vec![],
339
340            Terminator::Call {
341                return_target,
342                unwind_target,
343            } => {
344                let mut successors = Vec::new();
345                if let Some(ret) = return_target {
346                    successors.push(ret.clone());
347                }
348                if let Some(unw) = unwind_target {
349                    successors.push(unw.clone());
350                }
351                successors
352            }
353
354            Terminator::Assert {
355                success_target,
356                failure_target,
357            } => {
358                let mut successors = vec![success_target.clone()];
359                if let Some(fail) = failure_target {
360                    successors.push(fail.clone());
361                }
362                successors
363            }
364
365            Terminator::Drop {
366                target,
367                unwind_target,
368            } => {
369                let mut successors = vec![target.clone()];
370                if let Some(unw) = unwind_target {
371                    successors.push(unw.clone());
372                }
373                successors
374            }
375
376            Terminator::Unknown(_) => vec![],
377        }
378    }
379
380    /// Count branch points in the CFG (blocks with multiple successors)
381    pub fn branch_count(&self) -> usize {
382        self.edges
383            .values()
384            .filter(|successors| successors.len() > 1)
385            .count()
386    }
387
388    /// Check if this CFG is too complex for full path enumeration.
389    /// Returns true if enumeration should be skipped to prevent memory explosion.
390    ///
391    /// The heuristic is based on empirical observation:
392    /// - influxdb3_id: 78 blocks, 20 branches → works fine (~100 MB)
393    /// - influxdb3 serve: 2442 blocks, 381 branches → 58 GB explosion
394    ///
395    /// A CFG with B branches can have up to 2^B paths. With MAX_DEPTH=50,
396    /// we limit path length, but the DFS exploration tree still grows exponentially.
397    ///
398    /// Safe thresholds (empirically derived):
399    /// - MAX_BLOCKS: 500 (well above typical functions, catches mega-closures)
400    /// - MAX_BRANCHES: 100 (limits exponential exploration)
401    pub fn is_too_complex_for_path_enumeration(&self) -> bool {
402        const MAX_BLOCKS_FOR_PATH_ENUM: usize = 500;
403        const MAX_BRANCHES_FOR_PATH_ENUM: usize = 100;
404
405        self.blocks.len() > MAX_BLOCKS_FOR_PATH_ENUM
406            || self.branch_count() > MAX_BRANCHES_FOR_PATH_ENUM
407    }
408
409    /// Enumerate all paths from entry to exit blocks
410    /// Returns paths as sequences of block IDs
411    ///
412    /// For extremely complex CFGs (large async closures, state machines),
413    /// returns an empty vec to prevent memory explosion.
414    ///
415    /// Returns (paths, was_skipped_due_to_complexity)
416    pub fn get_all_paths(&self) -> (Vec<Vec<String>>, bool) {
417        // Guard against memory explosion on huge CFGs
418        // (e.g., influxdb3's serve::command::{closure#0} with 2442 blocks, 381 branches)
419        if self.is_too_complex_for_path_enumeration() {
420            // Return empty - caller should fall back to path-insensitive analysis
421            return (Vec::new(), true);
422        }
423
424        // println!("[DEBUG] get_all_paths: entry={}, exit_blocks={:?}, blocks={}", self.entry_block, self.exit_blocks, self.blocks.len());
425        let mut paths = Vec::new();
426        let mut current_path = Vec::new();
427        let mut visited = HashSet::new();
428
429        // Aggressive limits to prevent memory explosion on large crates
430        // With 1000 functions x 2 analyses x 32 paths = 64,000 max path analyses
431        const MAX_PATHS: usize = 1000;
432        const MAX_DEPTH: usize = 50;
433        self.dfs_paths(
434            &self.entry_block,
435            &mut current_path,
436            &mut visited,
437            &mut paths,
438            0,
439            MAX_DEPTH,
440            MAX_PATHS,
441        );
442
443        // println!("[DEBUG] Found {} paths", paths.len());
444        (paths, false)
445    }
446
447    /// Depth-first search to enumerate paths
448    fn dfs_paths(
449        &self,
450        current_block: &str,
451        current_path: &mut Vec<String>,
452        visited: &mut HashSet<String>,
453        paths: &mut Vec<Vec<String>>,
454        depth: usize,
455        max_depth: usize,
456        max_paths: usize,
457    ) {
458        // Prevent infinite loops, path explosion, and excessive path count
459        if depth > max_depth || visited.contains(current_block) || paths.len() >= max_paths {
460            return;
461        }
462
463        current_path.push(current_block.to_string());
464        visited.insert(current_block.to_string());
465
466        // Check if this is an exit block
467        if self.exit_blocks.contains(&current_block.to_string()) {
468            paths.push(current_path.clone());
469        } else if let Some(successors) = self.edges.get(current_block) {
470            // Explore each successor
471            for successor in successors {
472                if paths.len() >= max_paths {
473                    break; // Stop exploring if we hit the limit
474                }
475                self.dfs_paths(
476                    successor,
477                    current_path,
478                    visited,
479                    paths,
480                    depth + 1,
481                    max_depth,
482                    max_paths,
483                );
484            }
485        }
486
487        // Backtrack
488        current_path.pop();
489        visited.remove(current_block);
490    }
491
492    /// Get the basic block for a given ID
493    pub fn get_block(&self, block_id: &str) -> Option<&BasicBlock> {
494        self.blocks.get(block_id)
495    }
496
497    /// Check if the CFG has any branching (multiple paths)
498    pub fn has_branching(&self) -> bool {
499        self.edges.values().any(|successors| successors.len() > 1)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_parse_goto() {
509        let term = ControlFlowGraph::parse_terminator("goto -> bb5;");
510        match term {
511            Terminator::Goto { target } => assert_eq!(target, "bb5"),
512            _ => panic!("Expected Goto"),
513        }
514    }
515
516    #[test]
517    fn test_parse_return() {
518        let term = ControlFlowGraph::parse_terminator("return;");
519        assert!(matches!(term, Terminator::Return));
520    }
521
522    #[test]
523    fn test_parse_switch_int() {
524        let term =
525            ControlFlowGraph::parse_terminator("switchInt(move _5) -> [0: bb12, otherwise: bb7];");
526        match term {
527            Terminator::SwitchInt {
528                condition,
529                targets,
530                otherwise,
531            } => {
532                assert_eq!(condition, "move _5");
533                assert_eq!(targets.len(), 1);
534                assert_eq!(targets[0], ("0".to_string(), "bb12".to_string()));
535                assert_eq!(otherwise, Some("bb7".to_string()));
536            }
537            _ => panic!("Expected SwitchInt"),
538        }
539    }
540
541    #[test]
542    fn test_parse_call() {
543        let term =
544            ControlFlowGraph::parse_terminator("some_func() -> [return: bb2, unwind continue];");
545        match term {
546            Terminator::Call { return_target, .. } => {
547                assert_eq!(return_target, Some("bb2".to_string()));
548            }
549            _ => panic!("Expected Call"),
550        }
551    }
552}