Skip to main content

luadec_rust/lua51/
cfg.rs

1use std::collections::BTreeSet;
2use std::fmt;
3
4use crate::lua51::instruction::Instruction;
5use crate::lua51::opcodes::OpCode;
6
7/// Type of edge connecting two basic blocks.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum EdgeKind {
10    /// Sequential fallthrough to next block.
11    Fallthrough,
12    /// Unconditional jump (JMP).
13    Jump,
14    /// Conditional true branch (condition met, skip next instruction).
15    ConditionalTrue,
16    /// Conditional false branch (condition not met, fallthrough to JMP).
17    ConditionalFalse,
18    /// FORLOOP back-edge (loop continues).
19    ForLoopBack,
20    /// FORLOOP exit (loop finished).
21    ForLoopExit,
22    /// FORPREP jump to loop body end (to FORLOOP).
23    ForPrep,
24    /// TFORLOOP continue (iterator returned non-nil).
25    TForLoopBack,
26    /// TFORLOOP exit (iterator returned nil).
27    TForLoopExit,
28}
29
30/// A directed edge in the control flow graph.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct Edge {
33    pub from: usize,
34    pub to: usize,
35    pub kind: EdgeKind,
36}
37
38/// A basic block: a maximal sequence of instructions with single entry / single exit.
39#[derive(Debug, Clone)]
40pub struct BasicBlock {
41    pub id: usize,
42    /// First instruction index (inclusive).
43    pub start: usize,
44    /// Last instruction index (inclusive).
45    pub end: usize,
46    /// Successor block IDs.
47    pub successors: Vec<usize>,
48    /// Predecessor block IDs.
49    pub predecessors: Vec<usize>,
50}
51
52/// Control flow graph for a single Lua function/chunk.
53#[derive(Debug)]
54pub struct ControlFlowGraph {
55    pub blocks: Vec<BasicBlock>,
56    pub edges: Vec<Edge>,
57    pub instructions: Vec<Instruction>,
58    /// Map from instruction index to block ID (for quick lookup).
59    inst_to_block: Vec<usize>,
60}
61
62impl ControlFlowGraph {
63    /// Build a CFG from a slice of raw 32-bit Lua 5.1 instructions.
64    pub fn build(raw_instructions: &[u32]) -> Self {
65        let instructions: Vec<Instruction> = raw_instructions
66            .iter()
67            .map(|&raw| Instruction::decode(raw).expect("invalid instruction"))
68            .collect();
69
70        if instructions.is_empty() {
71            return ControlFlowGraph {
72                blocks: Vec::new(),
73                edges: Vec::new(),
74                instructions,
75                inst_to_block: Vec::new(),
76            };
77        }
78
79        // Phase 1: Identify block leaders (first instruction of each block).
80        let leaders = find_leaders(&instructions);
81
82        // Phase 2: Create basic blocks from leader set.
83        let blocks = create_blocks(&leaders, instructions.len());
84
85        // Phase 3: Build inst_to_block map.
86        let mut inst_to_block = vec![0usize; instructions.len()];
87        for (block_id, block) in blocks.iter().enumerate() {
88            for pc in block.start..=block.end {
89                inst_to_block[pc] = block_id;
90            }
91        }
92
93        // Phase 4: Build edges.
94        let edges = build_edges(&blocks, &instructions, &inst_to_block);
95
96        // Phase 5: Fill in successor/predecessor lists.
97        let mut blocks = blocks;
98        for edge in &edges {
99            blocks[edge.from].successors.push(edge.to);
100            blocks[edge.to].predecessors.push(edge.from);
101        }
102
103        ControlFlowGraph {
104            blocks,
105            edges,
106            instructions,
107            inst_to_block,
108        }
109    }
110
111    /// Get the block ID that contains the given instruction index.
112    pub fn block_of(&self, pc: usize) -> usize {
113        self.inst_to_block[pc]
114    }
115
116    /// Get a block by its ID.
117    pub fn block(&self, id: usize) -> &BasicBlock {
118        &self.blocks[id]
119    }
120
121    /// Get instructions in a block.
122    pub fn block_instructions(&self, block: &BasicBlock) -> &[Instruction] {
123        &self.instructions[block.start..=block.end]
124    }
125
126    /// Number of basic blocks.
127    pub fn num_blocks(&self) -> usize {
128        self.blocks.len()
129    }
130}
131
132/// Identify all leader instruction indices.
133///
134/// Leaders are:
135/// - Index 0 (entry point)
136/// - Jump targets (pc + 1 + sBx for JMP/FORLOOP/FORPREP)
137/// - Instruction after a jump/branch/return (fallthrough point)
138/// - Instruction after a test opcode's following JMP target
139fn find_leaders(instructions: &[Instruction]) -> BTreeSet<usize> {
140    let len = instructions.len();
141    let mut leaders = BTreeSet::new();
142    leaders.insert(0);
143
144    for (pc, inst) in instructions.iter().enumerate() {
145        match inst.op {
146            // Unconditional jump
147            OpCode::Jmp => {
148                let target = (pc as i32 + 1 + inst.sbx()) as usize;
149                if target < len {
150                    leaders.insert(target);
151                }
152                // Instruction after JMP is a leader (if reachable)
153                if pc + 1 < len {
154                    leaders.insert(pc + 1);
155                }
156            }
157
158            // Comparison tests: EQ, LT, LE skip next instruction (which must be JMP)
159            // The conditional creates two paths:
160            //   - condition true: skip the JMP, go to pc+2
161            //   - condition false: execute the JMP at pc+1
162            OpCode::Eq | OpCode::Lt | OpCode::Le | OpCode::Test | OpCode::TestSet => {
163                // pc+1 should be a JMP; its target is a leader
164                if pc + 1 < len {
165                    let next = &instructions[pc + 1];
166                    debug_assert_eq!(next.op, OpCode::Jmp, "test at pc={} not followed by JMP", pc);
167                    let jmp_target = (pc as i32 + 2 + next.sbx()) as usize;
168                    if jmp_target < len {
169                        leaders.insert(jmp_target);
170                    }
171                }
172                // pc+2 is the fallthrough after skipping the JMP
173                if pc + 2 < len {
174                    leaders.insert(pc + 2);
175                }
176            }
177
178            // Numeric for loop: FORLOOP jumps back if loop continues
179            OpCode::ForLoop => {
180                let target = (pc as i32 + 1 + inst.sbx()) as usize;
181                if target < len {
182                    leaders.insert(target);
183                }
184                if pc + 1 < len {
185                    leaders.insert(pc + 1);
186                }
187            }
188
189            // FORPREP: jump forward to FORLOOP
190            OpCode::ForPrep => {
191                let target = (pc as i32 + 1 + inst.sbx()) as usize;
192                if target < len {
193                    leaders.insert(target);
194                }
195            }
196
197            // Generic for loop: TFORLOOP either continues or exits
198            OpCode::TForLoop => {
199                // TFORLOOP is followed by a JMP (back-edge)
200                if pc + 1 < len {
201                    let next = &instructions[pc + 1];
202                    debug_assert_eq!(
203                        next.op,
204                        OpCode::Jmp,
205                        "TFORLOOP at pc={} not followed by JMP",
206                        pc
207                    );
208                    let jmp_target = (pc as i32 + 2 + next.sbx()) as usize;
209                    if jmp_target < len {
210                        leaders.insert(jmp_target);
211                    }
212                }
213                if pc + 2 < len {
214                    leaders.insert(pc + 2);
215                }
216            }
217
218            // Return terminates the block; next instruction (if any) starts a new one
219            OpCode::Return | OpCode::TailCall => {
220                if pc + 1 < len {
221                    leaders.insert(pc + 1);
222                }
223            }
224
225            _ => {}
226        }
227    }
228
229    leaders
230}
231
232/// Create BasicBlock structs from the sorted leader set.
233fn create_blocks(leaders: &BTreeSet<usize>, num_instructions: usize) -> Vec<BasicBlock> {
234    let leader_vec: Vec<usize> = leaders.iter().copied().collect();
235    let mut blocks = Vec::with_capacity(leader_vec.len());
236
237    for (i, &start) in leader_vec.iter().enumerate() {
238        let end = if i + 1 < leader_vec.len() {
239            leader_vec[i + 1] - 1
240        } else {
241            num_instructions - 1
242        };
243        blocks.push(BasicBlock {
244            id: i,
245            start,
246            end,
247            successors: Vec::new(),
248            predecessors: Vec::new(),
249        });
250    }
251
252    blocks
253}
254
255/// Build edges between blocks based on the last instruction of each block.
256fn build_edges(
257    blocks: &[BasicBlock],
258    instructions: &[Instruction],
259    inst_to_block: &[usize],
260) -> Vec<Edge> {
261    let mut edges = Vec::new();
262    let len = instructions.len();
263
264    for block in blocks {
265        let last_pc = block.end;
266        let last = &instructions[last_pc];
267        let block_id = block.id;
268
269        match last.op {
270            OpCode::Jmp => {
271                let target = (last_pc as i32 + 1 + last.sbx()) as usize;
272
273                // Check if this JMP is preceded by a test opcode (EQ/LT/LE/TEST/TESTSET)
274                // within the same block - making this a conditional branch.
275                let is_conditional = last_pc > block.start && {
276                    let prev = &instructions[last_pc - 1];
277                    matches!(
278                        prev.op,
279                        OpCode::Eq
280                            | OpCode::Lt
281                            | OpCode::Le
282                            | OpCode::Test
283                            | OpCode::TestSet
284                    )
285                };
286
287                let is_tforloop = last_pc > block.start
288                    && instructions[last_pc - 1].op == OpCode::TForLoop;
289
290                if is_conditional {
291                    // False branch: follow the JMP
292                    if target < len {
293                        edges.push(Edge {
294                            from: block_id,
295                            to: inst_to_block[target],
296                            kind: EdgeKind::ConditionalFalse,
297                        });
298                    }
299                    // True branch: skip the JMP, go to pc+1
300                    if last_pc + 1 < len {
301                        edges.push(Edge {
302                            from: block_id,
303                            to: inst_to_block[last_pc + 1],
304                            kind: EdgeKind::ConditionalTrue,
305                        });
306                    }
307                } else if is_tforloop {
308                    // Back-edge: jump back to iterator call
309                    if target < len {
310                        edges.push(Edge {
311                            from: block_id,
312                            to: inst_to_block[target],
313                            kind: EdgeKind::TForLoopBack,
314                        });
315                    }
316                    // Exit: fall through past the JMP
317                    if last_pc + 1 < len {
318                        edges.push(Edge {
319                            from: block_id,
320                            to: inst_to_block[last_pc + 1],
321                            kind: EdgeKind::TForLoopExit,
322                        });
323                    }
324                } else {
325                    if target < len {
326                        edges.push(Edge {
327                            from: block_id,
328                            to: inst_to_block[target],
329                            kind: EdgeKind::Jump,
330                        });
331                    }
332                }
333            }
334
335            // Test opcodes where the JMP is NOT in this block (rare edge case)
336            OpCode::Eq | OpCode::Lt | OpCode::Le | OpCode::Test | OpCode::TestSet => {
337                let jmp_pc = last_pc + 1;
338                if jmp_pc < len {
339                    edges.push(Edge {
340                        from: block_id,
341                        to: inst_to_block[jmp_pc],
342                        kind: EdgeKind::Fallthrough,
343                    });
344                }
345            }
346
347            OpCode::ForLoop => {
348                // Back-edge: loop continues, jump to loop body start
349                let target = (last_pc as i32 + 1 + last.sbx()) as usize;
350                if target < len {
351                    edges.push(Edge {
352                        from: block_id,
353                        to: inst_to_block[target],
354                        kind: EdgeKind::ForLoopBack,
355                    });
356                }
357                // Exit: loop ends, fall through
358                if last_pc + 1 < len {
359                    edges.push(Edge {
360                        from: block_id,
361                        to: inst_to_block[last_pc + 1],
362                        kind: EdgeKind::ForLoopExit,
363                    });
364                }
365            }
366
367            OpCode::ForPrep => {
368                // Jump to FORLOOP
369                let target = (last_pc as i32 + 1 + last.sbx()) as usize;
370                if target < len {
371                    edges.push(Edge {
372                        from: block_id,
373                        to: inst_to_block[target],
374                        kind: EdgeKind::ForPrep,
375                    });
376                }
377            }
378
379            OpCode::TForLoop => {
380                // TFORLOOP at last_pc, but normally the block also contains
381                // the following JMP. Check if last_pc is actually the block end
382                // (the JMP case is handled in the Jmp arm when preceded by TFORLOOP).
383                let jmp_pc = last_pc + 1;
384                if jmp_pc < len {
385                    edges.push(Edge {
386                        from: block_id,
387                        to: inst_to_block[jmp_pc],
388                        kind: EdgeKind::Fallthrough,
389                    });
390                }
391            }
392
393            OpCode::Return | OpCode::TailCall => {
394                // No successors (terminal block)
395            }
396
397            _ => {
398                // Default: fallthrough to next block
399                if last_pc + 1 < len {
400                    edges.push(Edge {
401                        from: block_id,
402                        to: inst_to_block[last_pc + 1],
403                        kind: EdgeKind::Fallthrough,
404                    });
405                }
406            }
407        }
408    }
409
410    edges
411}
412
413impl fmt::Display for ControlFlowGraph {
414    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415        for block in &self.blocks {
416            writeln!(
417                f,
418                "Block {} [pc {}..{}] pred={:?} succ={:?}",
419                block.id, block.start, block.end, block.predecessors, block.successors
420            )?;
421            for pc in block.start..=block.end {
422                writeln!(f, "  [{:4}] {}", pc, self.instructions[pc])?;
423            }
424        }
425        if !self.edges.is_empty() {
426            writeln!(f, "Edges:")?;
427            for edge in &self.edges {
428                writeln!(f, "  B{} -> B{} ({:?})", edge.from, edge.to, edge.kind)?;
429            }
430        }
431        Ok(())
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::lua51::instruction::{encode_abc, encode_abx, encode_asbx};
439    use crate::lua51::opcodes::OpCode;
440
441    #[test]
442    fn linear_code() {
443        // LOADK 0 0; LOADK 1 1; RETURN 0 1
444        let code = vec![
445            encode_abx(OpCode::LoadK, 0, 0),
446            encode_abx(OpCode::LoadK, 1, 1),
447            encode_abc(OpCode::Return, 0, 1, 0),
448        ];
449        let cfg = ControlFlowGraph::build(&code);
450        assert_eq!(cfg.num_blocks(), 1);
451        assert_eq!(cfg.blocks[0].start, 0);
452        assert_eq!(cfg.blocks[0].end, 2);
453        assert!(cfg.edges.is_empty());
454    }
455
456    #[test]
457    fn simple_jump() {
458        // [0] JMP +1      -> jumps to [2]
459        // [1] LOADK 0 0
460        // [2] RETURN 0 1
461        let code = vec![
462            encode_asbx(OpCode::Jmp, 0, 1),
463            encode_abx(OpCode::LoadK, 0, 0),
464            encode_abc(OpCode::Return, 0, 1, 0),
465        ];
466        let cfg = ControlFlowGraph::build(&code);
467        // Leaders: 0, 1, 2 -> 3 blocks
468        assert_eq!(cfg.num_blocks(), 3);
469        // Block 0 jumps to block 2
470        assert!(cfg.edges.iter().any(|e| e.from == 0
471            && e.to == 2
472            && e.kind == EdgeKind::Jump));
473    }
474
475    #[test]
476    fn conditional_branch() {
477        // [0] EQ 0 0 1     -- if R(0) == R(1)
478        // [1] JMP +1       -- jump to [3] if condition false
479        // [2] LOADK 2 0    -- condition true path
480        // [3] RETURN 0 1
481        let code = vec![
482            encode_abc(OpCode::Eq, 0, 0, 1),
483            encode_asbx(OpCode::Jmp, 0, 1),
484            encode_abx(OpCode::LoadK, 2, 0),
485            encode_abc(OpCode::Return, 0, 1, 0),
486        ];
487        let cfg = ControlFlowGraph::build(&code);
488        // Leaders: 0, 2, 3 -> 3 blocks
489        // Block 0: [0,1] (EQ + JMP)
490        assert_eq!(cfg.num_blocks(), 3);
491        assert_eq!(cfg.blocks[0].start, 0);
492        assert_eq!(cfg.blocks[0].end, 1);
493        // Should have ConditionalTrue -> block 1 (pc=2) and ConditionalFalse -> block 2 (pc=3)
494        assert!(cfg
495            .edges
496            .iter()
497            .any(|e| e.from == 0 && e.kind == EdgeKind::ConditionalTrue));
498        assert!(cfg
499            .edges
500            .iter()
501            .any(|e| e.from == 0 && e.kind == EdgeKind::ConditionalFalse));
502    }
503
504    #[test]
505    fn numeric_for_loop() {
506        // [0] LOADK 0 0      -- init
507        // [1] LOADK 1 1      -- limit
508        // [2] LOADK 2 2      -- step
509        // [3] FORPREP 0 +1   -- jump to [5] (FORLOOP)
510        // [4] LOADK 4 0      -- loop body
511        // [5] FORLOOP 0 -2   -- jump back to [4] if loop continues
512        // [6] RETURN 0 1
513        let code = vec![
514            encode_abx(OpCode::LoadK, 0, 0),
515            encode_abx(OpCode::LoadK, 1, 1),
516            encode_abx(OpCode::LoadK, 2, 2),
517            encode_asbx(OpCode::ForPrep, 0, 1),
518            encode_abx(OpCode::LoadK, 4, 0),
519            encode_asbx(OpCode::ForLoop, 0, -2),
520            encode_abc(OpCode::Return, 0, 1, 0),
521        ];
522        let cfg = ControlFlowGraph::build(&code);
523        // FORLOOP has back-edge and exit edge
524        assert!(cfg
525            .edges
526            .iter()
527            .any(|e| e.kind == EdgeKind::ForLoopBack));
528        assert!(cfg
529            .edges
530            .iter()
531            .any(|e| e.kind == EdgeKind::ForLoopExit));
532        assert!(cfg.edges.iter().any(|e| e.kind == EdgeKind::ForPrep));
533    }
534
535    #[test]
536    fn empty_function() {
537        let cfg = ControlFlowGraph::build(&[]);
538        assert_eq!(cfg.num_blocks(), 0);
539        assert!(cfg.edges.is_empty());
540    }
541}