use super::FunctionInliningVisitor;
use crate::{Replacer, SsaFormingInput, static_single_assignment::visitor::SsaFormingVisitor};
use leo_ast::*;
use indexmap::IndexMap;
use itertools::Itertools;
impl AstReconstructor for FunctionInliningVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = Vec<Statement>;
fn reconstruct_call(&mut self, input: CallExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
if input.function.expect_global_location().program != self.program {
return (input.into(), Default::default());
}
let function_location = input.function.expect_global_location();
let (_, callee) = self
.reconstructed_functions
.iter()
.find(|(path, _)| *path == function_location.path)
.expect("guaranteed to exist due to post-order traversal of the call graph.");
match callee.variant {
Variant::Inline => {
let parameter_to_argument = callee
.input
.iter()
.map(|input| input.identifier().name)
.zip_eq(input.arguments)
.collect::<IndexMap<_, _>>();
let replace_path = |expr: &Expression| match expr {
Expression::Path(path) => parameter_to_argument
.get(&path.identifier().name)
.map_or(Expression::Path(path.clone()), |expr| expr.clone()),
_ => expr.clone(),
};
let reconstructed_block = Replacer::new(replace_path, false , self.state)
.reconstruct_block(callee.block.clone())
.0;
let mut inlined_statements =
SsaFormingVisitor::new(self.state, SsaFormingInput { rename_defs: true }, self.program)
.consume_block(reconstructed_block);
let result = match inlined_statements.last() {
Some(Statement::Return(_)) => {
match inlined_statements.pop().unwrap() {
Statement::Return(ReturnStatement { expression, .. }) => expression,
_ => panic!("This branch checks that the last statement is a return statement."),
}
}
_ => {
let id = self.state.node_builder.next_id();
self.state.type_table.insert(id, Type::Unit);
UnitExpression { span: Default::default(), id }.into()
}
};
(result, inlined_statements)
}
Variant::Function
| Variant::Script
| Variant::AsyncFunction
| Variant::Transition
| Variant::AsyncTransition => (input.into(), Default::default()),
}
}
fn reconstruct_assign(&mut self, _input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
panic!("`AssignStatement`s should not exist in the AST at this phase of compilation.")
}
fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
let mut statements = Vec::with_capacity(block.statements.len());
for statement in block.statements {
let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
statements.extend(additional_statements);
statements.push(reconstructed_statement);
}
(Block { span: block.span, statements, id: block.id }, Default::default())
}
fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
if !self.is_async {
panic!("`ConditionalStatement`s should not be in the AST at this phase of compilation.")
} else {
(
ConditionalStatement {
condition: self.reconstruct_expression(input.condition, &()).0,
then: self.reconstruct_block(input.then).0,
otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
span: input.span,
id: input.id,
}
.into(),
Default::default(),
)
}
}
fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
let (value, mut statements) = self.reconstruct_expression(input.value, &());
match (input.place, value) {
(DefinitionPlace::Multiple(left), Expression::Tuple(right)) => {
assert_eq!(left.len(), right.elements.len());
for (identifier, rhs_value) in left.into_iter().zip(right.elements) {
let stmt = DefinitionStatement {
place: DefinitionPlace::Single(identifier),
type_: None,
value: rhs_value,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
statements.push(stmt);
}
(Statement::dummy(), statements)
}
(place, value) => {
input.value = value;
input.place = place;
(input.into(), statements)
}
}
}
fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
let (expression, additional_statements) = self.reconstruct_expression(input.expression, &());
let statement = match expression {
Expression::Unit(_) => Statement::dummy(),
_ => ExpressionStatement { expression, ..input }.into(),
};
(statement, additional_statements)
}
fn reconstruct_iteration(&mut self, _: IterationStatement) -> (Statement, Self::AdditionalOutput) {
panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
}
}