use crate::CodeGenerator;
use leo_ast::{
AssignStatement, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement, DecrementStatement,
DefinitionStatement, Expression, FinalizeStatement, IncrementStatement, IterationStatement, Mode, Output,
ReturnStatement, Statement,
};
use itertools::Itertools;
use std::fmt::Write as _;
impl<'a> CodeGenerator<'a> {
fn visit_statement(&mut self, input: &'a Statement) -> String {
match input {
Statement::Assign(stmt) => self.visit_assign(stmt),
Statement::Block(stmt) => self.visit_block(stmt),
Statement::Conditional(stmt) => self.visit_conditional(stmt),
Statement::Console(stmt) => self.visit_console(stmt),
Statement::Decrement(stmt) => self.visit_decrement(stmt),
Statement::Definition(stmt) => self.visit_definition(stmt),
Statement::Finalize(stmt) => self.visit_finalize(stmt),
Statement::Increment(stmt) => self.visit_increment(stmt),
Statement::Iteration(stmt) => self.visit_iteration(stmt),
Statement::Return(stmt) => self.visit_return(stmt),
}
}
fn visit_return(&mut self, input: &'a ReturnStatement) -> String {
match input.expression {
Expression::Tuple(ref tuple) if tuple.elements.is_empty() => String::new(),
_ => {
let (operand, mut expression_instructions) = self.visit_expression(&input.expression);
let output = if self.in_finalize {
self.current_function.unwrap().finalize.as_ref().unwrap().output.iter()
} else {
self.current_function.unwrap().output.iter()
};
let instructions = operand
.split('\n')
.into_iter()
.zip_eq(output)
.map(|(operand, output)| {
match output {
Output::Internal(output) => {
let visibility = if self.is_transition_function {
match self.in_finalize {
true => match output.mode {
Mode::None => Mode::Public,
mode => mode,
},
false => match output.mode {
Mode::None => Mode::Private,
mode => mode,
},
}
} else {
Mode::None
};
format!(
" output {} as {};\n",
operand,
self.visit_type_with_visibility(&output.type_, visibility)
)
}
Output::External(output) => {
format!(
" output {} as {}.aleo/{}.record;\n",
operand, output.program_name, output.record,
)
}
}
})
.join("");
expression_instructions.push_str(&instructions);
expression_instructions
}
}
}
fn visit_definition(&mut self, _input: &'a DefinitionStatement) -> String {
unreachable!("DefinitionStatement's should not exist in SSA form.")
}
fn visit_increment(&mut self, input: &'a IncrementStatement) -> String {
let (index, mut instructions) = self.visit_expression(&input.index);
let (amount, amount_instructions) = self.visit_expression(&input.amount);
instructions.push_str(&amount_instructions);
instructions.push_str(&format!(" increment {}[{index}] by {amount};\n", input.mapping));
instructions
}
fn visit_decrement(&mut self, input: &'a DecrementStatement) -> String {
let (index, mut instructions) = self.visit_expression(&input.index);
let (amount, amount_instructions) = self.visit_expression(&input.amount);
instructions.push_str(&amount_instructions);
instructions.push_str(&format!(" decrement {}[{index}] by {amount};\n", input.mapping));
instructions
}
fn visit_finalize(&mut self, input: &'a FinalizeStatement) -> String {
let mut instructions = String::new();
let mut finalize_instruction = " finalize".to_string();
for argument in input.arguments.iter() {
let (argument, argument_instructions) = self.visit_expression(argument);
write!(finalize_instruction, " {argument}").expect("failed to write to string");
instructions.push_str(&argument_instructions);
}
writeln!(finalize_instruction, ";").expect("failed to write to string");
finalize_instruction
}
fn visit_assign(&mut self, input: &'a AssignStatement) -> String {
match &input.place {
Expression::Identifier(identifier) => {
let (operand, expression_instructions) = self.visit_expression(&input.value);
self.variable_mapping.insert(&identifier.name, operand);
expression_instructions
}
_ => unimplemented!(
"Code generation for the left-hand side of an assignment is only implemented for `Identifier`s."
),
}
}
fn visit_conditional(&mut self, _input: &'a ConditionalStatement) -> String {
unreachable!("`ConditionalStatement`s should not be in the AST at this phase of compilation.")
}
fn visit_iteration(&mut self, _input: &'a IterationStatement) -> String {
unreachable!("`IterationStatement`s should not be in the AST at this phase of compilation.");
}
fn visit_console(&mut self, input: &'a ConsoleStatement) -> String {
let mut generate_assert_instruction = |name: &str, left: &'a Expression, right: &'a Expression| {
let (left_operand, left_instructions) = self.visit_expression(left);
let (right_operand, right_instructions) = self.visit_expression(right);
let assert_instruction = format!(" {name} {left_operand} {right_operand};\n");
let mut instructions = left_instructions;
instructions.push_str(&right_instructions);
instructions.push_str(&assert_instruction);
instructions
};
match &input.function {
ConsoleFunction::Assert(expr) => {
let (operand, mut instructions) = self.visit_expression(expr);
let assert_instruction = format!(" assert.eq {operand} true;\n");
instructions.push_str(&assert_instruction);
instructions
}
ConsoleFunction::AssertEq(left, right) => generate_assert_instruction("assert.eq", left, right),
ConsoleFunction::AssertNeq(left, right) => generate_assert_instruction("assert.neq", left, right),
}
}
pub(crate) fn visit_block(&mut self, input: &'a Block) -> String {
input.statements.iter().map(|stmt| self.visit_statement(stmt)).join("")
}
}