Skip to main content

proof_engine/scripting/
compiler.rs

1//! Bytecode compiler — walks the AST and emits `Op` instructions into a `Proto`.
2//!
3//! # Architecture
4//! A single-pass recursive descent over the AST maintains:
5//! - A flat local-variable stack per function (slot numbers are u16).
6//! - A scope depth counter for block exits.
7//! - A list of upvalue descriptors per function (used by `Closure`).
8//! - Per-loop break/continue patch lists.
9//!
10//! Jump offsets are relative (i32): positive = forward, negative = backward.
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use super::ast::*;
15use super::vm::Value as VmValue;
16
17// ── Constant pool ─────────────────────────────────────────────────────────────
18
19/// A compile-time constant value.
20#[derive(Debug, Clone, PartialEq)]
21pub enum Constant {
22    Nil,
23    Bool(bool),
24    Int(i64),
25    Float(f64),
26    Str(String),
27}
28
29// ── Op (bytecode instruction set) ────────────────────────────────────────────
30
31/// VM instruction.  Operands are embedded to allow the interpreter to avoid
32/// secondary table lookups on the hot path.
33#[allow(non_camel_case_types)]
34#[derive(Debug, Clone, PartialEq)]
35pub enum Op {
36    // ── Literals ──────────────────────────────────────────────────────────────
37    /// Push nil.
38    Nil,
39    /// Push true.
40    True,
41    /// Push false.
42    False,
43    /// Push `proto.constants[idx]`.
44    Const(u32),
45
46    // ── Stack ─────────────────────────────────────────────────────────────────
47    Pop,
48    Dup,
49    Swap,
50
51    // ── Locals ────────────────────────────────────────────────────────────────
52    GetLocal(u16),
53    SetLocal(u16),
54
55    // ── Upvalues ──────────────────────────────────────────────────────────────
56    GetUpval(u16),
57    SetUpval(u16),
58
59    // ── Globals (constant-indexed by string name) ─────────────────────────────
60    GetGlobal(u32),
61    SetGlobal(u32),
62
63    // ── Tables ────────────────────────────────────────────────────────────────
64    NewTable,
65    /// `SetField(kidx)`: pop val; table = peek; table[const_str] = val.
66    SetField(u32),
67    /// `GetField(kidx)`: pop table; push table[const_str].
68    GetField(u32),
69    /// `SetIndex`: pop val, key; table = peek; table[key] = val.
70    SetIndex,
71    /// `GetIndex`: pop key, table; push table[key].
72    GetIndex,
73    /// `TableAppend`: pop val; table = peek; append val to array part.
74    TableAppend,
75    /// `SetList(n)`: pop n values; table = peek; assign t[1..n].
76    SetList(u16),
77
78    // ── Unary ─────────────────────────────────────────────────────────────────
79    Len,
80    Neg,
81    Not,
82    BitNot,
83
84    // ── Arithmetic ────────────────────────────────────────────────────────────
85    Add, Sub, Mul, Div, IDiv, Mod, Pow,
86    Concat,         // pops 2, pushes concatenated string
87
88    // ── Comparison ────────────────────────────────────────────────────────────
89    Eq, NotEq, Lt, LtEq, Gt, GtEq,
90
91    // ── Bitwise ───────────────────────────────────────────────────────────────
92    BitAnd, BitOr, BitXor, Shl, Shr,
93
94    // ── Control flow ──────────────────────────────────────────────────────────
95    /// Relative unconditional jump.  `ip += offset` (can be negative).
96    Jump(i32),
97    /// Peek top; if truthy jump (no pop).
98    JumpIf(i32),
99    /// Peek top; if falsy jump (no pop).
100    JumpIfNot(i32),
101    /// Pop top; if falsy jump — used for short-circuit `and`.
102    JumpIfNotPop(i32),
103    /// Pop top; if truthy jump — used for short-circuit `or`.
104    JumpIfPop(i32),
105
106    // ── Calls & returns ───────────────────────────────────────────────────────
107    /// `Call(nargs, nret)`: pop nargs + callee; push nret results (0 = all).
108    Call(u8, u8),
109    /// `CallMethod(name_kidx, nargs, nret)`: obj on stack; method = const_str.
110    CallMethod(u32, u8, u8),
111    /// `Return(n)`: pop n values and return (0 = return all).
112    Return(u8),
113    /// Tail-call optimisation.
114    TailCall(u8),
115
116    // ── Closures ──────────────────────────────────────────────────────────────
117    /// Create a closure from `proto.protos[idx]`, capturing upvalues.
118    Closure(u32),
119    /// Close the upvalue at local slot `slot`.
120    Close(u16),
121
122    // ── Iterators ─────────────────────────────────────────────────────────────
123    /// Prepare generic-for: push iterator state.
124    ForPrep(u16),
125    /// Advance generic-for; pop results if exhausted (implied jump offset in
126    /// combination with `ForStepJump`).
127    ForStep,
128    /// Like ForStep but with a jump offset for the exhausted case.
129    ForStepJump(i32),
130    /// Push and validate [start, limit, step] for numeric-for.
131    NumForInit,
132    /// Advance numeric-for; jump by offset if done.
133    NumForStep(i32),
134
135    // ── Varargs ───────────────────────────────────────────────────────────────
136    /// Push `n` vararg values (0 = all).
137    Vararg(u8),
138
139    // ── Debug ─────────────────────────────────────────────────────────────────
140    LineInfo(u32),
141}
142
143// ── Proto (function prototype) ────────────────────────────────────────────────
144
145/// A compiled function — the unit of bytecode.
146#[derive(Debug, Clone)]
147pub struct Proto {
148    pub name:          String,
149    pub code:          Vec<Op>,
150    pub constants:     Vec<Constant>,
151    pub protos:        Vec<Proto>,      // nested closure prototypes
152    pub param_count:   u8,
153    pub is_vararg:     bool,
154    pub upvalue_count: u16,
155    pub max_stack:     u16,
156}
157
158impl Proto {
159    fn new(name: impl Into<String>) -> Self {
160        Proto {
161            name:          name.into(),
162            code:          Vec::new(),
163            constants:     Vec::new(),
164            protos:        Vec::new(),
165            param_count:   0,
166            is_vararg:     false,
167            upvalue_count: 0,
168            max_stack:     0,
169        }
170    }
171
172    /// Add a constant, deduplicating where possible.
173    pub fn add_const(&mut self, c: Constant) -> u32 {
174        for (i, existing) in self.constants.iter().enumerate() {
175            if *existing == c { return i as u32; }
176        }
177        let idx = self.constants.len() as u32;
178        self.constants.push(c);
179        idx
180    }
181
182    fn emit(&mut self, op: Op) -> usize {
183        self.code.push(op);
184        self.code.len() - 1
185    }
186
187    fn patch_jump(&mut self, instr_idx: usize) {
188        let target = self.code.len() as i32;
189        let from   = instr_idx as i32 + 1;
190        let offset = target - from;
191        match &mut self.code[instr_idx] {
192            Op::Jump(o) | Op::JumpIf(o) | Op::JumpIfNot(o)
193            | Op::JumpIfNotPop(o) | Op::JumpIfPop(o)
194            | Op::NumForStep(o) | Op::ForStepJump(o) => *o = offset,
195            _ => {}
196        }
197    }
198}
199
200// ── Instruction (VM-facing bytecode) ─────────────────────────────────────────
201
202/// Runtime instruction set emitted by `Compiler::compile_script`.
203#[derive(Debug, Clone, PartialEq)]
204pub enum Instruction {
205    // Literals
206    LoadNil,
207    LoadBool(bool),
208    LoadInt(i64),
209    LoadFloat(f64),
210    LoadStr(String),
211    LoadConst(usize),
212    // Stack
213    Pop,
214    Dup,
215    Swap,
216    // Locals / upvalues / globals
217    GetLocal(usize),
218    SetLocal(usize),
219    GetUpvalue(usize),
220    SetUpvalue(usize),
221    GetGlobal(String),
222    SetGlobal(String),
223    // Tables
224    NewTable,
225    SetField(String),
226    GetField(String),
227    SetIndex,
228    GetIndex,
229    TableAppend,
230    // Unary
231    Len,
232    Neg,
233    Not,
234    BitNot,
235    // Arithmetic
236    Add, Sub, Mul, Div, IDiv, Mod, Pow,
237    Concat,
238    // Bitwise
239    BitAnd, BitOr, BitXor, Shl, Shr,
240    // Comparison
241    Eq, NotEq, Lt, LtEq, Gt, GtEq,
242    // Control flow
243    Jump(isize),
244    JumpIf(isize),
245    JumpIfNot(isize),
246    /// Peek; if not truthy jump (and leave value); if truthy pop and continue.
247    JumpIfNotPop(isize),
248    /// Peek; if truthy jump (and leave value); if falsy pop and continue.
249    JumpIfPop(isize),
250    JumpAbs(usize),
251    // Calls
252    Call(usize),
253    CallMethod(String, usize),
254    Return(usize),
255    // Closures
256    MakeFunction(usize),
257    MakeClosure(usize, Vec<(bool, usize)>),
258    CloseUpvalue(usize),
259    // Iterators
260    ForPrep(usize),
261    /// Advance numeric for-loop: `local_idx` = loop var slot, `jump_offset` = exit jump.
262    ForStep(usize, isize),
263    Nop,
264}
265
266// ── Chunk (VM-facing function prototype) ─────────────────────────────────────
267
268/// A compiled function ready for the VM.
269#[derive(Debug, Clone)]
270pub struct Chunk {
271    pub name:         String,
272    pub instructions: Vec<Instruction>,
273    pub constants:    Vec<VmValue>,
274    pub sub_chunks:   Vec<Arc<Chunk>>,
275    pub param_count:  u8,
276    pub is_vararg:    bool,
277}
278
279fn const_to_value(c: &Constant) -> VmValue {
280    match c {
281        Constant::Nil      => VmValue::Nil,
282        Constant::Bool(b)  => VmValue::Bool(*b),
283        Constant::Int(i)   => VmValue::Int(*i),
284        Constant::Float(f) => VmValue::Float(*f),
285        Constant::Str(s)   => VmValue::Str(Arc::new(s.clone())),
286    }
287}
288
289fn proto_to_chunk(proto: &Proto) -> Arc<Chunk> {
290    let instructions = proto.code.iter()
291        .map(|op| op_to_instruction(op, &proto.constants))
292        .collect();
293    let constants = proto.constants.iter().map(const_to_value).collect();
294    let sub_chunks = proto.protos.iter().map(proto_to_chunk).collect();
295    Arc::new(Chunk {
296        name:         proto.name.clone(),
297        instructions,
298        constants,
299        sub_chunks,
300        param_count:  proto.param_count,
301        is_vararg:    proto.is_vararg,
302    })
303}
304
305fn op_to_instruction(op: &Op, constants: &[Constant]) -> Instruction {
306    let get_str = |kidx: u32| -> String {
307        match constants.get(kidx as usize) {
308            Some(Constant::Str(s)) => s.clone(),
309            _ => String::new(),
310        }
311    };
312    match op {
313        Op::Nil               => Instruction::LoadNil,
314        Op::True              => Instruction::LoadBool(true),
315        Op::False             => Instruction::LoadBool(false),
316        Op::Const(idx)        => Instruction::LoadConst(*idx as usize),
317        Op::Pop               => Instruction::Pop,
318        Op::Dup               => Instruction::Dup,
319        Op::Swap              => Instruction::Swap,
320        Op::GetLocal(s)       => Instruction::GetLocal(*s as usize),
321        Op::SetLocal(s)       => Instruction::SetLocal(*s as usize),
322        Op::GetUpval(i)       => Instruction::GetUpvalue(*i as usize),
323        Op::SetUpval(i)       => Instruction::SetUpvalue(*i as usize),
324        Op::GetGlobal(k)      => Instruction::GetGlobal(get_str(*k)),
325        Op::SetGlobal(k)      => Instruction::SetGlobal(get_str(*k)),
326        Op::NewTable          => Instruction::NewTable,
327        Op::SetField(k)       => Instruction::SetField(get_str(*k)),
328        Op::GetField(k)       => Instruction::GetField(get_str(*k)),
329        Op::SetIndex          => Instruction::SetIndex,
330        Op::GetIndex          => Instruction::GetIndex,
331        Op::TableAppend       => Instruction::TableAppend,
332        Op::SetList(_)        => Instruction::Nop,
333        Op::Len               => Instruction::Len,
334        Op::Neg               => Instruction::Neg,
335        Op::Not               => Instruction::Not,
336        Op::BitNot            => Instruction::BitNot,
337        Op::Add               => Instruction::Add,
338        Op::Sub               => Instruction::Sub,
339        Op::Mul               => Instruction::Mul,
340        Op::Div               => Instruction::Div,
341        Op::IDiv              => Instruction::IDiv,
342        Op::Mod               => Instruction::Mod,
343        Op::Pow               => Instruction::Pow,
344        Op::Concat            => Instruction::Concat,
345        Op::Eq                => Instruction::Eq,
346        Op::NotEq             => Instruction::NotEq,
347        Op::Lt                => Instruction::Lt,
348        Op::LtEq              => Instruction::LtEq,
349        Op::Gt                => Instruction::Gt,
350        Op::GtEq              => Instruction::GtEq,
351        Op::BitAnd            => Instruction::BitAnd,
352        Op::BitOr             => Instruction::BitOr,
353        Op::BitXor            => Instruction::BitXor,
354        Op::Shl               => Instruction::Shl,
355        Op::Shr               => Instruction::Shr,
356        Op::Jump(off)         => Instruction::Jump(*off as isize),
357        Op::JumpIf(off)       => Instruction::JumpIf(*off as isize),
358        Op::JumpIfNot(off)    => Instruction::JumpIfNot(*off as isize),
359        Op::JumpIfNotPop(off) => Instruction::JumpIfNotPop(*off as isize),
360        Op::JumpIfPop(off)    => Instruction::JumpIfPop(*off as isize),
361        Op::Call(na, _)       => Instruction::Call(*na as usize),
362        Op::CallMethod(k, na, _) => Instruction::CallMethod(get_str(*k), *na as usize),
363        Op::Return(n)         => Instruction::Return(*n as usize),
364        Op::TailCall(n)       => Instruction::Call(*n as usize),
365        Op::Closure(idx)      => Instruction::MakeFunction(*idx as usize),
366        Op::Close(s)          => Instruction::CloseUpvalue(*s as usize),
367        Op::ForPrep(n)        => Instruction::ForPrep(*n as usize),
368        Op::ForStep           => Instruction::Nop,
369        Op::ForStepJump(off)  => Instruction::ForStep(0, *off as isize),
370        Op::NumForInit        => Instruction::Nop,
371        Op::NumForStep(off)   => Instruction::ForStep(0, *off as isize),
372        Op::Vararg(_)         => Instruction::Nop,
373        Op::LineInfo(_)       => Instruction::Nop,
374    }
375}
376
377// ── Local variable tracking ───────────────────────────────────────────────────
378
379#[derive(Debug, Clone)]
380struct Local {
381    name:  String,
382    slot:  u16,
383    depth: usize,
384}
385
386struct Scope {
387    locals:      Vec<Local>,
388    scope_depth: usize,
389    next_slot:   u16,
390}
391
392impl Scope {
393    fn new() -> Self {
394        Scope { locals: Vec::new(), scope_depth: 0, next_slot: 0 }
395    }
396
397    fn push_scope(&mut self) {
398        self.scope_depth += 1;
399    }
400
401    fn pop_scope(&mut self) -> u16 {
402        let depth = self.scope_depth;
403        let before = self.locals.len();
404        self.locals.retain(|l| l.depth < depth);
405        let removed = (before - self.locals.len()) as u16;
406        self.next_slot -= removed;
407        self.scope_depth -= 1;
408        removed
409    }
410
411    fn add_local(&mut self, name: &str) -> u16 {
412        let slot = self.next_slot;
413        self.locals.push(Local {
414            name: name.to_string(),
415            slot,
416            depth: self.scope_depth,
417        });
418        self.next_slot += 1;
419        slot
420    }
421
422    fn resolve_local(&self, name: &str) -> Option<u16> {
423        self.locals.iter().rev()
424            .find(|l| l.name == name)
425            .map(|l| l.slot)
426    }
427}
428
429// ── Upvalue descriptor ────────────────────────────────────────────────────────
430
431/// How an upvalue is captured by a closure.
432#[derive(Debug, Clone)]
433pub struct UpvalDesc {
434    pub name:     String,
435    /// If true, the upvalue is a local slot in the immediately enclosing scope;
436    /// otherwise it is an upvalue of the enclosing function.
437    pub in_stack: bool,
438    pub index:    u16,
439}
440
441// ── Compiler ─────────────────────────────────────────────────────────────────
442
443/// Single-pass bytecode compiler.
444pub struct Compiler {
445    proto:  Proto,
446    scope:  Scope,
447    breaks: Vec<Vec<usize>>,   // break-patch points indexed by loop nesting
448    // Note: upvalue handling is simplified — outer-function locals captured
449    // as globals in this basic implementation.
450}
451
452impl Compiler {
453    // ── Public entry points ───────────────────────────────────────────────────
454
455    /// Compile an entire script into a top-level `Arc<Chunk>` for the VM.
456    pub fn compile_script(script: &Script) -> Arc<Chunk> {
457        proto_to_chunk(&Self::compile_to_proto(script))
458    }
459
460    /// Compile to the internal `Proto` representation (used by compiler tests).
461    pub fn compile_to_proto(script: &Script) -> Proto {
462        let mut c = Compiler {
463            proto:  Proto::new(&script.name),
464            scope:  Scope::new(),
465            breaks: Vec::new(),
466        };
467        c.proto.is_vararg = true;
468        c.compile_block_no_scope(&script.stmts);
469        c.proto.emit(Op::Return(0));
470        c.proto
471    }
472
473    // ── Block compilation ─────────────────────────────────────────────────────
474
475    fn compile_block(&mut self, stmts: &[Stmt]) {
476        self.scope.push_scope();
477        for s in stmts { self.compile_stmt(s); }
478        let popped = self.scope.pop_scope();
479        for _ in 0..popped { self.proto.emit(Op::Pop); }
480    }
481
482    fn compile_block_no_scope(&mut self, stmts: &[Stmt]) {
483        for s in stmts { self.compile_stmt(s); }
484    }
485
486    // ── Statement compilation ─────────────────────────────────────────────────
487
488    fn compile_stmt(&mut self, stmt: &Stmt) {
489        match stmt {
490            Stmt::LocalDecl { name, init } => {
491                if let Some(expr) = init {
492                    self.compile_expr(expr);
493                } else {
494                    self.proto.emit(Op::Nil);
495                }
496                self.scope.add_local(name);
497            }
498
499            Stmt::LocalMulti { names, inits } => {
500                for (i, name) in names.iter().enumerate() {
501                    if i < inits.len() {
502                        self.compile_expr(&inits[i]);
503                    } else {
504                        self.proto.emit(Op::Nil);
505                    }
506                    self.scope.add_local(name);
507                }
508            }
509
510            Stmt::Assign { target, value } => {
511                for (i, t) in target.iter().enumerate() {
512                    if i < value.len() {
513                        self.compile_expr(&value[i]);
514                    } else {
515                        self.proto.emit(Op::Nil);
516                    }
517                    self.compile_assign_target(t);
518                }
519            }
520
521            Stmt::CompoundAssign { target, op, value } => {
522                self.compile_expr(target);
523                self.compile_expr(value);
524                self.compile_binop(*op);
525                self.compile_assign_target(target);
526            }
527
528            Stmt::Call(expr) | Stmt::Expr(expr) => {
529                self.compile_expr(expr);
530                self.proto.emit(Op::Pop);
531            }
532
533            Stmt::Do(body) => {
534                self.compile_block(body);
535            }
536
537            Stmt::While { cond, body } => {
538                let loop_start = self.proto.code.len() as i32;
539                self.compile_expr(cond);
540                let exit = self.proto.emit(Op::JumpIfNot(0));
541                self.breaks.push(Vec::new());
542                self.compile_block(body);
543                let back = loop_start - self.proto.code.len() as i32 - 1;
544                self.proto.emit(Op::Jump(back));
545                self.proto.patch_jump(exit);
546                for b in self.breaks.pop().unwrap_or_default() {
547                    self.proto.patch_jump(b);
548                }
549            }
550
551            Stmt::RepeatUntil { body, cond } => {
552                let loop_start = self.proto.code.len() as i32;
553                self.breaks.push(Vec::new());
554                self.compile_block(body);
555                self.compile_expr(cond);
556                let back = loop_start - self.proto.code.len() as i32 - 1;
557                self.proto.emit(Op::JumpIfNot(back));
558                for b in self.breaks.pop().unwrap_or_default() {
559                    self.proto.patch_jump(b);
560                }
561            }
562
563            Stmt::If { cond, then_body, elseif_branches, else_body } => {
564                self.compile_expr(cond);
565                let skip_then = self.proto.emit(Op::JumpIfNot(0));
566                self.compile_block(then_body);
567
568                let mut end_jumps = Vec::new();
569                if !elseif_branches.is_empty() || else_body.is_some() {
570                    end_jumps.push(self.proto.emit(Op::Jump(0)));
571                }
572                self.proto.patch_jump(skip_then);
573
574                for (ei_cond, ei_body) in elseif_branches {
575                    self.compile_expr(ei_cond);
576                    let skip = self.proto.emit(Op::JumpIfNot(0));
577                    self.compile_block(ei_body);
578                    end_jumps.push(self.proto.emit(Op::Jump(0)));
579                    self.proto.patch_jump(skip);
580                }
581
582                if let Some(eb) = else_body {
583                    self.compile_block(eb);
584                }
585                for j in end_jumps { self.proto.patch_jump(j); }
586            }
587
588            Stmt::NumericFor { var, start, limit, step, body } => {
589                self.compile_expr(start);
590                self.compile_expr(limit);
591                if let Some(s) = step {
592                    self.compile_expr(s);
593                } else {
594                    let k = self.proto.add_const(Constant::Int(1));
595                    self.proto.emit(Op::Const(k));
596                }
597                self.proto.emit(Op::NumForInit);
598                let loop_top = self.proto.code.len();
599                let exit = self.proto.emit(Op::NumForStep(0));
600
601                self.scope.push_scope();
602                let slot = self.scope.add_local(var);
603                self.proto.emit(Op::GetLocal(slot));
604                self.breaks.push(Vec::new());
605                for s in body { self.compile_stmt(s); }
606                self.scope.pop_scope();
607
608                let back = loop_top as i32 - self.proto.code.len() as i32 - 1;
609                self.proto.emit(Op::Jump(back));
610                self.proto.patch_jump(exit);
611                // Pop limit, step, counter
612                for _ in 0..3 { self.proto.emit(Op::Pop); }
613                for b in self.breaks.pop().unwrap_or_default() {
614                    self.proto.patch_jump(b);
615                }
616            }
617
618            Stmt::GenericFor { vars, iter, body } => {
619                for expr in iter { self.compile_expr(expr); }
620                self.proto.emit(Op::ForPrep(vars.len() as u16));
621                let loop_top = self.proto.code.len();
622                let exit = self.proto.emit(Op::ForStepJump(0));
623
624                self.scope.push_scope();
625                for name in vars {
626                    let slot = self.scope.add_local(name);
627                    self.proto.emit(Op::GetLocal(slot));
628                }
629                self.breaks.push(Vec::new());
630                for s in body { self.compile_stmt(s); }
631                self.scope.pop_scope();
632
633                let back = loop_top as i32 - self.proto.code.len() as i32 - 1;
634                self.proto.emit(Op::Jump(back));
635                self.proto.patch_jump(exit);
636                for b in self.breaks.pop().unwrap_or_default() {
637                    self.proto.patch_jump(b);
638                }
639            }
640
641            Stmt::FuncDecl { name, params, vararg, body } => {
642                let fn_proto = self.compile_func(
643                    name.last().map(|s| s.as_str()).unwrap_or("?"),
644                    params, *vararg, body,
645                );
646                let idx = self.proto.protos.len() as u32;
647                self.proto.protos.push(fn_proto);
648                self.proto.emit(Op::Closure(idx));
649
650                if name.len() == 1 {
651                    if let Some(slot) = self.scope.resolve_local(&name[0]) {
652                        self.proto.emit(Op::SetLocal(slot));
653                    } else {
654                        let k = self.proto.add_const(Constant::Str(name[0].clone()));
655                        self.proto.emit(Op::SetGlobal(k));
656                    }
657                } else {
658                    // a.b.fn = closure
659                    self.compile_expr(&Expr::Ident(name[0].clone()));
660                    for part in &name[1..name.len()-1] {
661                        let k = self.proto.add_const(Constant::Str(part.clone()));
662                        self.proto.emit(Op::GetField(k));
663                    }
664                    let last = name.last().unwrap();
665                    let k = self.proto.add_const(Constant::Str(last.clone()));
666                    self.proto.emit(Op::SetField(k));
667                }
668            }
669
670            Stmt::LocalFunc { name, params, vararg, body } => {
671                let slot = self.scope.add_local(name);
672                self.proto.emit(Op::Nil); // placeholder until closure is made
673                let fn_proto = self.compile_func(name, params, *vararg, body);
674                let idx = self.proto.protos.len() as u32;
675                self.proto.protos.push(fn_proto);
676                self.proto.emit(Op::Closure(idx));
677                self.proto.emit(Op::SetLocal(slot));
678            }
679
680            Stmt::Return(vals) => {
681                let n = vals.len() as u8;
682                for v in vals { self.compile_expr(v); }
683                self.proto.emit(Op::Return(n));
684            }
685
686            Stmt::Break => {
687                let j = self.proto.emit(Op::Jump(0));
688                if let Some(list) = self.breaks.last_mut() {
689                    list.push(j);
690                }
691            }
692
693            Stmt::Continue => {
694                // Simplified continue: jump to -1 (loop should handle by re-check)
695                self.proto.emit(Op::Jump(-1));
696            }
697
698            Stmt::Match { expr, arms } => {
699                self.compile_expr(expr);
700                let mut end_jumps = Vec::new();
701
702                for arm in arms {
703                    self.proto.emit(Op::Dup);
704                    match &arm.pattern {
705                        MatchPattern::Wildcard => {
706                            self.proto.emit(Op::Pop);
707                            self.compile_block(&arm.body);
708                            end_jumps.push(self.proto.emit(Op::Jump(0)));
709                            continue;
710                        }
711                        MatchPattern::Ident(bind) => {
712                            let slot = self.scope.add_local(bind);
713                            self.proto.emit(Op::SetLocal(slot));
714                            self.compile_block(&arm.body);
715                            end_jumps.push(self.proto.emit(Op::Jump(0)));
716                            continue;
717                        }
718                        MatchPattern::Literal(lit) => {
719                            self.compile_expr(lit);
720                            self.proto.emit(Op::Eq);
721                        }
722                        MatchPattern::Table(_) => {
723                            self.proto.emit(Op::Pop);
724                            self.proto.emit(Op::True);
725                        }
726                    }
727                    let skip = self.proto.emit(Op::JumpIfNot(0));
728                    self.compile_block(&arm.body);
729                    end_jumps.push(self.proto.emit(Op::Jump(0)));
730                    self.proto.patch_jump(skip);
731                }
732
733                self.proto.emit(Op::Pop);
734                for j in end_jumps { self.proto.patch_jump(j); }
735            }
736
737            Stmt::Import { path, alias } => {
738                let k = self.proto.add_const(Constant::Str(path.clone()));
739                self.proto.emit(Op::Const(k));
740                let rk = self.proto.add_const(Constant::Str("require".to_string()));
741                self.proto.emit(Op::GetGlobal(rk));
742                self.proto.emit(Op::Swap);
743                self.proto.emit(Op::Call(1, 1));
744                let bind = alias.clone().unwrap_or_else(|| {
745                    path.split('/').last().unwrap_or(path).trim_end_matches(".lua").to_string()
746                });
747                let bk = self.proto.add_const(Constant::Str(bind));
748                self.proto.emit(Op::SetGlobal(bk));
749            }
750
751            Stmt::Export(name) => {
752                if let Some(slot) = self.scope.resolve_local(name) {
753                    self.proto.emit(Op::GetLocal(slot));
754                } else {
755                    let k = self.proto.add_const(Constant::Str(name.clone()));
756                    self.proto.emit(Op::GetGlobal(k));
757                }
758                let ek = self.proto.add_const(Constant::Str(name.clone()));
759                let exports_k = self.proto.add_const(Constant::Str("__exports".to_string()));
760                self.proto.emit(Op::GetGlobal(exports_k));
761                self.proto.emit(Op::Swap);
762                self.proto.emit(Op::SetField(ek));
763            }
764        }
765    }
766
767    fn compile_assign_target(&mut self, target: &Expr) {
768        match target {
769            Expr::Ident(name) => {
770                if let Some(slot) = self.scope.resolve_local(name) {
771                    self.proto.emit(Op::SetLocal(slot));
772                } else {
773                    let k = self.proto.add_const(Constant::Str(name.clone()));
774                    self.proto.emit(Op::SetGlobal(k));
775                }
776            }
777            Expr::Field { table, name } => {
778                self.compile_expr(table);
779                let k = self.proto.add_const(Constant::Str(name.clone()));
780                self.proto.emit(Op::SetField(k));
781            }
782            Expr::Index { table, key } => {
783                self.compile_expr(table);
784                self.compile_expr(key);
785                self.proto.emit(Op::SetIndex);
786            }
787            _ => {}
788        }
789    }
790
791    // ── Expression compilation ────────────────────────────────────────────────
792
793    fn compile_expr(&mut self, expr: &Expr) {
794        match expr {
795            Expr::Nil         => { self.proto.emit(Op::Nil); }
796            Expr::Bool(b)     => { self.proto.emit(if *b { Op::True } else { Op::False }); }
797            Expr::Int(n)      => { let k = self.proto.add_const(Constant::Int(*n)); self.proto.emit(Op::Const(k)); }
798            Expr::Float(f)    => { let k = self.proto.add_const(Constant::Float(*f)); self.proto.emit(Op::Const(k)); }
799            Expr::Str(s)      => { let k = self.proto.add_const(Constant::Str(s.clone())); self.proto.emit(Op::Const(k)); }
800            Expr::Vararg      => { self.proto.emit(Op::Vararg(0)); }
801
802            Expr::Ident(name) => {
803                if let Some(slot) = self.scope.resolve_local(name) {
804                    self.proto.emit(Op::GetLocal(slot));
805                } else {
806                    let k = self.proto.add_const(Constant::Str(name.clone()));
807                    self.proto.emit(Op::GetGlobal(k));
808                }
809            }
810
811            Expr::Field { table, name } => {
812                self.compile_expr(table);
813                let k = self.proto.add_const(Constant::Str(name.clone()));
814                self.proto.emit(Op::GetField(k));
815            }
816
817            Expr::Index { table, key } => {
818                self.compile_expr(table);
819                self.compile_expr(key);
820                self.proto.emit(Op::GetIndex);
821            }
822
823            Expr::Call { callee, args } => {
824                self.compile_expr(callee);
825                let nargs = args.len() as u8;
826                for a in args { self.compile_expr(a); }
827                self.proto.emit(Op::Call(nargs, 1));
828            }
829
830            Expr::MethodCall { obj, method, args } => {
831                self.compile_expr(obj);
832                let k = self.proto.add_const(Constant::Str(method.clone()));
833                let nargs = args.len() as u8;
834                for a in args { self.compile_expr(a); }
835                self.proto.emit(Op::CallMethod(k, nargs, 1));
836            }
837
838            Expr::Unary { op, expr } => {
839                self.compile_expr(expr);
840                match op {
841                    UnOp::Neg    => { self.proto.emit(Op::Neg); }
842                    UnOp::Not    => { self.proto.emit(Op::Not); }
843                    UnOp::Len    => { self.proto.emit(Op::Len); }
844                    UnOp::BitNot => { self.proto.emit(Op::BitNot); }
845                }
846            }
847
848            Expr::Binary { op, lhs, rhs } => {
849                match op {
850                    BinOp::And => {
851                        self.compile_expr(lhs);
852                        let j = self.proto.emit(Op::JumpIfNotPop(0));
853                        self.compile_expr(rhs);
854                        self.proto.patch_jump(j);
855                        return;
856                    }
857                    BinOp::Or => {
858                        self.compile_expr(lhs);
859                        let j = self.proto.emit(Op::JumpIfPop(0));
860                        self.compile_expr(rhs);
861                        self.proto.patch_jump(j);
862                        return;
863                    }
864                    _ => {}
865                }
866                self.compile_expr(lhs);
867                self.compile_expr(rhs);
868                self.compile_binop(*op);
869            }
870
871            Expr::TableCtor(fields) => {
872                self.proto.emit(Op::NewTable);
873                let mut array_count = 0u16;
874                for field in fields {
875                    match field {
876                        TableField::NameKey(name, val) => {
877                            self.proto.emit(Op::Dup);
878                            self.compile_expr(val);
879                            let k = self.proto.add_const(Constant::Str(name.clone()));
880                            self.proto.emit(Op::SetField(k));
881                        }
882                        TableField::ExprKey(key, val) => {
883                            self.proto.emit(Op::Dup);
884                            self.compile_expr(key);
885                            self.compile_expr(val);
886                            self.proto.emit(Op::SetIndex);
887                        }
888                        TableField::Value(val) => {
889                            self.proto.emit(Op::Dup);
890                            self.compile_expr(val);
891                            array_count += 1;
892                            let k = self.proto.add_const(Constant::Int(array_count as i64));
893                            self.proto.emit(Op::SetField(k));
894                        }
895                    }
896                }
897            }
898
899            Expr::FuncExpr { params, vararg, body } => {
900                let fn_proto = self.compile_func("<anon>", params, *vararg, body);
901                let idx = self.proto.protos.len() as u32;
902                self.proto.protos.push(fn_proto);
903                self.proto.emit(Op::Closure(idx));
904            }
905
906            Expr::Ternary { cond, then_val, else_val } => {
907                self.compile_expr(cond);
908                let skip = self.proto.emit(Op::JumpIfNot(0));
909                self.compile_expr(then_val);
910                let end = self.proto.emit(Op::Jump(0));
911                self.proto.patch_jump(skip);
912                self.compile_expr(else_val);
913                self.proto.patch_jump(end);
914            }
915        }
916    }
917
918    fn compile_binop(&mut self, op: BinOp) {
919        let instr = match op {
920            BinOp::Add    => Op::Add,
921            BinOp::Sub    => Op::Sub,
922            BinOp::Mul    => Op::Mul,
923            BinOp::Div    => Op::Div,
924            BinOp::IDiv   => Op::IDiv,
925            BinOp::Mod    => Op::Mod,
926            BinOp::Pow    => Op::Pow,
927            BinOp::Concat => Op::Concat,
928            BinOp::Eq     => Op::Eq,
929            BinOp::NotEq  => Op::NotEq,
930            BinOp::Lt     => Op::Lt,
931            BinOp::LtEq   => Op::LtEq,
932            BinOp::Gt     => Op::Gt,
933            BinOp::GtEq   => Op::GtEq,
934            BinOp::And    => Op::BitAnd,
935            BinOp::Or     => Op::BitOr,
936            BinOp::BitAnd => Op::BitAnd,
937            BinOp::BitOr  => Op::BitOr,
938            BinOp::BitXor => Op::BitXor,
939            BinOp::Shl    => Op::Shl,
940            BinOp::Shr    => Op::Shr,
941        };
942        self.proto.emit(instr);
943    }
944
945    fn compile_func(&mut self, name: &str, params: &[String], vararg: bool, body: &[Stmt]) -> Proto {
946        let mut child = Compiler {
947            proto:  Proto::new(name),
948            scope:  Scope::new(),
949            breaks: Vec::new(),
950        };
951        child.proto.param_count = params.len() as u8;
952        child.proto.is_vararg   = vararg;
953        child.scope.push_scope();
954        for p in params { child.scope.add_local(p); }
955        for s in body   { child.compile_stmt(s); }
956        child.scope.pop_scope();
957        child.proto.emit(Op::Return(0));
958        child.proto
959    }
960}
961
962// ── Tests ─────────────────────────────────────────────────────────────────────
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967    use crate::scripting::parser;
968
969    fn compile_src(src: &str) -> Proto {
970        let script = parser::parse(src, "test").expect("parse failed");
971        Compiler::compile_to_proto(&script)
972    }
973
974    #[test]
975    fn test_compile_nil() {
976        let p = compile_src("local x");
977        assert!(p.code.iter().any(|op| *op == Op::Nil));
978    }
979
980    #[test]
981    fn test_compile_int_const() {
982        let p = compile_src("local x = 42");
983        assert!(p.constants.iter().any(|c| *c == Constant::Int(42)));
984    }
985
986    #[test]
987    fn test_compile_float_const() {
988        let p = compile_src("local pi = 3.14");
989        assert!(p.constants.iter().any(|c| matches!(c, Constant::Float(f) if (*f - 3.14).abs() < 1e-6)));
990    }
991
992    #[test]
993    fn test_compile_string_const() {
994        let p = compile_src(r#"local s = "hello""#);
995        assert!(p.constants.iter().any(|c| *c == Constant::Str("hello".to_string())));
996    }
997
998    #[test]
999    fn test_compile_add() {
1000        let p = compile_src("local z = 1 + 2");
1001        assert!(p.code.iter().any(|op| *op == Op::Add));
1002    }
1003
1004    #[test]
1005    fn test_compile_bool_true() {
1006        let p = compile_src("local b = true");
1007        assert!(p.code.iter().any(|op| *op == Op::True));
1008    }
1009
1010    #[test]
1011    fn test_compile_while_has_back_jump() {
1012        let p = compile_src("local i = 0 while i < 10 do i = i + 1 end");
1013        let has_exit = p.code.iter().any(|op| matches!(op, Op::JumpIfNot(_)));
1014        let has_back = p.code.iter().any(|op| matches!(op, Op::Jump(n) if *n < 0));
1015        assert!(has_exit, "expected JumpIfNot");
1016        assert!(has_back, "expected backward Jump");
1017    }
1018
1019    #[test]
1020    fn test_compile_if_else() {
1021        let p = compile_src("if x then return 1 else return 2 end");
1022        assert!(p.code.iter().any(|op| matches!(op, Op::JumpIfNot(_))));
1023        assert!(p.code.iter().any(|op| matches!(op, Op::Jump(_))));
1024    }
1025
1026    #[test]
1027    fn test_compile_function_creates_proto() {
1028        let p = compile_src("function add(a, b) return a + b end");
1029        assert!(!p.protos.is_empty());
1030        assert_eq!(p.protos[0].param_count, 2);
1031    }
1032
1033    #[test]
1034    fn test_compile_local_function() {
1035        let p = compile_src("local function square(x) return x * x end");
1036        assert!(p.code.iter().any(|op| matches!(op, Op::Closure(_))));
1037        assert!(!p.protos.is_empty());
1038    }
1039
1040    #[test]
1041    fn test_compile_table_ctor() {
1042        let p = compile_src("local t = {x = 1, y = 2}");
1043        assert!(p.code.iter().any(|op| *op == Op::NewTable));
1044        assert!(p.code.iter().any(|op| matches!(op, Op::SetField(_))));
1045    }
1046
1047    #[test]
1048    fn test_compile_method_call() {
1049        let p = compile_src("obj:doThing(1, 2)");
1050        assert!(p.code.iter().any(|op| matches!(op, Op::CallMethod(..))));
1051    }
1052
1053    #[test]
1054    fn test_compile_for_numeric() {
1055        let p = compile_src("for i = 1, 10, 2 do end");
1056        assert!(p.code.iter().any(|op| *op == Op::NumForInit));
1057        assert!(p.code.iter().any(|op| matches!(op, Op::NumForStep(_))));
1058    }
1059
1060    #[test]
1061    fn test_compile_for_generic() {
1062        let p = compile_src("for k, v in pairs(t) do end");
1063        assert!(p.code.iter().any(|op| matches!(op, Op::ForPrep(_))));
1064    }
1065
1066    #[test]
1067    fn test_compile_and_short_circuit() {
1068        let p = compile_src("local r = a and b");
1069        assert!(p.code.iter().any(|op| matches!(op, Op::JumpIfNotPop(_))));
1070    }
1071
1072    #[test]
1073    fn test_compile_or_short_circuit() {
1074        let p = compile_src("local r = a or b");
1075        assert!(p.code.iter().any(|op| matches!(op, Op::JumpIfPop(_))));
1076    }
1077
1078    #[test]
1079    fn test_compile_ternary() {
1080        let p = compile_src("local x = cond ? 1 : 2");
1081        assert!(p.code.iter().any(|op| matches!(op, Op::JumpIfNot(_))));
1082    }
1083
1084    #[test]
1085    fn test_compile_nested_function() {
1086        let p = compile_src("
1087            function outer(x)
1088                local function inner(y) return x + y end
1089                return inner(10)
1090            end
1091        ");
1092        assert!(!p.protos.is_empty());
1093        let outer = &p.protos[0];
1094        assert!(!outer.protos.is_empty(), "expected inner proto");
1095    }
1096
1097    #[test]
1098    fn test_compile_concat() {
1099        let p = compile_src(r#"local s = "hello" .. " " .. "world""#);
1100        assert!(p.code.iter().filter(|op| **op == Op::Concat).count() >= 1);
1101    }
1102
1103    #[test]
1104    fn test_compile_repeat_until() {
1105        let p = compile_src("local i = 0 repeat i = i + 1 until i >= 10");
1106        assert!(p.code.iter().any(|op| matches!(op, Op::JumpIfNot(n) if *n < 0)));
1107    }
1108
1109    #[test]
1110    fn test_compile_match() {
1111        let p = compile_src("match x { case 1 => return 1, case 2 => return 2 }");
1112        assert!(p.code.iter().any(|op| *op == Op::Dup));
1113        assert!(p.code.iter().any(|op| *op == Op::Eq));
1114    }
1115
1116    #[test]
1117    fn test_compile_import() {
1118        let p = compile_src(r#"import "math" as m"#);
1119        assert!(p.constants.iter().any(|c| *c == Constant::Str("math".to_string())));
1120        assert!(p.constants.iter().any(|c| *c == Constant::Str("require".to_string())));
1121    }
1122
1123    #[test]
1124    fn test_add_const_deduplication() {
1125        let mut p = Proto::new("test");
1126        let i1 = p.add_const(Constant::Int(42));
1127        let i2 = p.add_const(Constant::Int(42));
1128        assert_eq!(i1, i2, "deduplication failed");
1129        assert_eq!(p.constants.len(), 1);
1130    }
1131}