Skip to main content

luadec_rust/lua51/emit/
mod.rs

1use std::fmt::Write;
2
3use crate::lua51::ast::*;
4
5mod normalize;
6mod cache;
7#[cfg(test)]
8mod tests;
9
10/// Emit a Lua function as source code.
11pub fn emit_function(func: &Function) -> String {
12    let mut out = String::new();
13    let body = normalize::normalize_block(&func.body);
14    let body = fold_elseif_chain(&body);
15    emit_block(&body, &mut out, 0);
16    out
17}
18
19/// Emit a complete chunk (wrapping the main function body).
20pub fn emit_chunk(func: &Function) -> String {
21    let mut out = String::new();
22    let body = normalize::normalize_block(&func.body);
23    let body = fold_elseif_chain(&body);
24    emit_block(&body, &mut out, 0);
25    // Remove trailing blank lines
26    while out.ends_with("\n\n") {
27        out.pop();
28    }
29    if !out.ends_with('\n') {
30        out.push('\n');
31    }
32    out
33}
34
35/// Fold nested if-else chains into elseif clauses.
36/// Transforms: `if A then B else if C then D else E end end`
37/// into:       `if A then B elseif C then D else E end`
38fn fold_elseif_chain(block: &Block) -> Block {
39    block.iter().map(|stat| fold_stat(stat)).collect()
40}
41
42fn fold_stat(stat: &Stat) -> Stat {
43    match stat {
44        Stat::If {
45            cond,
46            then_block,
47            elseif_clauses,
48            else_block,
49        } => {
50            let then_block = fold_elseif_chain(then_block);
51            let mut new_elseifs: Vec<(Expr, Block)> = elseif_clauses
52                .iter()
53                .map(|(c, b)| (c.clone(), fold_elseif_chain(b)))
54                .collect();
55
56            // Try to flatten: if else_block is a single `if` statement, merge it
57            let new_else = if let Some(eb) = else_block {
58                let folded = fold_elseif_chain(eb);
59                if folded.len() == 1 {
60                    if let Stat::If {
61                        cond: inner_cond,
62                        then_block: inner_then,
63                        elseif_clauses: inner_elseifs,
64                        else_block: inner_else,
65                    } = &folded[0]
66                    {
67                        // Merge into elseif
68                        new_elseifs.push((inner_cond.clone(), inner_then.clone()));
69                        new_elseifs.extend(inner_elseifs.iter().cloned());
70                        inner_else.clone()
71                    } else {
72                        Some(folded)
73                    }
74                } else {
75                    Some(folded)
76                }
77            } else {
78                None
79            };
80
81            Stat::If {
82                cond: cond.clone(),
83                then_block,
84                elseif_clauses: new_elseifs,
85                else_block: new_else,
86            }
87        }
88        Stat::While { cond, body } => Stat::While {
89            cond: cond.clone(),
90            body: fold_elseif_chain(body),
91        },
92        Stat::Repeat { body, cond } => Stat::Repeat {
93            body: fold_elseif_chain(body),
94            cond: cond.clone(),
95        },
96        Stat::NumericFor { name, start, limit, step, body } => Stat::NumericFor {
97            name: name.clone(),
98            start: start.clone(),
99            limit: limit.clone(),
100            step: step.clone(),
101            body: fold_elseif_chain(body),
102        },
103        Stat::GenericFor { names, iterators, body } => Stat::GenericFor {
104            names: names.clone(),
105            iterators: iterators.clone(),
106            body: fold_elseif_chain(body),
107        },
108        Stat::DoBlock(body) => Stat::DoBlock(fold_elseif_chain(body)),
109        other => other.clone(),
110    }
111}
112
113fn emit_block(block: &Block, out: &mut String, indent: usize) {
114    for (i, stat) in block.iter().enumerate() {
115        // Add blank lines between top-level logical groups
116        if indent == 0 && i > 0 {
117            let prev = &block[i - 1];
118            if should_separate(prev, stat) {
119                out.push('\n');
120            }
121        }
122        emit_stat(stat, out, indent);
123    }
124}
125
126/// Decide whether to insert a blank line between two consecutive top-level statements.
127fn should_separate(prev: &Stat, curr: &Stat) -> bool {
128    // Always separate after a function definition
129    if is_func_def(prev) {
130        return true;
131    }
132    // Separate before a function definition
133    if is_func_def(curr) {
134        return true;
135    }
136    // Separate between different "kinds" of statements
137    // (e.g., after a block of assignments before a call, or vice versa)
138    false
139}
140
141fn is_func_def(stat: &Stat) -> bool {
142    match stat {
143        Stat::Assign { values, .. } => {
144            values.len() == 1 && matches!(&values[0], Expr::FunctionDef(_))
145        }
146        Stat::LocalAssign { exprs, .. } => {
147            exprs.len() == 1 && matches!(&exprs[0], Expr::FunctionDef(_))
148        }
149        _ => false,
150    }
151}
152
153fn emit_stat(stat: &Stat, out: &mut String, indent: usize) {
154    let pad = "  ".repeat(indent);
155    match stat {
156        Stat::LocalAssign { names, exprs } => {
157            write!(out, "{}local {}", pad, names.join(", ")).unwrap();
158            if !exprs.is_empty() {
159                out.push_str(" = ");
160                emit_expr_list(exprs, out);
161            }
162            out.push('\n');
163        }
164        Stat::Assign { targets, values } => {
165            // Pretty-print `name = function(...) end` as `function name(...) end`
166            if targets.len() == 1 && values.len() == 1 {
167                if let Expr::FunctionDef(func) = &values[0] {
168                    let name = match &targets[0] {
169                        Expr::Global(n) => Some(n.clone()),
170                        Expr::Name(n) => Some(n.clone()),
171                        Expr::Field(table, field) => {
172                            // t.method = function(...) -> function t.method(...)
173                            let mut s = String::new();
174                            emit_expr(table, &mut s, 10);
175                            s.push('.');
176                            s.push_str(field);
177                            Some(s)
178                        }
179                        _ => None,
180                    };
181                    if let Some(fname) = name {
182                        write!(out, "{}function {}(", pad, fname).unwrap();
183                        let mut params = func.params.join(", ");
184                        if func.is_vararg {
185                            if !params.is_empty() {
186                                params.push_str(", ");
187                            }
188                            params.push_str("...");
189                        }
190                        out.push_str(&params);
191                        out.push_str(")\n");
192                        emit_block(&func.body, out, indent + 1);
193                        writeln!(out, "{}end", pad).unwrap();
194                        return;
195                    }
196                }
197            }
198            write!(out, "{}", pad).unwrap();
199            emit_expr_list(targets, out);
200            out.push_str(" = ");
201            emit_expr_list(values, out);
202            out.push('\n');
203        }
204        Stat::Call(call) => {
205            write!(out, "{}", pad).unwrap();
206            emit_call(call, out);
207            out.push('\n');
208        }
209        Stat::DoBlock(body) => {
210            writeln!(out, "{}do", pad).unwrap();
211            emit_block(body, out, indent + 1);
212            writeln!(out, "{}end", pad).unwrap();
213        }
214        Stat::While { cond, body } => {
215            write!(out, "{}while ", pad).unwrap();
216            emit_expr(cond, out, 0);
217            out.push_str(" do\n");
218            emit_block(body, out, indent + 1);
219            writeln!(out, "{}end", pad).unwrap();
220        }
221        Stat::Repeat { body, cond } => {
222            writeln!(out, "{}repeat", pad).unwrap();
223            emit_block(body, out, indent + 1);
224            write!(out, "{}until ", pad).unwrap();
225            emit_expr(cond, out, 0);
226            out.push('\n');
227        }
228        Stat::If {
229            cond,
230            then_block,
231            elseif_clauses,
232            else_block,
233        } => {
234            write!(out, "{}if ", pad).unwrap();
235            emit_expr(cond, out, 0);
236            out.push_str(" then\n");
237            emit_block(then_block, out, indent + 1);
238            for (ec, eb) in elseif_clauses {
239                write!(out, "{}elseif ", pad).unwrap();
240                emit_expr(ec, out, 0);
241                out.push_str(" then\n");
242                emit_block(eb, out, indent + 1);
243            }
244            if let Some(eb) = else_block {
245                writeln!(out, "{}else", pad).unwrap();
246                emit_block(eb, out, indent + 1);
247            }
248            writeln!(out, "{}end", pad).unwrap();
249        }
250        Stat::NumericFor {
251            name,
252            start,
253            limit,
254            step,
255            body,
256        } => {
257            write!(out, "{}for {} = ", pad, name).unwrap();
258            emit_expr(start, out, 0);
259            out.push_str(", ");
260            emit_expr(limit, out, 0);
261            if let Some(s) = step {
262                out.push_str(", ");
263                emit_expr(s, out, 0);
264            }
265            out.push_str(" do\n");
266            emit_block(body, out, indent + 1);
267            writeln!(out, "{}end", pad).unwrap();
268        }
269        Stat::GenericFor {
270            names,
271            iterators,
272            body,
273        } => {
274            write!(out, "{}for {} in ", pad, names.join(", ")).unwrap();
275            emit_expr_list(iterators, out);
276            out.push_str(" do\n");
277            emit_block(body, out, indent + 1);
278            writeln!(out, "{}end", pad).unwrap();
279        }
280        Stat::Return(exprs) => {
281            write!(out, "{}return", pad).unwrap();
282            if !exprs.is_empty() {
283                out.push(' ');
284                emit_expr_list(exprs, out);
285            }
286            out.push('\n');
287        }
288        Stat::Break => {
289            writeln!(out, "{}break", pad).unwrap();
290        }
291        Stat::Comment(text) => {
292            writeln!(out, "{}-- {}", pad, text).unwrap();
293        }
294    }
295}
296
297fn emit_expr_list(exprs: &[Expr], out: &mut String) {
298    for (i, e) in exprs.iter().enumerate() {
299        if i > 0 {
300            out.push_str(", ");
301        }
302        emit_expr(e, out, 0);
303    }
304}
305
306/// Emit an expression, handling operator precedence and parenthesization.
307/// `parent_prec` is the precedence of the enclosing operator (0 = no enclosing op).
308fn emit_expr(expr: &Expr, out: &mut String, parent_prec: u8) {
309    match expr {
310        Expr::Nil => out.push_str("nil"),
311        Expr::Bool(true) => out.push_str("true"),
312        Expr::Bool(false) => out.push_str("false"),
313        Expr::Number(n) => emit_number(n, out),
314        Expr::StringLit(s) => emit_string(s, out),
315        Expr::VarArg => out.push_str("..."),
316        Expr::Name(n) => out.push_str(n),
317        Expr::Global(n) => out.push_str(n),
318        Expr::Register(r) => write!(out, "r{}", r).unwrap(),
319        Expr::Upvalue(u) => write!(out, "upval{}", u).unwrap(),
320        Expr::Index(table, key) => {
321            emit_expr(table, out, 10);
322            out.push('[');
323            emit_expr(key, out, 0);
324            out.push(']');
325        }
326        Expr::Field(table, field) => {
327            emit_expr(table, out, 10);
328            out.push('.');
329            out.push_str(field);
330        }
331        Expr::BinOp(op, lhs, rhs) => {
332            let prec = op.precedence();
333            let needs_parens = prec < parent_prec;
334            if needs_parens {
335                out.push('(');
336            }
337            emit_expr(lhs, out, prec);
338            write!(out, " {} ", op.symbol()).unwrap();
339            // Right-associative: right child needs prec+1 to avoid unnecessary parens
340            let rhs_prec = if op.is_right_assoc() { prec } else { prec + 1 };
341            emit_expr(rhs, out, rhs_prec);
342            if needs_parens {
343                out.push(')');
344            }
345        }
346        Expr::UnOp(op, operand) => {
347            let prec = op.precedence();
348            let needs_parens = prec < parent_prec;
349            if needs_parens {
350                out.push('(');
351            }
352            out.push_str(op.symbol());
353            emit_expr(operand, out, prec);
354            if needs_parens {
355                out.push(')');
356            }
357        }
358        Expr::FuncCall(call) => {
359            emit_call(call, out);
360        }
361        Expr::MethodCall(call) => {
362            emit_call(call, out);
363        }
364        Expr::FunctionDef(func) => {
365            emit_function_def(func, out, false);
366        }
367        Expr::Table(fields) => {
368            emit_table(fields, out);
369        }
370    }
371}
372
373fn emit_call(call: &CallExpr, out: &mut String) {
374    emit_expr(&call.func, out, 10);
375    out.push('(');
376    emit_expr_list(&call.args, out);
377    out.push(')');
378}
379
380fn emit_function_def(func: &Function, out: &mut String, _as_stat: bool) {
381    out.push_str("function(");
382    let mut params = func.params.join(", ");
383    if func.is_vararg {
384        if !params.is_empty() {
385            params.push_str(", ");
386        }
387        params.push_str("...");
388    }
389    out.push_str(&params);
390    out.push_str(")\n");
391
392    // Estimate current indent from trailing whitespace
393    let current_indent = count_trailing_indent(out);
394    emit_block(&func.body, out, current_indent + 1);
395
396    let pad = "  ".repeat(current_indent);
397    write!(out, "{}end", pad).unwrap();
398}
399
400fn emit_table(fields: &[TableField], out: &mut String) {
401    if fields.is_empty() {
402        out.push_str("{}");
403        return;
404    }
405
406    // Estimate inline length to decide single-line vs multi-line
407    let current_indent = count_trailing_indent(out);
408    let inline = emit_table_inline(fields);
409    // Use multi-line if: inline is too long, or has nested tables, or many fields
410    let use_multiline = inline.len() > 80 || fields.len() > 4 && inline.len() > 60;
411
412    if !use_multiline {
413        out.push_str(&inline);
414        return;
415    }
416
417    // Multi-line format
418    let inner_pad = "  ".repeat(current_indent + 1);
419    let outer_pad = "  ".repeat(current_indent);
420    out.push_str("{\n");
421    for (i, field) in fields.iter().enumerate() {
422        out.push_str(&inner_pad);
423        match field {
424            TableField::IndexField(key, val) => {
425                out.push('[');
426                emit_expr(key, out, 0);
427                out.push_str("] = ");
428                emit_expr(val, out, 0);
429            }
430            TableField::NameField(name, val) => {
431                out.push_str(name);
432                out.push_str(" = ");
433                emit_expr(val, out, 0);
434            }
435            TableField::Value(val) => {
436                emit_expr(val, out, 0);
437            }
438        }
439        if i + 1 < fields.len() {
440            out.push(',');
441        }
442        out.push('\n');
443    }
444    write!(out, "{}}}", outer_pad).unwrap();
445}
446
447/// Emit a table as a single-line string (for length estimation).
448fn emit_table_inline(fields: &[TableField]) -> String {
449    let mut s = String::new();
450    s.push('{');
451    for (i, field) in fields.iter().enumerate() {
452        if i > 0 {
453            s.push_str(", ");
454        }
455        match field {
456            TableField::IndexField(key, val) => {
457                s.push('[');
458                emit_expr(key, &mut s, 0);
459                s.push_str("] = ");
460                emit_expr(val, &mut s, 0);
461            }
462            TableField::NameField(name, val) => {
463                s.push_str(name);
464                s.push_str(" = ");
465                emit_expr(val, &mut s, 0);
466            }
467            TableField::Value(val) => {
468                emit_expr(val, &mut s, 0);
469            }
470        }
471    }
472    s.push('}');
473    s
474}
475
476fn emit_number(n: &NumLit, out: &mut String) {
477    match n {
478        NumLit::Int(v) => write!(out, "{}", v).unwrap(),
479        NumLit::Float(v) => {
480            if v.fract() == 0.0 && v.abs() < 1e15 {
481                // Emit as integer-looking float if it has no fractional part
482                write!(out, "{}", *v as i64).unwrap();
483            } else {
484                write!(out, "{}", v).unwrap();
485            }
486        }
487    }
488}
489
490/// Emit a string literal with proper escaping.
491/// Supports multi-line string detection, GBK decoding, and non-UTF-8 byte escapes.
492fn emit_string(bytes: &[u8], out: &mut String) {
493    // Try UTF-8 first
494    if let Ok(s) = std::str::from_utf8(bytes) {
495        emit_string_content(s, bytes, out);
496        return;
497    }
498
499    // Try GBK decoding (common in JX3 Lua scripts)
500    let (decoded, _, had_errors) = encoding_rs::GBK.decode(bytes);
501    if !had_errors {
502        emit_string_content(&decoded, bytes, out);
503        return;
504    }
505
506    // Fallback: emit with byte escapes for non-ASCII
507    out.push('"');
508    for &b in bytes {
509        emit_byte_escaped(b, out);
510    }
511    out.push('"');
512}
513
514/// Emit a string that has been successfully decoded to text.
515fn emit_string_content(text: &str, _raw: &[u8], out: &mut String) {
516    let has_newlines = text.contains('\n');
517    let has_long_bracket_close = text.contains("]]");
518    let is_printable = text.chars().all(|c| !c.is_control() || c == '\n' || c == '\r' || c == '\t');
519
520    if has_newlines && !has_long_bracket_close && is_printable {
521        // Use [[...]] long string
522        out.push_str("[[");
523        if text.starts_with('\n') {
524            out.push('\n');
525        }
526        out.push_str(text);
527        out.push_str("]]");
528        return;
529    }
530
531    // Use quoted string with escapes
532    out.push('"');
533    for ch in text.chars() {
534        match ch {
535            '\\' => out.push_str("\\\\"),
536            '"' => out.push_str("\\\""),
537            '\n' => out.push_str("\\n"),
538            '\r' => out.push_str("\\r"),
539            '\t' => out.push_str("\\t"),
540            '\0' => out.push_str("\\0"),
541            '\x07' => out.push_str("\\a"),
542            '\x08' => out.push_str("\\b"),
543            '\x0C' => out.push_str("\\f"),
544            '\x0B' => out.push_str("\\v"),
545            c if c >= ' ' && c <= '~' => out.push(c),
546            c if !c.is_control() => out.push(c), // printable Unicode (incl. CJK)
547            c => {
548                // Control character: emit as byte escapes
549                let mut buf = [0u8; 4];
550                let s = c.encode_utf8(&mut buf);
551                for &b in s.as_bytes() {
552                    write!(out, "\\{}", b).unwrap();
553                }
554            }
555        }
556    }
557    out.push('"');
558}
559
560fn emit_byte_escaped(b: u8, out: &mut String) {
561    match b {
562        b'\\' => out.push_str("\\\\"),
563        b'"' => out.push_str("\\\""),
564        b'\n' => out.push_str("\\n"),
565        b'\r' => out.push_str("\\r"),
566        b'\t' => out.push_str("\\t"),
567        b'\0' => out.push_str("\\0"),
568        0x07 => out.push_str("\\a"),
569        0x08 => out.push_str("\\b"),
570        0x0C => out.push_str("\\f"),
571        0x0B => out.push_str("\\v"),
572        0x20..=0x7E => out.push(b as char),
573        _ => {
574            write!(out, "\\{}", b).unwrap();
575        }
576    }
577}
578
579fn count_trailing_indent(s: &str) -> usize {
580    // Count indent level from last newline
581    if let Some(last_nl) = s.rfind('\n') {
582        let after = &s[last_nl + 1..];
583        let spaces = after.len() - after.trim_start().len();
584        spaces / 2
585    } else {
586        0
587    }
588}