rslua_march1917/
compiler.rs

1use crate::ast::*;
2use crate::ast_walker::{ast_walker, AstVisitor};
3use crate::consts::Const;
4use crate::opcodes::*;
5use crate::proto::{Proto, ProtoContext};
6use crate::types::Source;
7use crate::{debuggable, error, success};
8
9pub struct Compiler {
10    debug: bool,
11    proto_contexts: Vec<ProtoContext>,
12}
13
14pub struct CompileError(pub String);
15
16impl CompileError {
17    pub fn new(str: &str) -> Self {
18        CompileError(str.to_string())
19    }
20}
21
22type CompileResult = Result<Proto, CompileError>;
23
24macro_rules! compile_error {
25    ($self:ident, $error:ident, $source:ident) => {{
26        let error_msg = format!("[compile error] {} at line [{}].", $error.0, $source.line);
27        error!($self, CompileError, error_msg)
28    }};
29}
30
31pub struct Reg {
32    pub reg: u32,
33    pub temp: bool,
34    pub mutable: bool,
35}
36
37impl Reg {
38    pub fn new(reg: u32) -> Self {
39        Reg {
40            reg,
41            temp: false,
42            mutable: true,
43        }
44    }
45
46    pub fn new_temp(reg: u32) -> Self {
47        Reg {
48            reg,
49            temp: true,
50            mutable: true,
51        }
52    }
53
54    pub fn is_temp(&self) -> bool {
55        self.temp
56    }
57
58    pub fn is_const(&self) -> bool {
59        !self.mutable
60    }
61
62    pub fn free(&self, context: &mut ProtoContext) {
63        if self.is_temp() {
64            context.free_reg(1)
65        }
66    }
67}
68
69pub struct Jump {
70    pub reg: Reg,
71    pub pc: usize,
72    pub true_jumps: Vec<usize>,
73    pub false_jumps: Vec<usize>,
74    pub reg_should_move: Option<u32>,
75}
76
77impl Jump {
78    pub fn new(reg: Reg, pc: usize) -> Self {
79        Jump {
80            reg,
81            pc,
82            true_jumps: Vec::new(),
83            false_jumps: Vec::new(),
84            reg_should_move: None,
85        }
86    }
87
88    pub fn free(&self, context: &mut ProtoContext) {
89        let proto = &mut context.proto;
90        let target = self.reg.reg;
91        if let Some(from) = self.reg_should_move {
92            proto.code_move(target, from);
93        }
94        let false_pos = proto.code_bool(target, false, 1);
95        let true_pos = proto.code_bool(target, true, 0);
96        self.fix(true_pos, false_pos, proto);
97        self.reg.free(context);
98    }
99
100    pub fn free_reg(&self, context: &mut ProtoContext) {
101        self.reg.free(context);
102    }
103
104    pub fn inverse_cond(&self, context: &mut ProtoContext) {
105        let proto = &mut context.proto;
106        let cond = self.pc - 1;
107        let instruction = proto.get_instruction(cond);
108        instruction.set_arg_A(1 - instruction.get_arg_A());
109    }
110
111    pub fn concat_true_jumps(&mut self, other: &mut Jump) {
112        self.true_jumps.append(&mut other.true_jumps);
113        self.true_jumps.push(other.pc);
114    }
115
116    pub fn concat_false_jumps(&mut self, other: &mut Jump) {
117        self.false_jumps.append(&mut other.false_jumps);
118        self.false_jumps.push(other.pc);
119    }
120
121    pub fn set_reg_should_move(&mut self, from: u32) {
122        self.reg_should_move = Some(from)
123    }
124
125    fn fix(&self, true_pos: usize, false_pos: usize, proto: &mut Proto) {
126        proto.fix_cond_jump_pos(true_pos, false_pos, self.pc);
127        for pc in self.true_jumps.iter() {
128            proto.fix_jump_pos(true_pos, *pc)
129        }
130        for pc in self.false_jumps.iter() {
131            proto.fix_jump_pos(false_pos, *pc)
132        }
133    }
134}
135
136pub enum ExprResult {
137    Const(Const),
138    Reg(Reg),
139    Jump(Jump),
140    Nil,
141    True,
142    False,
143}
144
145impl ExprResult {
146    pub fn new_const(k: Const) -> Self {
147        ExprResult::Const(k)
148    }
149
150    pub fn new_const_reg(reg: u32) -> Self {
151        ExprResult::Reg(Reg {
152            reg,
153            temp: false,
154            mutable: false,
155        })
156    }
157
158    pub fn new_jump(reg: Reg, pc: usize) -> Self {
159        ExprResult::Jump(Jump::new(reg, pc))
160    }
161
162    pub fn get_rk(&self, context: &mut ProtoContext) -> u32 {
163        match self {
164            ExprResult::Const(k) => {
165                let index = context.proto.add_const(k.clone());
166                MASK_K | index
167            }
168            ExprResult::Reg(i) => i.reg,
169            ExprResult::Jump(j) => j.reg.reg,
170            _ => unreachable!(),
171        }
172    }
173
174    pub fn resolve(&self, context: &mut ProtoContext) {
175        match self {
176            ExprResult::Reg(r) => r.free(context),
177            ExprResult::Jump(j) => j.free(context),
178            _ => (),
179        };
180    }
181}
182
183impl Compiler {
184    pub fn new() -> Self {
185        Compiler {
186            debug: false,
187            proto_contexts: Vec::new(),
188        }
189    }
190
191    pub fn run(&mut self, block: &Block) -> CompileResult {
192        self.main_func(block)
193    }
194
195    fn main_func(&mut self, block: &Block) -> CompileResult {
196        self.push_proto();
197        self.proto().open();
198        ast_walker::walk_block(block, self)?;
199        self.proto().close();
200        Ok(self.pop_proto())
201    }
202
203    fn push_proto(&mut self) {
204        self.proto_contexts.push(ProtoContext::new());
205    }
206
207    fn pop_proto(&mut self) -> Proto {
208        if let Some(context) = self.proto_contexts.pop() {
209            return context.proto;
210        }
211        unreachable!()
212    }
213
214    // get current proto ref from stack
215    fn proto(&mut self) -> &mut Proto {
216        &mut self.context().proto
217    }
218
219    // get current proto context
220    fn context(&mut self) -> &mut ProtoContext {
221        if let Some(last) = self.proto_contexts.last_mut() {
222            return last;
223        }
224        unreachable!()
225    }
226
227    fn adjust_assign(&mut self, num_left: usize, right_exprs: &Vec<Expr>) -> i32 {
228        let extra = num_left as i32 - right_exprs.len() as i32;
229        if let Some(last_expr) = right_exprs.last() {
230            if last_expr.has_multi_ret() {
231                // TODO : process multi return value
232                todo!("process mult ret")
233            }
234        }
235
236        if extra > 0 {
237            let context = self.context();
238            let from = context.get_reg_top();
239            context.reserve_regs(extra as u32);
240            context.proto.code_nil(from, extra as u32);
241        }
242
243        extra
244    }
245
246    // process expr and return const index or register index
247    fn expr(&mut self, expr: &Expr, reg: Option<u32>) -> Result<ExprResult, CompileError> {
248        let proto = self.proto();
249        let result = match expr {
250            Expr::Int(i) => ExprResult::new_const(Const::Int(*i)),
251            Expr::Float(f) => ExprResult::new_const(Const::Float(*f)),
252            Expr::String(s) => {
253                // const string will always be added to consts
254                let k = Const::Str(s.clone());
255                proto.add_const(k.clone());
256                ExprResult::new_const(k)
257            }
258            Expr::Nil => ExprResult::Nil,
259            Expr::True => ExprResult::True,
260            Expr::False => ExprResult::False,
261            Expr::Name(name) => {
262                if let Some(src) = proto.get_local_var(name) {
263                    return Ok(ExprResult::new_const_reg(src));
264                }
265                // TODO : process upval and globals
266                todo!()
267            }
268            Expr::BinExpr(_) | Expr::UnExpr(_) => self.folding_or_code(expr, reg)?,
269            Expr::ParenExpr(expr) => self.folding_or_code(&expr, reg)?,
270            _ => todo!(),
271        };
272        Ok(result)
273    }
274
275    // try constant foding first, if failed then generate code
276    fn folding_or_code(
277        &mut self,
278        expr: &Expr,
279        reg: Option<u32>,
280    ) -> Result<ExprResult, CompileError> {
281        if let Some(k) = self.try_const_folding(expr)? {
282            Ok(ExprResult::new_const(k))
283        } else {
284            self.code_expr(expr, reg)
285        }
286    }
287
288    // try constant folding expr
289    fn try_const_folding(&self, expr: &Expr) -> Result<Option<Const>, CompileError> {
290        match expr {
291            Expr::Int(i) => return success!(Const::Int(*i)),
292            Expr::Float(f) => return success!(Const::Float(*f)),
293            Expr::String(s) => return success!(Const::Str(s.clone())),
294            Expr::BinExpr(bin) => match bin.op {
295                BinOp::Add
296                | BinOp::Minus
297                | BinOp::Mul
298                | BinOp::Div
299                | BinOp::IDiv
300                | BinOp::Mod
301                | BinOp::Pow
302                | BinOp::BAnd
303                | BinOp::BOr
304                | BinOp::BXor
305                | BinOp::Shl
306                | BinOp::Shr => {
307                    if let (Some(l), Some(r)) = (
308                        self.try_const_folding(&bin.left)?,
309                        self.try_const_folding(&bin.right)?,
310                    ) {
311                        if let Some(k) = self.const_folding_bin_op(bin.op, l, r)? {
312                            return success!(k);
313                        }
314                    }
315                }
316                _ => (),
317            },
318            Expr::UnExpr(un) => match un.op {
319                UnOp::BNot | UnOp::Minus => {
320                    if let Some(k) = self.try_const_folding(&un.expr)? {
321                        if let Some(k) = self.const_folding_un_op(un.op, k)? {
322                            return success!(k);
323                        }
324                    }
325                }
326                _ => (),
327            },
328            Expr::ParenExpr(expr) => return self.try_const_folding(&expr),
329            _ => (),
330        }
331        Ok(None)
332    }
333
334    fn code_expr(&mut self, expr: &Expr, reg: Option<u32>) -> Result<ExprResult, CompileError> {
335        match expr {
336            Expr::BinExpr(bin) => match bin.op {
337                BinOp::And => self.code_and(reg, &bin.left, &bin.right),
338                _ => self.code_bin_op(bin.op, reg, &bin.left, &bin.right),
339            },
340            Expr::UnExpr(un) => {
341                if un.op == UnOp::Not {
342                    self.code_not(reg, &un.expr)
343                } else {
344                    let result = self.expr(&un.expr, reg)?;
345                    self.code_un_op(un.op, reg, result)
346                }
347            }
348            _ => unreachable!(),
349        }
350    }
351
352    fn const_folding_bin_op(
353        &self,
354        op: BinOp,
355        l: Const,
356        r: Const,
357    ) -> Result<Option<Const>, CompileError> {
358        let result = match op {
359            BinOp::Add => l.add(r)?,
360            BinOp::Minus => l.sub(r)?,
361            BinOp::Mul => l.mul(r)?,
362            BinOp::Div => l.div(r)?,
363            BinOp::IDiv => l.idiv(r)?,
364            BinOp::Mod => l.mod_(r)?,
365            BinOp::Pow => l.pow(r)?,
366            BinOp::BAnd => l.band(r)?,
367            BinOp::BOr => l.bor(r)?,
368            BinOp::BXor => l.bxor(r)?,
369            BinOp::Shl => l.shl(r)?,
370            BinOp::Shr => l.shr(r)?,
371            _ => None,
372        };
373        Ok(result)
374    }
375
376    fn const_folding_un_op(&self, op: UnOp, k: Const) -> Result<Option<Const>, CompileError> {
377        let result = match op {
378            UnOp::Minus => k.minus()?,
379            UnOp::BNot => k.bnot()?,
380            _ => None,
381        };
382        Ok(result)
383    }
384
385    fn get_right_input(&mut self, input: Option<u32>, left: &ExprResult) -> Option<u32> {
386        let mut right_input = None;
387        let is_input_reusable = |r: u32, input: u32| r < input;
388        if let Some(input_reg) = input {
389            right_input = match &left {
390                ExprResult::Reg(r) if !is_input_reusable(r.reg, input_reg) => None,
391                ExprResult::Jump(j) if !is_input_reusable(j.reg.reg, input_reg) => None,
392                _ => input,
393            };
394        };
395        right_input
396    }
397
398    fn alloc_reg(&mut self, input: &Option<u32>) -> Reg {
399        let reg = input.unwrap_or_else(|| self.context().reserve_regs(1));
400        if Some(reg) == *input {
401            Reg::new(reg)
402        } else {
403            Reg::new_temp(reg)
404        }
405    }
406
407    fn code_bin_op(
408        &mut self,
409        op: BinOp,
410        input: Option<u32>,
411        left_expr: &Expr,
412        right_expr: &Expr,
413    ) -> Result<ExprResult, CompileError> {
414        // get left expr result
415        let left = self.expr(left_expr, input)?;
416        // resolve previous expr result
417        left.resolve(self.context());
418
419        // if input reg is not used by left expr, apply it to right expr
420        let right_input = self.get_right_input(input, &left);
421
422        // get right expr result
423        let right = self.expr(right_expr, right_input)?;
424
425        // resolve previous expr result
426        right.resolve(self.context());
427
428        let alloc_reg = self.alloc_reg(&input);
429        let reg = alloc_reg.reg;
430        let mut result = ExprResult::Reg(alloc_reg);
431
432        // get rk of left and right expr
433        let mut get_rk = || {
434            let left_rk = left.get_rk(self.context());
435            let right_rk = right.get_rk(self.context());
436            (left_rk, right_rk)
437        };
438
439        // gennerate opcode of binop
440        match op {
441            _ if op.is_comp() => {
442                let (left_rk, right_rk) = get_rk();
443                result = self.code_comp(op, result, left_rk, right_rk);
444            }
445            _ => {
446                let (left_rk, right_rk) = get_rk();
447                self.proto().code_bin_op(op, reg, left_rk, right_rk);
448            }
449        };
450
451        Ok(result)
452    }
453
454    fn code_comp(&mut self, op: BinOp, target: ExprResult, left: u32, right: u32) -> ExprResult {
455        match target {
456            ExprResult::Reg(reg) => {
457                // covert >= to <=, > to <
458                let (left, right) = match op {
459                    BinOp::Ge | BinOp::Gt => (right, left),
460                    _ => (left, right),
461                };
462
463                let proto = self.proto();
464                proto.code_comp(op, left, right);
465                let jump = proto.code_jmp(NO_JUMP, 0);
466                ExprResult::new_jump(reg, jump)
467            }
468            _ => unreachable!(),
469        }
470    }
471
472    fn code_and(
473        &mut self,
474        input: Option<u32>,
475        left_expr: &Expr,
476        right_expr: &Expr,
477    ) -> Result<ExprResult, CompileError> {
478        // get left expr result
479        let mut left = self.expr(left_expr, input)?;
480        match &mut left {
481            // do const folding if left is const value
482            ExprResult::True | ExprResult::Const(_) => self.expr(right_expr, input),
483            ExprResult::Jump(j) => {
484                j.inverse_cond(self.context());
485                let mut right = self.expr(right_expr, Some(j.reg.reg))?;
486                match &mut right {
487                    ExprResult::Jump(rj) => rj.concat_false_jumps(j),
488                    _ => todo!(),
489                };
490                Ok(right)
491            }
492            ExprResult::Reg(_reg) => self.code_test(input, left, right_expr),
493            _ => todo!(),
494        }
495    }
496
497    fn code_test(
498        &mut self,
499        input: Option<u32>,
500        left: ExprResult,
501        right: &Expr,
502    ) -> Result<ExprResult, CompileError> {
503        match &left {
504            ExprResult::Reg(r) => {
505                let proto = self.proto();
506                proto.code_test_set(NO_REG, r.reg, 0);
507                let jump = proto.code_jmp(NO_JUMP, 0);
508                let right_input = self.get_right_input(input, &left);
509                let right_result = self.expr(right, right_input)?;
510                let mut jump = Jump::new(self.alloc_reg(&input), jump);
511                match &right_result {
512                    ExprResult::Reg(r) if r.is_const() => jump.set_reg_should_move(r.reg),
513                    _ => (),
514                };
515                Ok(ExprResult::Jump(jump))
516            }
517            _ => unreachable!(),
518        }
519    }
520
521    fn code_un_op(
522        &mut self,
523        op: UnOp,
524        input: Option<u32>,
525        expr: ExprResult,
526    ) -> Result<ExprResult, CompileError> {
527        let src = expr.get_rk(self.context());
528
529        // resolve previous result
530        expr.resolve(self.context());
531
532        let alloc_reg = self.alloc_reg(&input);
533        let reg = alloc_reg.reg;
534        let result = ExprResult::Reg(alloc_reg);
535
536        // gennerate opcode of unop
537        let proto = self.proto();
538        proto.code_un_op(op, reg, src);
539
540        Ok(result)
541    }
542
543    fn code_not(&mut self, input: Option<u32>, expr: &Expr) -> Result<ExprResult, CompileError> {
544        if let Some(_) = self.try_const_folding(expr)? {
545            Ok(ExprResult::False)
546        } else {
547            let result = self.expr(expr, input)?;
548            match &result {
549                ExprResult::Jump(j) => {
550                    j.inverse_cond(self.context());
551                    Ok(result)
552                }
553                ExprResult::Nil | ExprResult::False => Ok(ExprResult::True),
554                ExprResult::Const(_) | ExprResult::True => Ok(ExprResult::False),
555                _ => self.code_un_op(UnOp::Not, input, result),
556            }
557        }
558    }
559
560    // process expr and save to register
561    fn expr_and_save(&mut self, expr: &Expr, save_reg: Option<u32>) -> Result<u32, CompileError> {
562        let reg = save_reg.unwrap_or_else(|| self.context().reserve_regs(1));
563
564        // use a register to store temp result
565        let temp_reg = if Some(reg) != save_reg {
566            reg
567        } else {
568            self.context().reserve_regs(1)
569        };
570
571        let result = self.expr(expr, Some(temp_reg))?;
572        let proto = self.proto();
573        match result {
574            ExprResult::Const(k) => {
575                let index = proto.add_const(k);
576                proto.code_const(reg, index)
577            }
578            ExprResult::Reg(src) if src.is_const() => proto.code_move(reg, src.reg),
579            ExprResult::Reg(_) => proto.save(reg),
580            ExprResult::True => proto.code_bool(reg, true, 0),
581            ExprResult::False => proto.code_bool(reg, false, 0),
582            ExprResult::Nil => proto.code_nil(reg, 1),
583            ExprResult::Jump(j) => {
584                j.free(self.context());
585                0
586            }
587        };
588
589        if temp_reg != reg {
590            self.context().free_reg(1);
591        }
592
593        Ok(reg)
594    }
595
596    fn get_assinable_reg(&mut self, assignable: &Assignable) -> u32 {
597        match assignable {
598            Assignable::Name(name) => self.proto().get_local_var(name).unwrap(),
599            Assignable::SuffixedExpr(_) => todo!(),
600        }
601    }
602
603    debuggable!();
604}
605
606impl AstVisitor<CompileError> for Compiler {
607    // error handler
608    fn error(&mut self, e: CompileError, source: &Source) -> Result<(), CompileError> {
609        compile_error!(self, e, source)
610    }
611
612    // compile local stat
613    fn local_stat(&mut self, stat: &LocalStat) -> Result<(), CompileError> {
614        let proto = self.proto();
615        for name in stat.names.iter() {
616            proto.add_local_var(name);
617        }
618        for expr in stat.exprs.iter() {
619            self.expr_and_save(expr, None)?;
620        }
621        self.adjust_assign(stat.names.len(), &stat.exprs);
622        Ok(())
623    }
624
625    // compile assign stat
626    fn assign_stat(&mut self, stat: &AssignStat) -> Result<(), CompileError> {
627        let use_temp_reg = stat.right.len() != stat.left.len();
628        let mut to_move: Vec<(u32, u32)> = Vec::new();
629
630        // move rules:
631        // if num of left != num of right:
632        //      MOVE temp[1..n] right[1..n]
633        //      MOVE left[1..n] temp[1..n]
634        // if num of left == num of right:
635        //      MOVE temp[1..(n-1)] right[1..(n-1)]
636        //      MOVE left[n] right[n]
637        //      MOVE left[1..(n-1)] temp[1..(n-1)]
638        for (i, expr) in stat.right.iter().enumerate() {
639            if i != stat.right.len() - 1 || use_temp_reg {
640                let reg = self.expr_and_save(expr, None)?;
641                if i < stat.left.len() {
642                    let target = self.get_assinable_reg(&stat.left[i]);
643                    to_move.push((target, reg));
644                }
645            } else {
646                let reg = self.get_assinable_reg(&stat.left[i]);
647                self.expr_and_save(expr, Some(reg))?;
648            };
649        }
650
651        // nil move
652        let reg = self.context().get_reg_top();
653        let extra = self.adjust_assign(stat.left.len(), &stat.right);
654        if extra > 0 {
655            let left_start = stat.left.len() as i32 - extra;
656            for i in 0..extra {
657                let target = self.get_assinable_reg(&stat.left[(left_start + i) as usize]);
658                let src = (reg as i32 + i) as u32;
659                to_move.push((target, src));
660            }
661        }
662
663        // apply moves
664        for (target, src) in to_move.iter().rev() {
665            self.proto().code_move(*target, *src);
666            self.context().free_reg(1);
667        }
668
669        // free extra regs
670        if extra < 0 {
671            self.context().free_reg(-extra as u32);
672        }
673
674        Ok(())
675    }
676}