runmat-vm 0.4.4

RunMat virtual machine and bytecode interpreter
Documentation
//! Statement lowering.

use crate::compiler::core::{Compiler, LoopLabels};
use crate::compiler::idioms;
use crate::compiler::CompileError;
use crate::instr::{EmitLabel, Instr};
use runmat_hir::{HirExpr, HirExprKind, HirStmt};

fn expr_supports_inline_emit(expr: &HirExpr) -> bool {
    match &expr.kind {
        HirExprKind::FuncCall(name, _) => !runmat_builtins::suppresses_auto_output(name),
        _ => true,
    }
}

fn label_for_expr(expr: &HirExpr) -> EmitLabel {
    if let HirExprKind::Var(var_id) = &expr.kind {
        EmitLabel::Var(var_id.0)
    } else {
        EmitLabel::Ans
    }
}

impl Compiler {
    pub(crate) fn compile_stmt_impl(&mut self, stmt: &HirStmt) -> Result<(), CompileError> {
        if idioms::try_lower_stmt_idiom(self, stmt)? {
            return Ok(());
        }
        match stmt {
            HirStmt::ExprStmt(expr, suppressed, _) => {
                self.compile_expr(expr)?;
                if !suppressed && expr_supports_inline_emit(expr) {
                    let label = label_for_expr(expr);
                    self.emit(Instr::EmitStackTop { label });
                }
                self.emit(Instr::Pop);
            }
            HirStmt::Assign(id, expr, suppressed, _) => {
                self.compile_expr(expr)?;
                self.emit(Instr::StoreVar(id.0));
                if !suppressed {
                    self.emit(Instr::EmitVar {
                        var_index: id.0,
                        label: EmitLabel::Var(id.0),
                    });
                }
            }
            HirStmt::If {
                cond,
                then_body,
                elseif_blocks,
                else_body,
                ..
            } => {
                self.compile_expr(cond)?;
                let mut last_jump = self.emit(Instr::JumpIfFalse(usize::MAX));
                for s in then_body {
                    self.compile_stmt(s)?;
                }
                let mut end_jumps = Vec::new();
                end_jumps.push(self.emit(Instr::Jump(usize::MAX)));
                for (c, b) in elseif_blocks {
                    self.patch(last_jump, Instr::JumpIfFalse(self.instructions.len()));
                    self.compile_expr(c)?;
                    last_jump = self.emit(Instr::JumpIfFalse(usize::MAX));
                    for s in b {
                        self.compile_stmt(s)?;
                    }
                    end_jumps.push(self.emit(Instr::Jump(usize::MAX)));
                }
                self.patch(last_jump, Instr::JumpIfFalse(self.instructions.len()));
                if let Some(body) = else_body {
                    for s in body {
                        self.compile_stmt(s)?;
                    }
                }
                let end = self.instructions.len();
                for j in end_jumps {
                    self.patch(j, Instr::Jump(end));
                }
            }
            HirStmt::While { cond, body, .. } => {
                let start = self.instructions.len();
                self.compile_expr(cond)?;
                let jump_end = self.emit(Instr::JumpIfFalse(usize::MAX));
                let labels = LoopLabels {
                    break_jumps: Vec::new(),
                    continue_jumps: Vec::new(),
                };
                self.loop_stack.push(labels);
                for s in body {
                    self.compile_stmt(s)?;
                }
                let labels = self.loop_stack.pop().unwrap();
                for j in labels.continue_jumps {
                    self.patch(j, Instr::Jump(start));
                }
                self.emit(Instr::Jump(start));
                let end = self.instructions.len();
                self.patch(jump_end, Instr::JumpIfFalse(end));
                for j in labels.break_jumps {
                    self.patch(j, Instr::Jump(end));
                }
            }
            HirStmt::For {
                var, expr, body, ..
            } => {
                if let HirExprKind::Range(start, step, end) = &expr.kind {
                    self.compile_expr(start)?;
                    self.emit(Instr::StoreVar(var.0));
                    self.compile_expr(end)?;
                    let end_var = self.alloc_temp();
                    self.emit(Instr::StoreVar(end_var));
                    let step_var = self.alloc_temp();
                    if let Some(step_expr) = step {
                        self.compile_expr(step_expr)?;
                        self.emit(Instr::StoreVar(step_var));
                    } else {
                        self.emit(Instr::LoadConst(1.0));
                        self.emit(Instr::StoreVar(step_var));
                    }

                    let loop_start = self.instructions.len();

                    self.emit(Instr::LoadVar(step_var));
                    self.emit(Instr::LoadConst(0.0));
                    self.emit(Instr::Equal);
                    let j_step_zero_skip = self.emit(Instr::JumpIfFalse(usize::MAX));
                    let jump_end = self.emit(Instr::Jump(usize::MAX));
                    let after_step_zero_check = self.instructions.len();
                    self.patch(j_step_zero_skip, Instr::JumpIfFalse(after_step_zero_check));

                    self.emit(Instr::LoadVar(step_var));
                    self.emit(Instr::LoadConst(0.0));
                    self.emit(Instr::GreaterEqual);
                    let j_neg_branch = self.emit(Instr::JumpIfFalse(usize::MAX));
                    self.emit(Instr::LoadVar(var.0));
                    self.emit(Instr::LoadVar(end_var));
                    self.emit(Instr::LessEqual);
                    let j_exit_pos = self.emit(Instr::JumpIfFalse(usize::MAX));
                    let j_cond_done = self.emit(Instr::Jump(usize::MAX));
                    let neg_branch = self.instructions.len();
                    self.patch(j_neg_branch, Instr::JumpIfFalse(neg_branch));
                    self.emit(Instr::LoadVar(var.0));
                    self.emit(Instr::LoadVar(end_var));
                    self.emit(Instr::GreaterEqual);
                    let j_exit_neg = self.emit(Instr::JumpIfFalse(usize::MAX));
                    let cond_done = self.instructions.len();
                    self.patch(j_cond_done, Instr::Jump(cond_done));

                    self.loop_stack.push(LoopLabels {
                        break_jumps: Vec::new(),
                        continue_jumps: Vec::new(),
                    });
                    for s in body {
                        self.compile_stmt(s)?;
                    }
                    let labels = self.loop_stack.pop().unwrap();
                    for j in labels.continue_jumps {
                        self.patch(j, Instr::Jump(self.instructions.len()));
                    }
                    self.emit(Instr::LoadVar(var.0));
                    self.emit(Instr::LoadVar(step_var));
                    self.emit(Instr::Add);
                    self.emit(Instr::StoreVar(var.0));
                    self.emit(Instr::Jump(loop_start));

                    let end_pc = self.instructions.len();
                    self.patch(jump_end, Instr::Jump(end_pc));
                    self.patch(j_exit_pos, Instr::JumpIfFalse(end_pc));
                    self.patch(j_exit_neg, Instr::JumpIfFalse(end_pc));
                    for j in labels.break_jumps {
                        self.patch(j, Instr::Jump(end_pc));
                    }
                } else {
                    return Err(self.compile_error("for loop expects range"));
                }
            }
            HirStmt::Break(_) => {
                if self.loop_stack.is_empty() {
                    return Err(self.compile_error("break outside loop"));
                }
                let idx = self.emit(Instr::Jump(usize::MAX));
                if let Some(labels) = self.loop_stack.last_mut() {
                    labels.break_jumps.push(idx);
                }
            }
            HirStmt::Continue(_) => {
                if self.loop_stack.is_empty() {
                    return Err(self.compile_error("continue outside loop"));
                }
                let idx = self.emit(Instr::Jump(usize::MAX));
                if let Some(labels) = self.loop_stack.last_mut() {
                    labels.continue_jumps.push(idx);
                }
            }
            HirStmt::Return(_) => {
                self.emit(Instr::Return);
            }
            HirStmt::Function {
                name,
                params,
                outputs,
                body,
                has_varargin,
                has_varargout,
                ..
            } => self.compile_function_stmt(
                name,
                params,
                outputs,
                body,
                *has_varargin,
                *has_varargout,
            )?,
            HirStmt::Switch {
                expr,
                cases,
                otherwise,
                ..
            } => {
                let temp_id = self.alloc_temp();
                self.compile_expr(expr)?;
                self.emit(Instr::StoreVar(temp_id));
                let mut end_jumps: Vec<usize> = Vec::new();
                let mut next_case_jump_to_here: Option<usize> = None;
                for (case_expr, body) in cases {
                    if let Some(j) = next_case_jump_to_here.take() {
                        self.patch(j, Instr::JumpIfFalse(self.instructions.len()));
                    }
                    self.emit(Instr::LoadVar(temp_id));
                    self.compile_expr(case_expr)?;
                    self.emit(Instr::Equal);
                    let jmp = self.emit(Instr::JumpIfFalse(usize::MAX));
                    for s in body {
                        self.compile_stmt(s)?;
                    }
                    end_jumps.push(self.emit(Instr::Jump(usize::MAX)));
                    next_case_jump_to_here = Some(jmp);
                }
                let otherwise_start = self.instructions.len();
                if let Some(j) = next_case_jump_to_here.take() {
                    self.patch(j, Instr::JumpIfFalse(otherwise_start));
                }
                if let Some(body) = otherwise {
                    for s in body {
                        self.compile_stmt(s)?;
                    }
                }
                let end = self.instructions.len();
                for j in end_jumps {
                    self.patch(j, Instr::Jump(end));
                }
            }
            HirStmt::TryCatch {
                try_body,
                catch_var,
                catch_body,
                ..
            } => {
                let enter_idx = self.emit(Instr::EnterTry(usize::MAX, catch_var.map(|v| v.0)));
                for s in try_body {
                    self.compile_stmt(s)?;
                }
                self.emit(Instr::PopTry);
                let jmp_end = self.emit(Instr::Jump(usize::MAX));
                let catch_pc = self.instructions.len();
                self.patch(enter_idx, Instr::EnterTry(catch_pc, catch_var.map(|v| v.0)));
                for s in catch_body {
                    self.compile_stmt(s)?;
                }
                let end_pc = self.instructions.len();
                self.patch(jmp_end, Instr::Jump(end_pc));
            }
            HirStmt::AssignLValue(lv, rhs, _, _) => self.compile_assign_lvalue(lv, rhs)?,
            HirStmt::Global(vars, _) => {
                let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
                let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
                self.emit(Instr::DeclareGlobalNamed(ids, names));
            }
            HirStmt::Persistent(vars, _) => {
                let ids: Vec<usize> = vars.iter().map(|(v, _n)| v.0).collect();
                let names: Vec<String> = vars.iter().map(|(_v, n)| n.clone()).collect();
                self.emit(Instr::DeclarePersistentNamed(ids, names));
            }
            HirStmt::Import { path, wildcard, .. } => {
                self.emit(Instr::RegisterImport {
                    path: path.clone(),
                    wildcard: *wildcard,
                });
            }
            HirStmt::ClassDef {
                name,
                super_class,
                members,
                ..
            } => self.compile_class_def(name, super_class, members)?,
            HirStmt::MultiAssign(vars, expr, suppressed, _) => {
                match &expr.kind {
                    HirExprKind::FuncCall(name, args) => {
                        let call_arg_spans: Vec<runmat_hir::Span> =
                            args.iter().map(|a| a.span).collect();
                        if self.functions.contains_key(name) {
                            for arg in args {
                                self.compile_expr(arg)?;
                            }
                            self.emit_call_with_arg_spans(
                                Instr::CallFunctionMulti(name.clone(), args.len(), vars.len()),
                                &call_arg_spans,
                            );
                            self.emit(Instr::Unpack(vars.len()));
                            for (_i, var) in vars.iter().enumerate().rev() {
                                if let Some(v) = var {
                                    self.emit(Instr::StoreVar(v.0));
                                } else {
                                    self.emit(Instr::Pop);
                                }
                            }
                        } else {
                            for arg in args {
                                self.compile_expr(arg)?;
                            }
                            self.emit_call_with_arg_spans(
                                Instr::CallBuiltin(name.clone(), args.len()),
                                &call_arg_spans,
                            );
                            self.emit(Instr::Unpack(vars.len()));
                            for (_i, var) in vars.iter().enumerate().rev() {
                                if let Some(v) = var {
                                    self.emit(Instr::StoreVar(v.0));
                                } else {
                                    self.emit(Instr::Pop);
                                }
                            }
                        }
                    }
                    HirExprKind::IndexCell(base, indices) => {
                        self.compile_expr(base)?;
                        for index in indices {
                            self.compile_expr(index)?;
                        }
                        self.emit(Instr::IndexCellExpand(indices.len(), vars.len()));
                        for (_i, var) in vars.iter().enumerate().rev() {
                            if let Some(v) = var {
                                self.emit(Instr::StoreVar(v.0));
                            } else {
                                self.emit(Instr::Pop);
                            }
                        }
                    }
                    _ => {
                        let first_real = vars.iter().position(|v| v.is_some());
                        if let Some(first_idx) = first_real {
                            self.compile_expr(expr)?;
                            if let Some(Some(first_var)) = vars.get(first_idx) {
                                self.emit(Instr::StoreVar(first_var.0));
                            }
                        }
                        for (i, var) in vars.iter().enumerate() {
                            if Some(i) == first_real {
                                continue;
                            }
                            if let Some(v) = var {
                                self.emit(Instr::LoadConst(0.0));
                                self.emit(Instr::StoreVar(v.0));
                            }
                        }
                    }
                }
                if !suppressed {
                    self.emit_multiassign_outputs(vars);
                }
            }
        }
        Ok(())
    }
}