use hashbrown::HashMap;
use crate::{
alloc::{Box, String, ToOwned},
executable::{Atom, Command, CompiledExpr, Executable, ExecutableModule, FieldName, Registers},
Error, ErrorKind, ModuleId, Value,
};
use arithmetic_parser::{
grammars::Grammar, BinaryOp, Block, Destructure, FnDefinition, InputSpan, Lvalue,
ObjectDestructure, Spanned, SpannedLvalue, UnaryOp,
};
mod captures;
mod expr;
use self::captures::{CapturesExtractor, CompilerExtTarget};
pub(crate) type ImportSpans<'a> = HashMap<String, Spanned<'a>>;
#[derive(Debug)]
pub(crate) struct Compiler {
vars_to_registers: HashMap<String, usize>,
scope_depth: usize,
register_count: usize,
module_id: Box<dyn ModuleId>,
}
impl Compiler {
fn new(module_id: Box<dyn ModuleId>) -> Self {
Self {
vars_to_registers: HashMap::new(),
scope_depth: 0,
register_count: 0,
module_id,
}
}
fn from_env<T>(module_id: Box<dyn ModuleId>, env: &Registers<'_, T>) -> Self {
Self {
vars_to_registers: env.variables_map().clone(),
register_count: env.register_count(),
scope_depth: 0,
module_id,
}
}
fn backup(&mut self) -> Self {
Self {
vars_to_registers: self.vars_to_registers.clone(),
scope_depth: self.scope_depth,
register_count: self.register_count,
module_id: self.module_id.clone_boxed(),
}
}
fn create_error<'a, T>(&self, span: &Spanned<'a, T>, err: ErrorKind) -> Error<'a> {
Error::new(self.module_id.as_ref(), span, err)
}
fn check_unary_op<'a>(&self, op: &Spanned<'a, UnaryOp>) -> Result<UnaryOp, Error<'a>> {
match op.extra {
UnaryOp::Neg | UnaryOp::Not => Ok(op.extra),
_ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
}
}
fn check_binary_op<'a>(&self, op: &Spanned<'a, BinaryOp>) -> Result<BinaryOp, Error<'a>> {
match op.extra {
BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div
| BinaryOp::Power
| BinaryOp::And
| BinaryOp::Or
| BinaryOp::Eq
| BinaryOp::NotEq
| BinaryOp::Gt
| BinaryOp::Ge
| BinaryOp::Lt
| BinaryOp::Le => Ok(op.extra),
_ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
}
}
fn get_var(&self, name: &str) -> usize {
*self
.vars_to_registers
.get(name)
.expect("Captures must created during module compilation")
}
fn push_assignment<'a, T, U>(
&mut self,
executable: &mut Executable<'a, T>,
rhs: CompiledExpr<'a, T>,
rhs_span: &Spanned<'a, U>,
) -> usize {
let register = self.register_count;
let command = Command::Push(rhs);
executable.push_command(rhs_span.copy_with_extra(command));
self.register_count += 1;
register
}
pub fn compile_module<'a, Id: ModuleId, T: Grammar<'a>>(
module_id: Id,
block: &Block<'a, T>,
) -> Result<(ExecutableModule<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
let module_id = Box::new(module_id) as Box<dyn ModuleId>;
let (captures, import_spans) = Self::extract_captures(module_id.clone_boxed(), block)?;
let mut compiler = Self::from_env(module_id.clone_boxed(), &captures);
let mut executable = Executable::new(module_id);
let empty_span = InputSpan::new("");
let last_atom = compiler
.compile_block_inner(&mut executable, block)?
.map_or(Atom::Void, |spanned| spanned.extra);
compiler.push_assignment(
&mut executable,
CompiledExpr::Atom(last_atom),
&empty_span.into(),
);
executable.finalize_block(compiler.register_count);
let module = ExecutableModule::from_parts(executable, captures);
Ok((module, import_spans))
}
fn extract_captures<'a, T: Grammar<'a>>(
module_id: Box<dyn ModuleId>,
block: &Block<'a, T>,
) -> Result<(Registers<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
let mut extractor = CapturesExtractor::new(module_id);
extractor.eval_block(&block)?;
let mut captures = Registers::new();
for &var_name in extractor.captures.keys() {
captures.insert_var(var_name, Value::void());
}
let import_spans = extractor
.captures
.into_iter()
.map(|(var_name, var_span)| (var_name.to_owned(), var_span))
.collect();
Ok((captures, import_spans))
}
fn assign<'a, T, Ty>(
&mut self,
executable: &mut Executable<'a, T>,
lhs: &SpannedLvalue<'a, Ty>,
rhs_register: usize,
) -> Result<(), Error<'a>> {
match &lhs.extra {
Lvalue::Variable { .. } => {
self.insert_var(executable, lhs.with_no_extra(), rhs_register);
}
Lvalue::Tuple(destructure) => {
let span = lhs.with_no_extra();
self.destructure(executable, destructure, span, rhs_register)?;
}
Lvalue::Object(destructure) => {
let span = lhs.with_no_extra();
self.destructure_object(executable, destructure, span, rhs_register)?;
}
_ => {
let err = ErrorKind::unsupported(lhs.extra.ty());
return Err(self.create_error(lhs, err));
}
}
Ok(())
}
fn insert_var<'a, T>(
&mut self,
executable: &mut Executable<'a, T>,
var_span: Spanned<'a>,
register: usize,
) {
let var_name = *var_span.fragment();
if var_name != "_" {
self.vars_to_registers.insert(var_name.to_owned(), register);
if self.scope_depth == 0 {
let command = Command::Annotate {
register,
name: var_name.to_owned(),
};
executable.push_command(var_span.copy_with_extra(command));
}
}
}
fn destructure<'a, T, Ty>(
&mut self,
executable: &mut Executable<'a, T>,
destructure: &Destructure<'a, Ty>,
span: Spanned<'a>,
rhs_register: usize,
) -> Result<(), Error<'a>> {
let command = Command::Destructure {
source: rhs_register,
start_len: destructure.start.len(),
end_len: destructure.end.len(),
lvalue_len: destructure.len(),
unchecked: false,
};
executable.push_command(span.copy_with_extra(command));
let start_register = self.register_count;
self.register_count += destructure.start.len() + destructure.end.len() + 1;
for (i, lvalue) in (start_register..).zip(&destructure.start) {
self.assign(executable, lvalue, i)?;
}
let start_register = start_register + destructure.start.len();
if let Some(middle) = &destructure.middle {
if let Some(lvalue) = middle.extra.to_lvalue() {
self.assign(executable, &lvalue, start_register)?;
}
}
let start_register = start_register + 1;
for (i, lvalue) in (start_register..).zip(&destructure.end) {
self.assign(executable, lvalue, i)?;
}
Ok(())
}
fn destructure_object<'a, T, Ty>(
&mut self,
executable: &mut Executable<'a, T>,
destructure: &ObjectDestructure<'a, Ty>,
span: Spanned<'a>,
rhs_register: usize,
) -> Result<(), Error<'a>> {
for field in &destructure.fields {
let field_name = FieldName::Name((*field.field_name.fragment()).to_owned());
let field_access = CompiledExpr::FieldAccess {
receiver: span.copy_with_extra(Atom::Register(rhs_register)).into(),
field: field_name,
};
let register = self.push_assignment(executable, field_access, &field.field_name);
if let Some(binding) = &field.binding {
self.assign(executable, binding, register)?;
} else {
self.insert_var(executable, field.field_name, register);
}
}
Ok(())
}
}
pub trait CompilerExt<'a> {
fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>>;
}
impl<'a, T: Grammar<'a>> CompilerExt<'a> for Block<'a, T> {
fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
CompilerExtTarget::Block(self).get_undefined_variables()
}
}
impl<'a, T: Grammar<'a>> CompilerExt<'a> for FnDefinition<'a, T> {
fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
CompilerExtTarget::FnDefinition(self).get_undefined_variables()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Value, WildcardId};
use arithmetic_parser::grammars::{F32Grammar, Parse, ParseLiteral, Typed, Untyped};
use arithmetic_parser::{Expr, NomResult};
#[test]
fn compilation_basics() {
let block = "x = 3; 1 + { y = 2; y * x } == 7";
let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
let value = module.run().unwrap();
assert_eq!(value, Value::Bool(true));
}
#[test]
fn compiled_function() {
let block = "add = |x, y| x + y; add(2, 3) == 5";
let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn compiled_function_with_capture() {
let block = "A = 2; add = |x, y| x + y / A; add(2, 3) == 3.5";
let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn variable_extraction() {
let def = "|a, b| ({ x = a * b + y; x - 2 }, a / b)";
let def = Untyped::<F32Grammar>::parse_statements(def)
.unwrap()
.return_value
.unwrap();
let def = match def.extra {
Expr::FnDefinition(def) => def,
other => panic!("Unexpected function parsing result: {:?}", other),
};
let captures = def.undefined_variables().unwrap();
assert_eq!(captures["y"].location_offset(), 22);
assert!(!captures.contains_key("x"));
}
#[test]
fn variable_extraction_with_scoping() {
let def = "|a, b| ({ x = a * b + y; x - 2 }, a / x)";
let def = Untyped::<F32Grammar>::parse_statements(def)
.unwrap()
.return_value
.unwrap();
let def = match def.extra {
Expr::FnDefinition(def) => def,
other => panic!("Unexpected function parsing result: {:?}", other),
};
let captures = def.undefined_variables().unwrap();
assert_eq!(captures["y"].location_offset(), 22);
assert_eq!(captures["x"].location_offset(), 38);
}
#[test]
fn extracting_captures() {
let program = "y = 5 * x; y - 3 + x";
let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let (registers, import_spans) =
Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
assert_eq!(registers.register_count(), 1);
assert_eq!(*registers.get_var("x").unwrap(), Value::void());
assert_eq!(import_spans.len(), 1);
assert_eq!(import_spans["x"], Spanned::from_str(program, 8..9));
}
#[test]
fn extracting_captures_with_inner_fns() {
let program = r#"
y = 5 * x; // x is a capture
fun = |z| { // z is not a capture
z * x + y * PI // y is not a capture for the entire module, PI is
};
"#;
let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let (registers, import_spans) =
Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
assert_eq!(registers.register_count(), 2);
assert!(registers.variables_map().contains_key("x"));
assert!(registers.variables_map().contains_key("PI"));
assert_eq!(import_spans["x"].location_line(), 2); }
#[test]
fn type_casts_are_ignored() {
struct TypedGrammar;
impl ParseLiteral for TypedGrammar {
type Lit = f32;
fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
F32Grammar::parse_literal(input)
}
}
impl Grammar<'_> for TypedGrammar {
type Type = ();
fn parse_type(input: InputSpan<'_>) -> NomResult<'_, Self::Type> {
use nom::{bytes::complete::tag, combinator::map};
map(tag("Num"), drop)(input)
}
}
let block = "x = 3 as Num; 1 + { y = 2; y * x as Num } == 7";
let block = Typed::<TypedGrammar>::parse_statements(block).unwrap();
let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
let value = module.run().unwrap();
assert_eq!(value, Value::Bool(true));
}
}