use crate::{Assigner, SymbolTable};
use leo_ast::{
AccessExpression,
BinaryExpression,
BinaryOperation,
Block,
Expression,
ExpressionReconstructor,
Identifier,
Member,
ReturnStatement,
Statement,
TernaryExpression,
TupleExpression,
Type,
};
use leo_span::Symbol;
use indexmap::IndexMap;
pub struct Flattener<'a> {
pub(crate) symbol_table: &'a SymbolTable,
pub(crate) assigner: Assigner,
pub(crate) structs: IndexMap<Symbol, Symbol>,
pub(crate) condition_stack: Vec<Expression>,
pub(crate) returns: Vec<(Option<Expression>, ReturnStatement)>,
pub(crate) tuples: IndexMap<Symbol, TupleExpression>,
}
impl<'a> Flattener<'a> {
pub(crate) fn new(symbol_table: &'a SymbolTable, assigner: Assigner) -> Self {
Self {
symbol_table,
assigner,
structs: IndexMap::new(),
condition_stack: Vec::new(),
returns: Vec::new(),
tuples: IndexMap::new(),
}
}
pub(crate) fn clear_early_returns(&mut self) -> Vec<(Option<Expression>, ReturnStatement)> {
core::mem::take(&mut self.returns)
}
pub(crate) fn construct_guard(&mut self) -> Option<Expression> {
match self.condition_stack.is_empty() {
true => None,
false => {
let (first, rest) = self.condition_stack.split_first().unwrap();
Some(rest.iter().cloned().fold(first.clone(), |acc, condition| {
Expression::Binary(BinaryExpression {
op: BinaryOperation::And,
left: Box::new(acc),
right: Box::new(condition),
span: Default::default(),
})
}))
}
}
}
pub(crate) fn fold_guards(
&mut self,
prefix: &str,
mut guards: Vec<(Option<Expression>, Expression)>,
) -> (Expression, Vec<Statement>) {
let (_, last_expression) = guards.pop().unwrap();
let mut statements = Vec::with_capacity(guards.len());
let mut construct_ternary_assignment = |guard: Expression, if_true: Expression, if_false: Expression| {
let place = Identifier { name: self.assigner.unique_symbol(prefix, "$"), span: Default::default() };
let (value, stmts) = self.reconstruct_ternary(TernaryExpression {
condition: Box::new(guard),
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: Default::default(),
});
statements.extend(stmts);
match &value {
Expression::Tuple(_) => value,
_ => {
statements.push(self.simple_assign_statement(place, value));
Expression::Identifier(place)
}
}
};
let expression = guards.into_iter().rev().fold(last_expression, |acc, (guard, expr)| match guard {
None => unreachable!("All expressions except for the last one must have a guard."),
Some(guard) => construct_ternary_assignment(guard, expr, acc),
});
(expression, statements)
}
pub(crate) fn lookup_struct_symbol(&self, expression: &Expression) -> Option<Symbol> {
match expression {
Expression::Identifier(identifier) => self.structs.get(&identifier.name).copied(),
Expression::Access(AccessExpression::Member(access)) => {
let name = self.lookup_struct_symbol(&access.inner).unwrap();
let struct_ = self.symbol_table.lookup_struct(name).unwrap();
let Member { type_, .. } =
struct_.members.iter().find(|member| member.name() == access.name.name).unwrap();
match type_ {
Type::Identifier(identifier) => Some(identifier.name),
_ => None,
}
}
_ => None,
}
}
pub(crate) fn update_structs(&mut self, lhs: &Identifier, rhs: &Expression) {
match rhs {
Expression::Struct(rhs) => {
self.structs.insert(lhs.name, rhs.name.name);
}
Expression::Identifier(rhs) if self.structs.contains_key(&rhs.name) => {
self.structs.insert(lhs.name, *self.structs.get(&rhs.name).unwrap());
}
_ => (),
}
}
pub(crate) fn unique_simple_assign_statement(&mut self, expr: Expression) -> (Identifier, Statement) {
let (place, statement) = self.assigner.unique_simple_assign_statement(expr);
match &statement {
Statement::Assign(assign) => {
self.update_structs(&place, &assign.value);
}
_ => unreachable!("`assigner.unique_simple_assign_statement` always returns an assignment statement."),
}
(place, statement)
}
pub(crate) fn simple_assign_statement(&mut self, lhs: Identifier, rhs: Expression) -> Statement {
self.update_structs(&lhs, &rhs);
self.assigner.simple_assign_statement(lhs, rhs)
}
pub(crate) fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
if !returns.is_empty() {
let mut return_expressions = Vec::with_capacity(returns.len());
let (has_finalize, number_of_finalize_arguments) = match &returns[0].1.finalize_arguments {
None => (false, 0),
Some(args) => (true, args.len()),
};
let mut finalize_arguments: Vec<Vec<(Option<Expression>, Expression)>> =
vec![Vec::with_capacity(returns.len()); number_of_finalize_arguments];
for (guard, return_statement) in returns {
return_expressions.push((guard.clone(), return_statement.expression));
if let Some(arguments) = return_statement.finalize_arguments {
for (i, argument) in arguments.into_iter().enumerate() {
finalize_arguments[i].push((guard.clone(), argument));
}
}
}
let (expression, stmts) = self.fold_guards("$ret", return_expressions);
block.statements.extend(stmts);
let finalize_arguments = match has_finalize {
false => None,
true => Some(
finalize_arguments
.into_iter()
.enumerate()
.map(|(i, arguments)| {
let (expression, stmts) = self.fold_guards(&format!("finalize${i}$"), arguments);
block.statements.extend(stmts);
expression
})
.collect(),
),
};
block.statements.push(Statement::Return(ReturnStatement {
expression,
finalize_arguments,
span: Default::default(),
}));
}
}
}