use std::collections::{HashMap, HashSet};
use super::{Chunk, CompileError, Instr, Opcode};
use crate::env::Value;
use crate::eval::{Expr, Op, is_global, is_persistent};
use crate::parser::{Stmt, StmtEntry};
const EXEC_INTERCEPTS: &[&str] = &[
"run", "source", "addpath", "rmpath", "path", "clear",
"remove", "format", "save", "load", "ws", "wl",
];
pub const COMPILABLE_BUILTINS: &[&str] = &[
"abs",
"sqrt",
"exp",
"log",
"log2",
"log10",
"sin",
"cos",
"tan",
"asin",
"acos",
"atan",
"atan2",
"floor",
"ceil",
"round",
"fix",
"sign",
"mod",
"rem",
"real",
"imag",
"conj",
"angle",
"isreal",
"isnan",
"isinf",
"isfinite",
"sum",
"prod",
"mean",
"norm",
"max",
"min",
"any",
"all",
"cumsum",
"cumprod",
"size",
"length",
"numel",
"zeros",
"ones",
"eye",
"reshape",
"fliplr",
"flipud",
"sort",
"unique",
"find",
"num2str",
"int2str",
"str2num",
"str2double",
"ischar",
"iscell",
"isstruct",
];
pub fn is_compilable(stmts: &[StmtEntry]) -> bool {
stmts.iter().all(|(stmt, _, _)| stmt_compilable(stmt))
}
fn stmt_compilable(stmt: &Stmt) -> bool {
match stmt {
Stmt::Break | Stmt::Continue | Stmt::Return => true,
Stmt::FunctionDef { .. } => true,
Stmt::Assign(_, expr) => !is_exec_intercepted_call(expr),
Stmt::Expr(Expr::Call(name, args)) => {
if EXEC_INTERCEPTS.contains(&name.as_str()) {
return false;
}
if name == "eval" && (args.len() == 1 || args.len() == 2) {
return false;
}
true
}
Stmt::Expr(_) => true,
Stmt::For { body, .. } => is_compilable(body),
Stmt::While { body, .. } => is_compilable(body),
Stmt::If {
body,
elseif_branches,
else_body,
..
} => {
is_compilable(body)
&& elseif_branches.iter().all(|(_, b)| is_compilable(b))
&& else_body.as_ref().is_none_or(|b| is_compilable(b))
}
Stmt::IndexSet { .. } => true,
_ => false,
}
}
pub fn compile(stmts: &[StmtEntry]) -> Result<Chunk, CompileError> {
let mut candidates: Vec<String> = Vec::new();
collect_candidates(stmts, &mut candidates);
let mut env_required: HashSet<String> = HashSet::new();
env_required.insert("ans".to_string());
collect_env_required(stmts, &mut env_required);
let mut slot_map: HashMap<String, u16> = HashMap::new();
let mut slot_names: Vec<String> = Vec::new();
for name in &candidates {
if !env_required.contains(name) && !is_global(name) && !is_persistent(name) {
let slot = slot_names.len() as u16;
slot_map.insert(name.clone(), slot);
slot_names.push(name.clone());
}
}
let mut compiler = Compiler {
chunk: Chunk::new(),
loop_stack: Vec::new(),
current_line: 0,
slots: slot_map,
};
compiler.chunk.slot_names = slot_names;
compiler.compile_stmts(stmts)?;
Ok(compiler.chunk)
}
struct LoopFrame {
continue_target: usize,
break_patches: Vec<usize>,
is_for: bool,
}
struct Compiler {
chunk: Chunk,
loop_stack: Vec<LoopFrame>,
current_line: usize,
slots: HashMap<String, u16>,
}
impl Compiler {
fn emit(&mut self, instr: Instr) {
self.chunk.lines.push(self.current_line);
self.chunk.code.push(instr);
}
fn compile_stmts(&mut self, stmts: &[StmtEntry]) -> Result<(), CompileError> {
for (stmt, silent, line) in stmts {
self.current_line = *line;
self.compile_stmt(stmt, *silent)?;
}
Ok(())
}
fn compile_stmt(&mut self, stmt: &Stmt, silent: bool) -> Result<(), CompileError> {
match stmt {
Stmt::Assign(name, expr) => {
if is_exec_intercepted_call(expr) {
return Err(CompileError::Unsupported);
}
self.compile_expr_push(expr);
if let Some(&slot) = self.slots.get(name) {
self.emit(Instr::with_u16_u8(
Opcode::StoreSlot,
slot,
u8::from(silent),
));
} else {
let idx = self.chunk.name_idx(name);
self.emit(Instr::with_u16_u8(Opcode::StoreVar, idx, u8::from(silent)));
}
Ok(())
}
Stmt::Expr(expr) => {
if let Expr::Call(name, _) = expr
&& EXEC_INTERCEPTS.contains(&name.as_str())
{
return Err(CompileError::Unsupported);
}
if let Expr::Call(name, args) = expr
&& name == "eval"
&& (args.len() == 1 || args.len() == 2)
{
return Err(CompileError::Unsupported);
}
self.compile_expr_push(expr);
self.emit(Instr::no_arg(Opcode::UpdateAns));
if silent {
self.emit(Instr::no_arg(Opcode::Pop));
} else {
self.emit(Instr::no_arg(Opcode::Print));
}
Ok(())
}
Stmt::For {
var,
range_expr,
body,
} => {
self.compile_expr_push(range_expr);
self.emit(Instr::no_arg(Opcode::PushIter));
let iter_next_pos = self.chunk.code.len();
if let Some(&slot) = self.slots.get(var) {
self.chunk
.code
.push(Instr::with_u16_i32(Opcode::IterNextSlot, slot, 0));
self.loop_stack.push(LoopFrame {
continue_target: iter_next_pos,
break_patches: Vec::new(),
is_for: true,
});
self.compile_stmts(body)?;
let back_off = iter_next_pos as i32 - self.chunk.code.len() as i32 - 1;
self.emit(Instr::with_i32(Opcode::Jump, back_off));
let exit_pos = self.chunk.code.len();
let exit_off = exit_pos as i32 - iter_next_pos as i32 - 1;
self.chunk.code[iter_next_pos].set_u16_i32(slot, exit_off);
let frame = self.loop_stack.pop().unwrap();
for p in frame.break_patches {
let off = exit_pos as i32 - p as i32 - 1;
self.chunk.code[p].set_i32(off);
}
} else {
let var_idx = self.chunk.name_idx(var);
self.chunk
.code
.push(Instr::with_u16_i32(Opcode::IterNext, var_idx, 0));
self.loop_stack.push(LoopFrame {
continue_target: iter_next_pos,
break_patches: Vec::new(),
is_for: true,
});
self.compile_stmts(body)?;
let back_off = iter_next_pos as i32 - self.chunk.code.len() as i32 - 1;
self.emit(Instr::with_i32(Opcode::Jump, back_off));
let exit_pos = self.chunk.code.len();
let exit_off = exit_pos as i32 - iter_next_pos as i32 - 1;
self.chunk.code[iter_next_pos].set_u16_i32(var_idx, exit_off);
let frame = self.loop_stack.pop().unwrap();
for p in frame.break_patches {
let off = exit_pos as i32 - p as i32 - 1;
self.chunk.code[p].set_i32(off);
}
}
Ok(())
}
Stmt::While { cond, body } => {
let cond_pos = self.chunk.code.len();
self.compile_expr_push(cond);
let jf_idx = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::JumpFalsy, 0));
self.loop_stack.push(LoopFrame {
continue_target: cond_pos,
break_patches: Vec::new(),
is_for: false,
});
self.compile_stmts(body)?;
let back_off = cond_pos as i32 - self.chunk.code.len() as i32 - 1;
self.emit(Instr::with_i32(Opcode::Jump, back_off));
let exit_pos = self.chunk.code.len();
let jf_off = exit_pos as i32 - jf_idx as i32 - 1;
self.chunk.code[jf_idx].set_i32(jf_off);
let frame = self.loop_stack.pop().unwrap();
for p in frame.break_patches {
let off = exit_pos as i32 - p as i32 - 1;
self.chunk.code[p].set_i32(off);
}
Ok(())
}
Stmt::If {
cond,
body,
elseif_branches,
else_body,
} => {
let mut end_patches: Vec<usize> = Vec::new();
self.compile_expr_push(cond);
let jf_idx = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::JumpFalsy, 0));
self.compile_stmts(body)?;
if elseif_branches.is_empty() && else_body.is_none() {
let exit_pos = self.chunk.code.len();
let off = exit_pos as i32 - jf_idx as i32 - 1;
self.chunk.code[jf_idx].set_i32(off);
} else {
let ej_idx = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::Jump, 0));
end_patches.push(ej_idx);
let next = self.chunk.code.len();
let off = next as i32 - jf_idx as i32 - 1;
self.chunk.code[jf_idx].set_i32(off);
for (ei_cond, ei_body) in elseif_branches {
self.compile_expr_push(ei_cond);
let ei_jf = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::JumpFalsy, 0));
self.compile_stmts(ei_body)?;
let ej2 = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::Jump, 0));
end_patches.push(ej2);
let next2 = self.chunk.code.len();
let off2 = next2 as i32 - ei_jf as i32 - 1;
self.chunk.code[ei_jf].set_i32(off2);
}
if let Some(else_stmts) = else_body {
self.compile_stmts(else_stmts)?;
}
let end_pos = self.chunk.code.len();
for p in end_patches {
let off = end_pos as i32 - p as i32 - 1;
self.chunk.code[p].set_i32(off);
}
}
Ok(())
}
Stmt::Break => {
let Some(frame) = self.loop_stack.last_mut() else {
return Err(CompileError::Unsupported);
};
if frame.is_for {
self.emit(Instr::no_arg(Opcode::PopIter));
}
let j_idx = self.chunk.code.len();
self.emit(Instr::with_i32(Opcode::Jump, 0));
let j_idx_owned = j_idx;
self.loop_stack
.last_mut()
.unwrap()
.break_patches
.push(j_idx_owned);
Ok(())
}
Stmt::Continue => {
let Some(frame) = self.loop_stack.last() else {
return Err(CompileError::Unsupported);
};
let target = frame.continue_target;
let off = target as i32 - self.chunk.code.len() as i32 - 1;
self.emit(Instr::with_i32(Opcode::Jump, off));
Ok(())
}
Stmt::Return => {
self.emit(Instr::no_arg(Opcode::Return));
Ok(())
}
Stmt::FunctionDef {
name,
outputs,
params,
body_source,
doc,
} => {
use crate::env::{FunctionData, Value};
use indexmap::IndexMap;
let func = Value::Function(Box::new(FunctionData {
outputs: outputs.clone(),
params: params.clone(),
body_source: body_source.clone(),
locals: IndexMap::new(),
doc: doc.clone(),
}));
let const_idx = self.chunk.add_const(func);
let name_idx = self.chunk.name_idx(name);
self.emit(Instr::with_u16_u16(Opcode::DefineFunc, name_idx, const_idx));
Ok(())
}
Stmt::IndexSet {
name,
indices,
value,
} => {
self.compile_expr_push(value);
let iset_idx = self.chunk.index_sets.len() as u16;
self.chunk.index_sets.push(indices.clone());
let name_idx = self.chunk.name_idx(name);
self.emit(Instr::with_u16_u16_u8(
Opcode::IndexSetOp,
name_idx,
iset_idx,
u8::from(silent),
));
Ok(())
}
_ => Err(CompileError::Unsupported),
}
}
fn compile_expr_push(&mut self, expr: &Expr) {
if Self::is_pure(expr) {
self.compile_native(expr);
} else {
let idx = self.chunk.add_expr(expr.clone());
self.emit(Instr::with_u16(Opcode::EvalExpr, idx));
}
}
fn is_pure(expr: &Expr) -> bool {
match expr {
Expr::Number(_) | Expr::Var(_) => true,
Expr::UnaryMinus(e) | Expr::UnaryNot(e) => Self::is_pure(e),
Expr::BinOp(a, op, b) => {
!matches!(op, Op::ElemAnd | Op::ElemOr | Op::LDiv)
&& Self::is_pure(a)
&& Self::is_pure(b)
}
Expr::Call(name, args) => {
COMPILABLE_BUILTINS.contains(&name.as_str()) && args.iter().all(Self::is_pure)
}
_ => false,
}
}
fn compile_native(&mut self, expr: &Expr) {
match expr {
Expr::Number(f) => {
let idx = self.chunk.add_const(Value::Scalar(*f));
self.emit(Instr::with_u16(Opcode::PushConst, idx));
}
Expr::Var(name) => {
if let Some(&slot) = self.slots.get(name) {
self.emit(Instr::with_u16(Opcode::LoadSlot, slot));
} else {
let idx = self.chunk.name_idx(name);
self.emit(Instr::with_u16(Opcode::LoadVar, idx));
}
}
Expr::UnaryMinus(inner) => {
self.compile_native(inner);
self.emit(Instr::no_arg(Opcode::Neg));
}
Expr::UnaryNot(inner) => {
self.compile_native(inner);
self.emit(Instr::no_arg(Opcode::Not));
}
Expr::BinOp(left, op, right) => {
self.compile_native(left);
self.compile_native(right);
let opcode = match op {
Op::Add => Opcode::Add,
Op::Sub => Opcode::Sub,
Op::Mul => Opcode::Mul,
Op::Div => Opcode::Div,
Op::Pow => Opcode::Pow,
Op::ElemMul => Opcode::ElemMul,
Op::ElemDiv => Opcode::ElemDiv,
Op::ElemPow => Opcode::ElemPow,
Op::Eq => Opcode::Eq,
Op::NotEq => Opcode::Ne,
Op::Lt => Opcode::Lt,
Op::LtEq => Opcode::Le,
Op::Gt => Opcode::Gt,
Op::GtEq => Opcode::Ge,
Op::And => Opcode::And,
Op::Or => Opcode::Or,
Op::ElemAnd | Op::ElemOr | Op::LDiv => {
unreachable!("compile_native: is_pure should have excluded this op")
}
};
self.emit(Instr::no_arg(opcode));
}
Expr::Call(name, args) => {
for arg in args {
self.compile_native(arg);
}
let name_idx = self.chunk.name_idx(name);
self.emit(Instr::with_u16_u8(
Opcode::CallBuiltin,
name_idx,
args.len() as u8,
));
}
_ => unreachable!("compile_native called on non-pure expression"),
}
}
}
fn is_exec_intercepted_call(expr: &Expr) -> bool {
if let Expr::Call(name, _) = expr {
return EXEC_INTERCEPTS.contains(&name.as_str());
}
false
}
fn collect_candidates(stmts: &[StmtEntry], out: &mut Vec<String>) {
for (stmt, _, _) in stmts {
match stmt {
Stmt::Assign(name, _) if !out.contains(name) => {
out.push(name.clone());
}
Stmt::Assign(_, _) => {}
Stmt::For { var, body, .. } => {
if !out.contains(var) {
out.push(var.clone());
}
collect_candidates(body, out);
}
Stmt::While { body, .. } => collect_candidates(body, out),
Stmt::If {
body,
elseif_branches,
else_body,
..
} => {
collect_candidates(body, out);
for (_, b) in elseif_branches {
collect_candidates(b, out);
}
if let Some(b) = else_body {
collect_candidates(b, out);
}
}
_ => {}
}
}
}
fn collect_env_required(stmts: &[StmtEntry], out: &mut HashSet<String>) {
for (stmt, _, _) in stmts {
match stmt {
Stmt::Assign(_, expr) if !Compiler::is_pure(expr) => {
free_vars_in_expr(expr, out);
}
Stmt::Assign(_, _) => {}
Stmt::Expr(expr) if !Compiler::is_pure(expr) => {
free_vars_in_expr(expr, out);
}
Stmt::Expr(_) => {}
Stmt::For {
range_expr, body, ..
} => {
if !Compiler::is_pure(range_expr) {
free_vars_in_expr(range_expr, out);
}
collect_env_required(body, out);
}
Stmt::While { cond, body } => {
if !Compiler::is_pure(cond) {
free_vars_in_expr(cond, out);
}
collect_env_required(body, out);
}
Stmt::If {
cond,
body,
elseif_branches,
else_body,
} => {
if !Compiler::is_pure(cond) {
free_vars_in_expr(cond, out);
}
collect_env_required(body, out);
for (ei_cond, ei_body) in elseif_branches {
if !Compiler::is_pure(ei_cond) {
free_vars_in_expr(ei_cond, out);
}
collect_env_required(ei_body, out);
}
if let Some(b) = else_body {
collect_env_required(b, out);
}
}
Stmt::IndexSet {
name,
indices,
value,
} => {
out.insert(name.clone());
for idx in indices {
free_vars_in_expr(idx, out);
}
if !Compiler::is_pure(value) {
free_vars_in_expr(value, out);
}
}
_ => {}
}
}
}
fn free_vars_in_expr(expr: &Expr, out: &mut HashSet<String>) {
match expr {
Expr::Var(name) => {
out.insert(name.clone());
}
Expr::Number(_)
| Expr::StrLiteral(_)
| Expr::StringObjLiteral(_)
| Expr::Colon
| Expr::NaT
| Expr::FuncHandle(_) => {}
Expr::UnaryMinus(e) | Expr::UnaryNot(e) | Expr::Transpose(e) | Expr::PlainTranspose(e) => {
free_vars_in_expr(e, out);
}
Expr::BinOp(a, _, b) => {
free_vars_in_expr(a, out);
free_vars_in_expr(b, out);
}
Expr::Call(name, args) => {
out.insert(name.clone());
for a in args {
free_vars_in_expr(a, out);
}
}
Expr::CellLiteral(args) => {
for a in args {
free_vars_in_expr(a, out);
}
}
Expr::Matrix(rows) => {
for row in rows {
for e in row {
free_vars_in_expr(e, out);
}
}
}
Expr::Range(a, step, b) => {
free_vars_in_expr(a, out);
if let Some(s) = step {
free_vars_in_expr(s, out);
}
free_vars_in_expr(b, out);
}
Expr::CellIndex(base, idx) => {
free_vars_in_expr(base, out);
free_vars_in_expr(idx, out);
}
Expr::FieldGet(base, _) => {
free_vars_in_expr(base, out);
}
Expr::DynFieldGet(base, field) => {
free_vars_in_expr(base, out);
free_vars_in_expr(field, out);
}
Expr::DotCall(_, args) => {
for a in args {
free_vars_in_expr(a, out);
}
}
Expr::Lambda { body, .. } => {
free_vars_in_expr(body, out);
}
}
}