use std::{
collections::HashMap,
convert::Infallible,
error::Error,
fmt::{Display, Formatter}
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
CanAllocate as _, Optimizer as _, StandardOptimizer,
ast::{
self, ASTVisitor, ArithmeticExpression, Constant, DiceExpression,
Expression
},
ir::{
AddressingMode, Immediate, Instruction, RegisterIndex,
RollingRecordIndex
},
parser::{self, ParseError}
};
pub fn compile_unoptimized(
source: &str
) -> Result<Function, CompilationError<'_>>
{
let ast =
parser::parse(source).map_err(CompilationError::CompilationFailed)?;
Ok(Compiler::compile(&ast))
}
#[cfg_attr(doc, aquamarine::aquamarine)]
pub fn compile(source: &str) -> Result<Function, CompilationError<'_>>
{
let ast =
parser::parse(source).map_err(CompilationError::CompilationFailed)?;
let function = Compiler::compile(&ast);
let optimizer = StandardOptimizer::new(Default::default());
let function = optimizer
.optimize(function)
.map_err(|_| CompilationError::OptimizationFailed)?;
Ok(function)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompilationError<'src>
{
CompilationFailed(ParseError<'src>),
OptimizationFailed
}
impl Display for CompilationError<'_>
{
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result
{
match self
{
CompilationError::CompilationFailed(e) =>
{
write!(f, "{}", e)
},
CompilationError::OptimizationFailed =>
{
write!(f, "optimization failed")
}
}
}
}
impl Error for CompilationError<'_> {}
pub struct Compiler<'src>
{
instructions: Vec<Instruction>,
next_register: RegisterIndex,
next_rolling_record: RollingRecordIndex,
arity: usize,
variables: HashMap<&'src str, RegisterIndex>
}
impl<'src> Compiler<'src>
{
pub fn compile(ast: &'src ast::Function<'src>) -> Function
{
let mut compiler = Self {
instructions: Vec::new(),
next_register: RegisterIndex(0),
next_rolling_record: RollingRecordIndex(0),
arity: 0,
variables: HashMap::new()
};
let _ = compiler.visit_function(ast);
compiler.finish()
}
fn finish(self) -> Function
{
let mut parameters = Vec::new();
let mut externals = Vec::new();
for (name, register) in &self.variables
{
match register.0 >= self.arity
{
false => parameters.push((name, register)),
true => externals.push((name, register))
}
}
parameters.sort_by_key(|(_, register)| register.0);
externals.sort_by_key(|(_, register)| register.0);
let parameters = parameters
.into_iter()
.map(|(name, _)| name.to_string())
.collect();
let externals = externals
.into_iter()
.map(|(name, _)| name.to_string())
.collect();
Function {
parameters,
externals,
register_count: self.next_register.0,
rolling_record_count: self.next_rolling_record.0,
instructions: self.instructions
}
}
fn variable(&mut self, name: &'src str) -> RegisterIndex
{
match self.variables.get(name)
{
Some(®ister) => register,
None =>
{
let register = self.allocate_register();
self.variables.insert(name, register);
register
}
}
}
#[inline]
fn allocate_register(&mut self) -> RegisterIndex
{
self.next_register.allocate()
}
#[inline]
fn allocate_rolling_record(&mut self) -> RollingRecordIndex
{
self.next_rolling_record.allocate()
}
#[inline]
fn emit(&mut self, instruction: Instruction)
{
self.instructions.push(instruction);
}
fn accept_expression(
&mut self,
expr: &'src Expression<'src>
) -> AddressingMode
{
let value = expr.accept(self).unwrap();
match value
{
AddressingMode::RollingRecord(record) =>
{
let sum = self.allocate_register();
self.emit(Instruction::sum_rolling_record(sum, record));
sum.into()
},
other => other
}
}
fn generate_binary(
&mut self,
left: &'src Expression<'src>,
right: &'src Expression<'src>,
constructor: fn(
RegisterIndex,
AddressingMode,
AddressingMode
) -> Instruction
) -> AddressingMode
{
let op1 = self.accept_expression(left);
let op2 = self.accept_expression(right);
let dest = self.allocate_register();
self.emit(constructor(dest, op1, op2));
dest.into()
}
}
impl<'src> ASTVisitor<'src> for Compiler<'src>
{
type Output = AddressingMode;
type Error = Infallible;
fn visit_function(
&mut self,
node: &'src ast::Function<'src>
) -> Result<AddressingMode, Infallible>
{
if let Some(ref parameters) = node.parameters
{
for ¶m in parameters
{
self.variable(param);
}
self.arity = self.variables.len();
}
let externals = discover_externals(&node.body);
for external in externals
{
self.variable(external);
}
let return_value = self.accept_expression(&node.body);
self.emit(Instruction::r#return(return_value));
Ok(return_value)
}
fn visit_group(
&mut self,
node: &'src ast::Group<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.accept_expression(&node.expression))
}
fn visit_constant(
&mut self,
node: &Constant
) -> Result<AddressingMode, Infallible>
{
Ok(Immediate(node.0).into())
}
fn visit_variable(
&mut self,
node: &'src ast::Variable<'src>
) -> Result<AddressingMode, Infallible>
{
let register = self.variable(node.0);
Ok(register.into())
}
fn visit_range(
&mut self,
node: &'src ast::Range<'src>
) -> Result<AddressingMode, Infallible>
{
let start = self.accept_expression(&node.start);
let end = self.accept_expression(&node.end);
let dest = self.allocate_rolling_record();
self.emit(Instruction::roll_range(dest, start, end));
let sum = self.allocate_register();
self.emit(Instruction::sum_rolling_record(sum, dest));
Ok(sum.into())
}
fn visit_standard_dice(
&mut self,
node: &'src ast::StandardDice<'src>
) -> Result<AddressingMode, Infallible>
{
let count = self.accept_expression(&node.count);
let faces = self.accept_expression(&node.faces);
let dest = self.allocate_rolling_record();
self.emit(Instruction::roll_standard_dice(dest, count, faces));
Ok(dest.into())
}
fn visit_custom_dice(
&mut self,
node: &'src ast::CustomDice<'src>
) -> Result<AddressingMode, Infallible>
{
let count = self.accept_expression(&node.count);
let dest = self.allocate_rolling_record();
self.emit(Instruction::roll_custom_dice(
dest,
count,
node.faces.clone()
));
Ok(dest.into())
}
fn visit_drop_lowest(
&mut self,
node: &'src ast::DropLowest<'src>
) -> Result<AddressingMode, Infallible>
{
let record: RollingRecordIndex = node
.dice
.accept(self)
.unwrap()
.try_into()
.expect("dice visitor must return RollingRecord");
let count = match &node.drop
{
Some(expr) => self.accept_expression(expr),
None => Immediate(1).into()
};
self.emit(Instruction::drop_lowest(record, count));
Ok(record.into())
}
fn visit_drop_highest(
&mut self,
node: &'src ast::DropHighest<'src>
) -> Result<AddressingMode, Infallible>
{
let record: RollingRecordIndex = node
.dice
.accept(self)
.unwrap()
.try_into()
.expect("dice visitor must return RollingRecord");
let count = match &node.drop
{
Some(expr) => self.accept_expression(expr),
None => Immediate(1).into()
};
self.emit(Instruction::drop_highest(record, count));
Ok(record.into())
}
fn visit_add(
&mut self,
node: &'src ast::Add<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::add))
}
fn visit_sub(
&mut self,
node: &'src ast::Sub<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::sub))
}
fn visit_mul(
&mut self,
node: &'src ast::Mul<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::mul))
}
fn visit_div(
&mut self,
node: &'src ast::Div<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::div))
}
fn visit_mod(
&mut self,
node: &'src ast::Mod<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::r#mod))
}
fn visit_exp(
&mut self,
node: &'src ast::Exp<'src>
) -> Result<AddressingMode, Infallible>
{
Ok(self.generate_binary(&node.left, &node.right, Instruction::exp))
}
fn visit_neg(
&mut self,
node: &'src ast::Neg<'src>
) -> Result<AddressingMode, Infallible>
{
if let Expression::Constant(Constant(n)) = node.operand.as_ref()
{
return Ok(Immediate(n.saturating_neg()).into());
}
let op = self.accept_expression(&node.operand);
let dest = self.allocate_register();
self.emit(Instruction::neg(dest, op));
Ok(dest.into())
}
}
fn discover_externals<'src>(expr: &'src Expression<'src>) -> Vec<&'src str>
{
let mut externals = Vec::new();
collect_variables(expr, &mut externals);
externals
}
fn collect_variables<'src>(
expr: &'src Expression<'src>,
out: &mut Vec<&'src str>
)
{
match expr
{
Expression::Variable(v) =>
{
out.push(v.0);
},
Expression::Group(g) =>
{
collect_variables(&g.expression, out);
},
Expression::Range(r) =>
{
collect_variables(&r.start, out);
collect_variables(&r.end, out);
},
Expression::Dice(d) =>
{
collect_dice_variables(d, out);
},
Expression::Arithmetic(a) =>
{
collect_arithmetic_variables(a, out);
},
Expression::Constant(_) =>
{}
}
}
fn collect_dice_variables<'src>(
dice: &'src DiceExpression<'src>,
out: &mut Vec<&'src str>
)
{
match dice
{
DiceExpression::Standard(d) =>
{
collect_variables(&d.count, out);
collect_variables(&d.faces, out);
},
DiceExpression::Custom(d) =>
{
collect_variables(&d.count, out);
},
DiceExpression::DropLowest(d) =>
{
collect_dice_variables(&d.dice, out);
if let Some(ref drop) = d.drop
{
collect_variables(drop, out);
}
},
DiceExpression::DropHighest(d) =>
{
collect_dice_variables(&d.dice, out);
if let Some(ref drop) = d.drop
{
collect_variables(drop, out);
}
}
}
}
fn collect_arithmetic_variables<'src>(
arith: &'src ArithmeticExpression<'src>,
out: &mut Vec<&'src str>
)
{
match arith
{
ArithmeticExpression::Add(a) =>
{
collect_variables(&a.left, out);
collect_variables(&a.right, out);
},
ArithmeticExpression::Sub(s) =>
{
collect_variables(&s.left, out);
collect_variables(&s.right, out);
},
ArithmeticExpression::Mul(m) =>
{
collect_variables(&m.left, out);
collect_variables(&m.right, out);
},
ArithmeticExpression::Div(d) =>
{
collect_variables(&d.left, out);
collect_variables(&d.right, out);
},
ArithmeticExpression::Mod(m) =>
{
collect_variables(&m.left, out);
collect_variables(&m.right, out);
},
ArithmeticExpression::Exp(e) =>
{
collect_variables(&e.left, out);
collect_variables(&e.right, out);
},
ArithmeticExpression::Neg(n) =>
{
collect_variables(&n.operand, out);
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Function
{
pub parameters: Vec<String>,
pub externals: Vec<String>,
pub register_count: usize,
pub rolling_record_count: usize,
pub instructions: Vec<Instruction>
}
impl Function
{
#[inline]
pub fn arity(&self) -> usize { self.parameters.len() }
}
impl Display for Function
{
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result
{
write!(f, "Function(")?;
for (i, parameter) in self.parameters.iter().enumerate()
{
if i != 0
{
write!(f, ", ")?;
}
write!(f, "{}@{}", parameter, i)?;
}
writeln!(
f,
") r#{} ⚅#{}",
self.register_count, self.rolling_record_count
)?;
write!(f, "\textern[")?;
for (i, external) in self.externals.iter().enumerate()
{
if i != 0
{
write!(f, ", ")?;
}
write!(f, "{}@{}", external, i + self.parameters.len())?;
}
writeln!(f, "]")?;
writeln!(f, "\tbody:")?;
for instruction in &self.instructions
{
writeln!(f, "\t\t{}", instruction)?;
}
Ok(())
}
}