Skip to main content

luadec_rust/lua51/lifter/
mod.rs

1use std::collections::{HashMap, HashSet};
2
3use luac_parser::LuaChunk;
4
5use crate::lua51::ast::*;
6use crate::lua51::cfg::ControlFlowGraph;
7use crate::lua51::dominator::{find_loops, DominatorTree, LoopKind, NaturalLoop};
8use crate::lua51::liveness::{compute_liveness, LivenessInfo};
9use crate::lua51::opcodes::OpCode;
10
11mod instruction;
12mod naming;
13mod or_and;
14mod register;
15mod util;
16
17use util::negate_expr;
18
19/// Context for decompiling a single Lua function.
20pub struct Lifter<'a> {
21    chunk: &'a LuaChunk,
22    cfg: ControlFlowGraph,
23    _dom: DominatorTree,
24    loops: Vec<NaturalLoop>,
25    liveness: LivenessInfo,
26    /// Register expressions: tracks what expression is currently held in each register.
27    regs: Vec<Option<Expr>>,
28    /// Pending tables being constructed (register -> accumulated fields).
29    pending_tables: HashMap<u32, Vec<TableField>>,
30    /// Stable references a register value has been assigned into.
31    capture_aliases: HashMap<u32, Expr>,
32    /// Registers that are updated from their own previous value across branches.
33    accumulator_regs: HashSet<u32>,
34    /// Blocks already visited to prevent infinite recursion.
35    visited_blocks: HashSet<usize>,
36    /// Local variable names assigned to registers (reg -> name).
37    local_names: HashMap<u32, String>,
38    /// Registers that have been declared as `local`.
39    declared_locals: HashSet<u32>,
40    /// Number of parameters (these registers are implicitly declared).
41    num_params: u32,
42    /// Whether this chunk has debug info (locals/upvalue names).
43    has_debug_info: bool,
44    /// Upvalue expressions resolved from the parent closure site.
45    resolved_upvalues: Vec<Option<Expr>>,
46    /// Active loop headers being lifted to avoid re-entering the same loop.
47    active_loop_headers: Vec<usize>,
48    /// Exit blocks for active loops, innermost last.
49    active_loop_exits: Vec<usize>,
50}
51
52impl<'a> Lifter<'a> {
53    pub fn decompile(chunk: &'a LuaChunk) -> Function {
54        Self::decompile_with_upvalues(chunk, Vec::new())
55    }
56
57    fn decompile_with_upvalues(
58        chunk: &'a LuaChunk,
59        resolved_upvalues: Vec<Option<Expr>>,
60    ) -> Function {
61        let cfg = ControlFlowGraph::build(&chunk.instructions);
62        let dom = DominatorTree::build(&cfg);
63        let loops = find_loops(&cfg, &dom);
64        let liveness = compute_liveness(&cfg, chunk.max_stack as usize);
65        let has_debug_info = !chunk.locals.is_empty();
66
67        let max_stack = chunk.max_stack as usize;
68        let mut lifter = Lifter {
69            chunk,
70            cfg,
71            _dom: dom,
72            loops,
73            liveness,
74            regs: vec![None; max_stack.max(256)],
75            pending_tables: HashMap::new(),
76            capture_aliases: HashMap::new(),
77            accumulator_regs: HashSet::new(),
78            visited_blocks: HashSet::new(),
79            local_names: HashMap::new(),
80            declared_locals: HashSet::new(),
81            num_params: chunk.num_params as u32,
82            has_debug_info,
83            resolved_upvalues,
84            active_loop_headers: Vec::new(),
85            active_loop_exits: Vec::new(),
86        };
87
88        lifter.accumulator_regs = lifter.find_accumulator_regs();
89
90        let params: Vec<String> = (0..chunk.num_params as u32)
91            .map(|i| {
92                let name = lifter.local_name(i, 0);
93                lifter.local_names.insert(i, name.clone());
94                lifter.declared_locals.insert(i);
95                lifter.set_reg(i, Expr::Name(name.clone()));
96                name
97            })
98            .collect();
99        let is_vararg = chunk.is_vararg.is_some();
100
101        let body = if lifter.cfg.num_blocks() > 0 {
102            lifter.lift_block_range(0, lifter.cfg.num_blocks())
103        } else {
104            Vec::new()
105        };
106
107        Function {
108            params,
109            is_vararg,
110            body,
111        }
112    }
113
114    /// Lift a range of blocks into a statement list, handling control structures.
115    fn lift_block_range(&mut self, start_block: usize, end_block: usize) -> Block {
116        let mut stmts = Vec::new();
117        let mut block_idx = start_block;
118
119        while block_idx < end_block && block_idx < self.cfg.num_blocks() {
120            // Prevent revisiting blocks
121            if self.visited_blocks.contains(&block_idx) {
122                block_idx += 1;
123                continue;
124            }
125
126            // Check if this block is a loop header
127            if let Some(lp) = self.find_loop_at(block_idx) {
128                let lp = lp.clone();
129                let next = self.lift_loop(&lp, &mut stmts);
130                if next <= block_idx {
131                    // Safety: avoid infinite loop
132                    block_idx += 1;
133                } else {
134                    block_idx = next;
135                }
136                continue;
137            }
138
139            self.visited_blocks.insert(block_idx);
140
141            let block = self.cfg.blocks[block_idx].clone();
142            let _last_pc = block.end;
143
144            // Check for conditional (if/elseif/else)
145            if self.is_conditional_block(&block) {
146                let next = self.lift_conditional(block_idx, &mut stmts);
147                if next <= block_idx {
148                    // Safety: avoid infinite loop, lift as normal instructions
149                    self.lift_instructions(block.start, block.end, &mut stmts);
150                    block_idx += 1;
151                } else {
152                    block_idx = next;
153                }
154                continue;
155            }
156
157            // Normal block: lift instructions sequentially
158            self.lift_instructions(block.start, block.end, &mut stmts);
159
160            // Check if this block ends with a JMP to a loop exit (break)
161            let last_inst = self.cfg.instructions[block.end];
162            if last_inst.op == OpCode::Jmp && block.successors.len() == 1 {
163                let target = block.successors[0];
164                if self.current_loop_exit() == Some(target) {
165                    stmts.push(Stat::Break);
166                }
167                // Follow unconditional JMP: skip ahead to target block
168                // (blocks between here and target are only reachable via the target)
169                if target > block_idx + 1 {
170                    block_idx = target;
171                    continue;
172                }
173            }
174
175            block_idx += 1;
176        }
177
178        stmts
179    }
180
181    /// Lift a single loop structure.
182    fn lift_loop(&mut self, lp: &NaturalLoop, stmts: &mut Block) -> usize {
183        match lp.kind {
184            LoopKind::NumericFor => self.lift_numeric_for(lp, stmts),
185            LoopKind::GenericFor => self.lift_generic_for(lp, stmts),
186            LoopKind::WhileRepeat => self.lift_while(lp, stmts),
187        }
188    }
189
190    fn lift_numeric_for(&mut self, lp: &NaturalLoop, stmts: &mut Block) -> usize {
191        let header = &self.cfg.blocks[lp.header].clone();
192        let loop_exit = self.max_loop_block(lp) + 1;
193
194        // Find the FORPREP instruction: it's in the block preceding the header
195        // or in the header itself.  The FORPREP's A register tells us the for-loop slots.
196        let forprep_block = self.find_forprep_block(lp.header);
197        let forprep_inst = if let Some(fb) = forprep_block {
198            let b = &self.cfg.blocks[fb];
199            self.cfg.instructions[b.end]
200        } else {
201            self.cfg.instructions[header.start]
202        };
203
204        let base = forprep_inst.a;
205        let var_name = self.local_name(base + 3, header.start);
206
207        // Lift the pre-loop setup to get init/limit/step
208        if let Some(fb) = forprep_block {
209            if !self.visited_blocks.contains(&fb) {
210                let b = &self.cfg.blocks[fb].clone();
211                self.visited_blocks.insert(fb);
212                // Lift instructions before FORPREP to set up init/limit/step
213                if b.end > b.start {
214                    self.lift_instructions(b.start, b.end - 1, stmts);
215                }
216            }
217        }
218
219        let start_expr = self.reg_expr(base);
220        let limit_expr = self.reg_expr(base + 1);
221        let step_expr = self.reg_expr(base + 2);
222        let step = if matches!(&step_expr, Expr::Number(NumLit::Int(1))) {
223            None
224        } else {
225            Some(step_expr)
226        };
227
228        self.active_loop_headers.push(lp.header);
229        self.active_loop_exits.push(loop_exit);
230        let body = self.lift_block_range(lp.header, lp.latch + 1);
231        self.active_loop_exits.pop();
232        self.active_loop_headers.pop();
233
234        stmts.push(Stat::NumericFor {
235            name: var_name,
236            start: start_expr,
237            limit: limit_expr,
238            step,
239            body,
240        });
241
242        // Return the block after the loop exit
243        loop_exit
244    }
245
246    fn lift_generic_for(&mut self, lp: &NaturalLoop, stmts: &mut Block) -> usize {
247        let header = &self.cfg.blocks[lp.header].clone();
248
249        // Find TFORLOOP instruction in the header or latch block
250        let mut tforloop_inst = None;
251        for pc in header.start..=header.end {
252            if self.cfg.instructions[pc].op == OpCode::TForLoop {
253                tforloop_inst = Some(self.cfg.instructions[pc]);
254                break;
255            }
256        }
257        if tforloop_inst.is_none() {
258            let latch_block = &self.cfg.blocks[lp.latch].clone();
259            for pc in latch_block.start..=latch_block.end {
260                if self.cfg.instructions[pc].op == OpCode::TForLoop {
261                    tforloop_inst = Some(self.cfg.instructions[pc]);
262                    break;
263                }
264            }
265        }
266        let tfl = tforloop_inst.unwrap_or(self.cfg.instructions[header.end]);
267
268        let base = tfl.a;
269        let num_vars = tfl.c();
270
271        let names: Vec<String> = (0..num_vars)
272            .map(|i| self.local_name(base + 3 + i, header.start))
273            .collect();
274
275        // Register loop variable names so the body can reference them
276        for (i, name) in names.iter().enumerate() {
277            let r = base + 3 + i as u32;
278            self.local_names.insert(r, name.clone());
279            self.declared_locals.insert(r);
280            self.set_reg(r, Expr::Name(name.clone()));
281        }
282
283        let iter_expr = self.reg_expr(base);
284
285        // Body: loop blocks excluding the header, sorted by block ID
286        let mut body_blocks: Vec<usize> = lp.body.iter()
287            .filter(|&&b| b != lp.header)
288            .copied()
289            .collect();
290        body_blocks.sort();
291
292        self.active_loop_headers.push(lp.header);
293        self.active_loop_exits.push(self.max_loop_block(lp) + 1);
294        let body = if !body_blocks.is_empty() {
295            let first = *body_blocks.first().unwrap();
296            let last = *body_blocks.last().unwrap();
297            self.lift_block_range(first, last + 1)
298        } else {
299            Vec::new()
300        };
301        self.active_loop_exits.pop();
302        self.active_loop_headers.pop();
303
304        stmts.push(Stat::GenericFor {
305            names,
306            iterators: vec![iter_expr],
307            body,
308        });
309
310        self.max_loop_block(lp) + 1
311    }
312
313    fn lift_while(&mut self, lp: &NaturalLoop, stmts: &mut Block) -> usize {
314        let _header = &self.cfg.blocks[lp.header].clone();
315
316        // Try to extract condition from header block
317        let cond = self.extract_condition(lp.header).unwrap_or(Expr::Bool(true));
318
319        // Body: blocks in the loop excluding header
320        self.active_loop_headers.push(lp.header);
321        self.active_loop_exits.push(self.max_loop_block(lp) + 1);
322        let body_start = lp.header + 1;
323        let body_end = lp.latch + 1;
324        let body = self.lift_block_range(body_start, body_end);
325        self.active_loop_exits.pop();
326        self.active_loop_headers.pop();
327
328        stmts.push(Stat::While { cond, body });
329
330        self.max_loop_block(lp) + 1
331    }
332
333    /// Lift an if/elseif/else chain.
334    fn lift_conditional(&mut self, block_idx: usize, stmts: &mut Block) -> usize {
335        // Try to detect and lift OR/AND short-circuit chains first
336        if let Some(next) = self.try_lift_or_and_chain(block_idx, stmts) {
337            return next;
338        }
339
340        let block = self.cfg.blocks[block_idx].clone();
341
342        // Lift any instructions before the test/JMP at the end of this block.
343        // The test is typically the second-to-last instruction (before JMP).
344        let test_pc = self.find_test_pc(&block);
345        if let Some(tp) = test_pc {
346            if tp > block.start {
347                self.lift_instructions(block.start, tp - 1, stmts);
348            }
349        }
350
351        let cond = self.extract_condition(block_idx).unwrap_or(Expr::Bool(true));
352
353        // Find the two branches
354        let succs = block.successors.clone();
355        if succs.len() != 2 {
356            // Not a proper conditional; just lift as normal
357            self.lift_instructions(block.start, block.end, stmts);
358            return block_idx + 1;
359        }
360
361        // The edge order: ConditionalFalse is added first, ConditionalTrue second
362        // ConditionalFalse = the JMP target (condition NOT met)
363        // ConditionalTrue = fallthrough after JMP (condition met, test passed -> skip JMP)
364        let false_target = succs[0]; // Where JMP goes (condition false)
365        let true_target = succs[1];  // Fallthrough past JMP (condition true)
366
367        // Detect guard clause pattern: `if not cond then return end`
368        // In Lua bytecode this is: TEST/EQ -> JMP past return -> RETURN -> continuation
369        // The true_target block (condition true = skip JMP) is a small block ending in RETURN
370        // and false_target is the continuation.
371        // Only match if the return block is NOT a merge point (has only 1 predecessor).
372        if self.is_return_block(true_target) && false_target > true_target
373            && self.cfg.blocks[true_target].predecessors.len() <= 1
374        {
375            // Guard clause: the "true" path is just a return
376            let guard_body = self.lift_block_range(true_target, true_target + 1);
377            stmts.push(Stat::If {
378                cond,
379                then_block: guard_body,
380                elseif_clauses: Vec::new(),
381                else_block: None,
382            });
383            return false_target;
384        }
385
386        // Detect inverted guard: `if cond then <continue> else return end`
387        // Here false_target is a return block and true_target is the continuation
388        // Only match if the return block is NOT a merge point.
389        if self.is_return_block(false_target) && true_target < false_target
390            && self.cfg.blocks[false_target].predecessors.len() <= 1
391        {
392            let guard_body = self.lift_block_range(false_target, false_target + 1);
393            let inv_cond = negate_expr(cond);
394            stmts.push(Stat::If {
395                cond: inv_cond,
396                then_block: guard_body,
397                elseif_clauses: Vec::new(),
398                else_block: None,
399            });
400            return true_target;
401        }
402
403        // Find the merge point
404        let merge = self.find_merge_point(block_idx, true_target, false_target);
405
406        let then_end = merge.unwrap_or(false_target);
407        let then_block = self.lift_block_range(true_target, then_end);
408
409        let else_block = if let Some(merge) = merge {
410            if false_target < merge {
411                let eb = self.lift_block_range(false_target, merge);
412                if eb.is_empty() { None } else { Some(eb) }
413            } else {
414                None
415            }
416        } else {
417            None
418        };
419
420        stmts.push(Stat::If {
421            cond,
422            then_block,
423            elseif_clauses: Vec::new(),
424            else_block,
425        });
426
427        merge.unwrap_or(false_target.max(true_target) + 1)
428    }
429}