use std::collections::HashMap;
use compact_str::{CompactString, ToCompactString};
use itertools::Itertools;
use crate::ast::{ProcedureKind, TypeAnnotation};
use crate::decorator::Decorator;
use crate::interpreter::{
Interpreter, InterpreterResult, InterpreterSettings, Result, RuntimeError, RuntimeErrorKind,
};
use crate::name_resolution::LAST_RESULT_IDENTIFIERS;
use crate::prefix::Prefix;
use crate::prefix_parser::AcceptsPrefix;
use crate::pretty_print::PrettyPrint;
use crate::typechecker::TypeChecker;
use crate::typed_ast;
use crate::typed_ast::{
BinaryOperator, DefineVariable, Expression, Statement, StringPart, UnaryOperator,
};
use crate::unit::{CanonicalName, Unit};
use crate::unit_registry::{UnitMetadata, UnitRegistry};
use crate::value::{FunctionReference, Value};
use crate::vm::{Constant, ExecutionContext, FfiCallArg, FfiCallArgs, Op, Vm};
use crate::{Type, decorator};
#[derive(Debug, Clone, Default)]
pub struct LocalMetadata {
pub name: Option<CompactString>,
pub url: Option<CompactString>,
pub description: Option<CompactString>,
pub aliases: Vec<CompactString>,
}
#[derive(Debug, Clone)]
pub struct Local {
identifiers: Vec<CompactString>,
pub metadata: LocalMetadata,
}
#[derive(Clone)]
pub struct BytecodeInterpreter {
vm: Vm,
locals: Vec<Vec<Local>>,
unit_name_to_constant_index: HashMap<CompactString, u16>,
functions: HashMap<CompactString, bool>,
}
impl BytecodeInterpreter {
pub fn runtime_error(&self, kind: RuntimeErrorKind) -> RuntimeError {
self.vm.runtime_error(kind)
}
fn compile_expression(&mut self, expr: &Expression) {
match expr {
Expression::Scalar { span, value, .. } => {
let index = self.vm.add_constant(Constant::Scalar(value.to_f64()));
self.vm.add_op1(Op::LoadConstant, index, *span);
}
Expression::Identifier {
span,
name: identifier,
..
} => {
let current_depth = self.locals.len() - 1;
if let Some(position) = self.locals[current_depth]
.iter()
.rposition(|l| l.identifiers.iter().any(|n| n == identifier))
{
self.vm.add_op1(Op::GetLocal, position as u16, *span); } else if let Some(upvalue_position) = self.locals[0]
.iter()
.rposition(|l| l.identifiers.iter().any(|n| n == identifier))
{
self.vm
.add_op1(Op::GetUpvalue, upvalue_position as u16, *span);
} else if LAST_RESULT_IDENTIFIERS.contains(identifier) {
self.vm.add_op(Op::GetLastResult, *span);
} else if let Some(is_foreign) = self.functions.get(*identifier) {
let index = self
.vm
.add_constant(Constant::FunctionReference(if *is_foreign {
FunctionReference::Foreign(identifier.to_compact_string())
} else {
FunctionReference::Normal(identifier.to_compact_string())
}));
self.vm.add_op1(Op::LoadConstant, index, *span);
} else {
unreachable!("Unknown identifier '{identifier}'")
}
}
Expression::UnitIdentifier {
span,
prefix,
name: unit_name,
..
} => {
let index = self
.unit_name_to_constant_index
.get(unit_name)
.expect("unit should already exist");
self.vm.add_op1(Op::LoadConstant, *index, *span);
if prefix != &Prefix::none() {
let prefix_idx = self.vm.add_prefix(*prefix);
self.vm.add_op1(Op::ApplyPrefix, prefix_idx, *span);
}
}
Expression::UnaryOperator {
span,
op: UnaryOperator::Negate,
expr: rhs,
..
} => {
self.compile_expression(rhs);
self.vm.add_op(Op::Negate, *span);
}
Expression::UnaryOperator {
span,
op: UnaryOperator::Factorial(order),
expr: lhs,
..
} => {
self.compile_expression(lhs);
self.vm.add_op1(Op::Factorial, order.get() as u16, *span);
}
Expression::UnaryOperator {
span,
op: UnaryOperator::LogicalNeg,
expr: lhs,
..
} => {
self.compile_expression(lhs);
self.vm.add_op(Op::LogicalNeg, *span);
}
Expression::BinaryOperator {
op_span,
op: operator,
lhs,
rhs,
..
} => {
self.compile_expression(lhs);
self.compile_expression(rhs);
let op = match operator {
BinaryOperator::Add => Op::Add,
BinaryOperator::Sub => Op::Subtract,
BinaryOperator::Mul => Op::Multiply,
BinaryOperator::Div => Op::Divide,
BinaryOperator::Power => Op::Power,
BinaryOperator::ConvertTo => Op::ConvertTo,
BinaryOperator::LessThan => Op::LessThan,
BinaryOperator::GreaterThan => Op::GreaterThan,
BinaryOperator::LessOrEqual => Op::LessOrEqual,
BinaryOperator::GreaterOrEqual => Op::GreatorOrEqual,
BinaryOperator::Equal => Op::Equal,
BinaryOperator::NotEqual => Op::NotEqual,
BinaryOperator::LogicalAnd => Op::LogicalAnd,
BinaryOperator::LogicalOr => Op::LogicalOr,
};
self.vm.add_op(
op,
op_span.unwrap_or_else(|| {
crate::span::Span::in_between(lhs.full_span(), rhs.full_span())
}),
);
}
Expression::BinaryOperatorForDate {
op_span,
op: operator,
lhs,
rhs,
type_scheme,
} => {
self.compile_expression(lhs);
self.compile_expression(rhs);
let op = if type_scheme.is_dtype() {
let second_idx = self.unit_name_to_constant_index.get("second");
self.vm.add_op1(
Op::LoadConstant,
*second_idx.unwrap(),
op_span.unwrap_or_else(|| {
crate::span::Span::in_between(lhs.full_span(), rhs.full_span())
}),
);
Op::DiffDateTime
} else {
match operator {
BinaryOperator::Add => Op::AddToDateTime,
BinaryOperator::Sub => Op::SubFromDateTime,
_ => unreachable!("{operator:?} is not valid with a DateTime"), }
};
self.vm.add_op(
op,
op_span.unwrap_or_else(|| {
crate::span::Span::in_between(lhs.full_span(), rhs.full_span())
}),
);
}
Expression::FunctionCall {
full_span,
name,
args,
type_scheme,
..
} => {
for arg in args {
self.compile_expression(arg);
}
if let Some(idx) = self.vm.get_ffi_callable_idx(name) {
let call_args = FfiCallArgs {
args: args
.iter()
.map(|a: &typed_ast::Expression| FfiCallArg {
span: a.full_span(),
type_: a.get_type_scheme(),
})
.collect(),
return_type: Some(type_scheme.clone()),
};
let call_args_idx = self.vm.add_ffi_call_args(call_args);
self.vm.add_op3(
Op::FFICallFunction,
idx,
args.len() as u16,
call_args_idx,
*full_span,
);
} else {
let idx = self.vm.get_function_idx(name);
self.vm
.add_op2(Op::Call, idx, args.len() as u16, *full_span); }
}
Expression::InstantiateStruct {
span,
fields,
struct_info,
} => {
let sorted_fields = fields
.iter()
.sorted_by_key(|(n, _)| struct_info.fields.get_index_of(*n).unwrap());
for (_, expr) in sorted_fields.rev() {
self.compile_expression(expr);
}
let struct_info_idx = self.vm.get_structinfo_idx(&struct_info.name).unwrap() as u16;
self.vm.add_op2(
Op::BuildStructInstance,
struct_info_idx,
fields.len() as u16,
*span,
);
}
Expression::AccessField {
full_span,
expr,
field_name,
struct_type,
..
} => {
self.compile_expression(expr);
let Type::Struct(ref struct_info) = struct_type.to_concrete_type() else {
unreachable!(
"Field access of non-struct type should be prevented by the type checker"
);
};
let idx = struct_info.fields.get_index_of(*field_name).unwrap();
self.vm
.add_op1(Op::AccessStructField, idx as u16, *full_span);
}
Expression::CallableCall {
full_span,
callable,
args,
type_scheme,
} => {
for arg in args {
self.compile_expression(arg);
}
self.compile_expression(callable);
let call_args = FfiCallArgs {
args: args
.iter()
.map(|a| FfiCallArg {
span: a.full_span(),
type_: a.get_type_scheme(),
})
.collect(),
return_type: Some(type_scheme.clone()),
};
let call_args_idx = self.vm.add_ffi_call_args(call_args);
self.vm.add_op2(
Op::CallCallable,
args.len() as u16,
call_args_idx,
*full_span,
);
}
Expression::Boolean(span, val) => {
let index = self.vm.add_constant(Constant::Boolean(*val));
self.vm.add_op1(Op::LoadConstant, index, *span);
}
Expression::String(span, string_parts) => {
for part in string_parts {
match part {
StringPart::Fixed(s) => {
let index = self.vm.add_constant(Constant::String(s.clone()));
self.vm.add_op1(Op::LoadConstant, index, *span)
}
StringPart::Interpolation {
expr,
span: _,
format_specifiers,
} => {
self.compile_expression(expr);
let index = self.vm.add_constant(Constant::FormatSpecifiers(
format_specifiers.map(|s| s.to_compact_string()),
));
self.vm.add_op1(Op::LoadConstant, index, *span)
}
}
}
self.vm
.add_op1(Op::JoinString, string_parts.len() as u16, *span); }
Expression::Condition {
span,
condition,
then_expr,
else_expr,
} => {
self.compile_expression(condition);
let if_jump_offset = self.vm.current_offset() + 1; self.vm.add_op1(Op::JumpIfFalse, 0xffff, *span);
self.compile_expression(then_expr);
let else_jump_offset = self.vm.current_offset() + 1;
self.vm.add_op1(Op::Jump, 0xffff, *span);
let else_block_offset = self.vm.current_offset();
self.vm
.patch_u16_value_at(if_jump_offset, else_block_offset - (if_jump_offset + 2));
self.compile_expression(else_expr);
let end_offset = self.vm.current_offset();
self.vm
.patch_u16_value_at(else_jump_offset, end_offset - (else_jump_offset + 2));
}
Expression::List { span, elements, .. } => {
for element in elements {
self.compile_expression(element);
}
self.vm.add_op1(Op::BuildList, elements.len() as u16, *span);
}
Expression::TypedHole(_, _) => {
unreachable!("Typed holes cause type inference errors")
}
};
}
fn compile_define_variable(&mut self, define_variable: &DefineVariable) {
let DefineVariable {
name: identifier,
decorators,
expr,
..
} = define_variable;
let current_depth = self.current_depth();
let identifiers = crate::decorator::name_and_aliases(identifier, decorators)
.map(|(name, _)| name.to_compact_string())
.collect::<Vec<_>>();
let metadata = LocalMetadata {
name: crate::decorator::name(decorators).map(CompactString::from),
url: crate::decorator::url(decorators).map(CompactString::from),
description: crate::decorator::description(decorators),
aliases: identifiers.clone(),
};
self.compile_expression(expr);
self.locals[current_depth].push(Local {
identifiers,
metadata,
});
}
fn compile_statement(&mut self, stmt: &Statement, typechecker: &TypeChecker) -> Result<()> {
match stmt {
Statement::Expression(expr) => {
self.compile_expression(expr);
self.vm.add_op(Op::Return, expr.full_span());
}
Statement::DefineVariable(define_variable) => {
self.compile_define_variable(define_variable);
}
Statement::DefineFunction {
function_name: name,
parameters,
body: Some(expr),
local_variables,
..
} => {
self.vm.begin_function(name);
self.locals.push(vec![]);
let current_depth = self.current_depth();
for parameter in parameters {
self.locals[current_depth].push(Local {
identifiers: [parameter.1.to_compact_string()].into(),
metadata: LocalMetadata::default(),
});
}
for local_variables in local_variables {
self.compile_define_variable(local_variables);
}
self.compile_expression(expr);
self.vm.add_op(Op::Return, expr.full_span());
self.locals.pop();
self.vm.end_function();
self.functions.insert(name.to_compact_string(), false);
}
Statement::DefineFunction {
function_name: name,
parameters,
body: None,
..
} => {
self.vm
.add_foreign_function(name, parameters.len()..=parameters.len());
self.functions.insert(name.to_compact_string(), true);
}
Statement::DefineDimension(_name, _dexprs) => {
}
Statement::DefineBaseUnit {
name: unit_name,
identifier_span: span,
decorators,
type_annotation,
type_scheme,
} => {
let aliases = decorator::name_and_aliases(unit_name, decorators)
.map(|(name, ap)| (name.to_compact_string(), ap))
.collect();
self.vm
.unit_registry
.add_base_unit(
unit_name,
UnitMetadata {
type_: type_scheme.to_concrete_type(), readable_type: type_annotation
.as_ref()
.map(|a: &TypeAnnotation| a.pretty_print())
.unwrap_or(
type_scheme.to_readable_type(typechecker.registry(), false),
),
aliases,
name: decorator::name(decorators).map(CompactString::from),
canonical_name: decorator::get_canonical_unit_name(
unit_name, decorators,
),
url: decorator::url(decorators).map(CompactString::from),
description: decorator::description(decorators),
binary_prefixes: decorators.contains(&Decorator::BinaryPrefixes),
metric_prefixes: decorators.contains(&Decorator::MetricPrefixes),
is_abbreviation: decorator::contains_abbreviation(decorators),
code_source_id: span.code_source_id,
},
)
.map_err(|e| {
self.vm
.runtime_error(RuntimeErrorKind::UnitRegistryError(e))
})?;
let constant_idx = self.vm.add_constant(Constant::Unit(Unit::new_base(
unit_name.to_compact_string(),
crate::decorator::get_canonical_unit_name(unit_name, &decorators[..]),
)));
for (name, _) in decorator::name_and_aliases(unit_name, decorators) {
self.unit_name_to_constant_index
.insert(name.into(), constant_idx);
}
}
Statement::DefineDerivedUnit {
name: unit_name,
identifier_span: full_span,
expr,
decorators,
type_annotation,
type_scheme,
..
} => {
let aliases = decorator::name_and_aliases(unit_name, decorators)
.map(|(name, ap)| (name.to_compact_string(), ap))
.collect();
let constant_idx = self.vm.add_constant(Constant::Unit(Unit::new_base(
CompactString::const_new("<dummy>"),
CanonicalName {
name: CompactString::const_new("<dummy>"),
accepts_prefix: AcceptsPrefix::both(),
},
))); let unit_information_idx = self.vm.add_unit_information(
unit_name,
UnitMetadata {
type_: type_scheme.to_concrete_type(), readable_type: type_annotation
.as_ref()
.map(|a: &TypeAnnotation| a.pretty_print())
.unwrap_or(type_scheme.to_readable_type(typechecker.registry(), false)),
aliases,
name: decorator::name(decorators).map(CompactString::from),
canonical_name: decorator::get_canonical_unit_name(unit_name, decorators),
url: decorator::url(decorators).map(CompactString::from),
description: decorator::description(decorators),
binary_prefixes: decorators.contains(&Decorator::BinaryPrefixes),
metric_prefixes: decorators.contains(&Decorator::MetricPrefixes),
is_abbreviation: decorator::contains_abbreviation(decorators),
code_source_id: full_span.code_source_id,
},
);
self.compile_expression(expr);
self.vm.add_op2(
Op::SetUnitConstant,
unit_information_idx,
constant_idx,
*full_span,
);
for (name, _) in decorator::name_and_aliases(unit_name, decorators) {
self.unit_name_to_constant_index
.insert(name.into(), constant_idx);
}
}
Statement::ProcedureCall {
kind: ProcedureKind::Type,
span: full_span,
args,
} => {
assert_eq!(args.len(), 1);
let arg = &args[0];
use crate::markup as m;
let idx = self.vm.add_string(
m::dimmed("=") + m::whitespace(" ") + arg.get_type_scheme().pretty_print(), );
self.vm.add_op1(Op::PrintString, idx, *full_span);
}
Statement::ProcedureCall {
kind,
span: full_span,
args,
} => {
for arg in args {
self.compile_expression(arg);
}
let name = kind.name();
let callable_idx = self.vm.get_ffi_callable_idx(name).unwrap();
let call_args = FfiCallArgs {
args: args
.iter()
.map(|a: &typed_ast::Expression| FfiCallArg {
span: a.full_span(),
type_: a.get_type_scheme(),
})
.collect(),
return_type: None,
};
let call_args_idx = self.vm.add_ffi_call_args(call_args);
self.vm.add_op3(
Op::FFICallProcedure,
callable_idx,
args.len() as u16,
call_args_idx,
*full_span,
);
}
Statement::DefineStruct(struct_info) => {
self.vm.add_struct_info(struct_info);
}
}
Ok(())
}
fn run(
&mut self,
settings: &mut InterpreterSettings,
prefix_transformer: &crate::prefix_transformer::Transformer,
typechecker: &TypeChecker,
) -> Result<InterpreterResult> {
let mut ctx = ExecutionContext {
print_fn: &mut settings.print_fn,
unit_name_to_constant_idx: &self.unit_name_to_constant_index,
prefix_transformer,
typechecker,
};
self.vm.disassemble();
let result = self.vm.run(&mut ctx);
let result = match result {
Ok(InterpreterResult::Value(Value::Quantity(q))) => {
let simplified = self
.vm
.simplify_quantity(&q, &self.unit_name_to_constant_index);
Ok(InterpreterResult::Value(Value::Quantity(simplified)))
}
r => r,
};
self.vm.debug();
result
}
pub(crate) fn set_debug(&mut self, activate: bool) {
self.vm.set_debug(activate);
}
fn current_depth(&self) -> usize {
self.locals.len() - 1
}
pub fn get_defining_unit(&self, unit_name: &str) -> Option<&Unit> {
self.unit_name_to_constant_index
.get(unit_name)
.and_then(|idx| self.vm.constants.get(*idx as usize))
.and_then(|constant| match constant {
Constant::Unit(u) => Some(u),
_ => None,
})
}
pub fn lookup_global(&self, name: &str) -> Option<&Local> {
self.locals[0]
.iter()
.find(|l| l.identifiers.iter().any(|n| n == name))
}
}
impl Interpreter for BytecodeInterpreter {
fn new() -> Self {
Self {
vm: Vm::new(),
locals: vec![vec![]],
unit_name_to_constant_index: HashMap::new(),
functions: HashMap::new(),
}
}
fn interpret_statements(
&mut self,
settings: &mut InterpreterSettings,
statements: &[Statement],
prefix_transformer: &crate::prefix_transformer::Transformer,
typechecker: &TypeChecker,
) -> Result<InterpreterResult> {
for statement in statements {
self.compile_statement(statement, typechecker)?;
}
self.run(settings, prefix_transformer, typechecker)
}
fn get_unit_registry(&self) -> &UnitRegistry {
&self.vm.unit_registry
}
}