use super::DestructuringVisitor;
use leo_ast::*;
use leo_span::Symbol;
use itertools::{Itertools, izip};
impl AstReconstructor for DestructuringVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = Vec<Statement>;
fn reconstruct_binary(
&mut self,
input: BinaryExpression,
_additional: &Self::AdditionalInput,
) -> (Expression, Self::AdditionalOutput) {
let (left, mut statements) = self.reconstruct_expression_tuple(input.left);
let (right, statements2) = self.reconstruct_expression_tuple(input.right);
statements.extend(statements2);
use BinaryOperation::*;
if let (Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) = (&left, &right)
&& matches!(input.op, Eq | Neq)
{
assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
let pieces: Vec<Expression> = tuple_left
.elements
.iter()
.zip(&tuple_right.elements)
.map(|(l, r)| {
let expr: Expression = BinaryExpression {
op: input.op,
left: l.clone(),
right: r.clone(),
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
self.state.type_table.insert(expr.id(), Type::Boolean);
expr
})
.collect();
let op = match input.op {
Eq => BinaryOperation::And,
Neq => BinaryOperation::Or,
_ => unreachable!(),
};
return (self.fold_with_op(op, pieces.into_iter()), statements);
}
(BinaryExpression { op: input.op, left, right, ..input }.into(), Default::default())
}
fn reconstruct_tuple_access(
&mut self,
input: TupleAccess,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let Expression::Path(path) = &input.tuple else {
panic!("SSA guarantees that subexpressions are identifiers or literals.");
};
match self.tuples.get(&path.identifier().name).and_then(|tuple_names| tuple_names.get(input.index.value())) {
Some(id) => (Path::from(*id).to_local().into(), Default::default()),
None => {
if !matches!(self.state.type_table.get(&path.id), Some(Type::Future(_))) {
panic!("Type checking guarantees that all tuple accesses are declared and indices are valid.");
}
let index = Literal::integer(
IntegerType::U32,
input.index.to_string(),
input.span,
self.state.node_builder.next_id(),
);
self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
let expr =
ArrayAccess { array: path.clone().into(), index: index.into(), span: input.span, id: input.id }
.into();
(expr, Default::default())
}
}
}
fn reconstruct_ternary(
&mut self,
mut input: TernaryExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (condition, mut statements) =
self.reconstruct_expression(std::mem::take(&mut input.condition), &Default::default());
let (if_true, statements2) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_true));
statements.extend(statements2);
let (if_false, statements3) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_false));
statements.extend(statements3);
match (if_true, if_false) {
(Expression::Tuple(tuple_true), Expression::Tuple(tuple_false)) => {
let Some(Type::Tuple(tuple_type)) = self.state.type_table.get(&tuple_true.id()) else {
panic!("Should have tuple type");
};
let cond = if let Expression::Path(..) = condition {
condition
} else {
let place = Identifier::new(
self.state.assigner.unique_symbol("cond", "$$"),
self.state.node_builder.next_id(),
);
let definition =
self.state.assigner.simple_definition(place, condition, self.state.node_builder.next_id());
statements.push(definition);
self.state.type_table.insert(place.id(), Type::Boolean);
Expression::Path(Path::from(place).to_local())
};
let mut elements = Vec::with_capacity(tuple_true.elements.len());
for (i, (lhs, rhs, ty)) in
izip!(tuple_true.elements, tuple_false.elements, tuple_type.elements()).enumerate()
{
let identifier = Identifier::new(
self.state.assigner.unique_symbol(format_args!("ternary_{i}"), "$$"),
self.state.node_builder.next_id(),
);
let expression: Expression = TernaryExpression {
condition: cond.clone(),
if_true: lhs,
if_false: rhs,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
self.state.type_table.insert(identifier.id(), ty.clone());
self.state.type_table.insert(expression.id(), ty.clone());
let definition = self.state.assigner.simple_definition(
identifier,
expression,
self.state.node_builder.next_id(),
);
statements.push(definition);
elements.push(Path::from(identifier).to_local().into());
}
let expr: Expression =
TupleExpression { elements, span: Default::default(), id: self.state.node_builder.next_id() }
.into();
self.state.type_table.insert(expr.id(), Type::Tuple(tuple_type.clone()));
(expr, statements)
}
(if_true, if_false) => {
(TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
}
}
}
fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
match input.variant {
AssertVariant::Assert(expr) => {
let (expr, _) = self.reconstruct_expression(expr, &Default::default());
(AssertStatement { variant: AssertVariant::Assert(expr), ..input }.into(), Default::default())
}
AssertVariant::AssertEq(ref left, ref right) | AssertVariant::AssertNeq(ref left, ref right) => {
let (left, mut statements) = self.reconstruct_expression_tuple(left.clone());
let (right, statements2) = self.reconstruct_expression_tuple(right.clone());
statements.extend(statements2);
match (&left, &right) {
(Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) => {
assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
for (l, r) in tuple_left.elements.iter().zip(&tuple_right.elements) {
let assert_variant = match input.variant {
AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(l.clone(), r.clone()),
AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(l.clone(), r.clone()),
_ => unreachable!(),
};
let stmt = AssertStatement { variant: assert_variant, ..input.clone() }.into();
statements.push(stmt);
}
(Statement::dummy(), statements)
}
_ => {
let variant = match input.variant {
AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(left, right),
AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(left, right),
_ => unreachable!(),
};
(AssertStatement { variant, ..input }.into(), Default::default())
}
}
}
}
}
fn reconstruct_assign(&mut self, mut assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
let (value, mut statements) = self.reconstruct_expression(assign.value, &());
if let Expression::Path(path) = &assign.place
&& let Type::Tuple(..) = self.state.type_table.get(&value.id()).expect("Expressions should have types.")
{
let identifiers = self.tuples.get(&path.identifier().name).expect("Tuple should have been encountered.");
let Expression::Path(rhs) = value else {
panic!("SSA should have ensured this is an identifier.");
};
let rhs_identifiers = self.tuples.get(&rhs.identifier().name).expect("Tuple should have been encountered.");
for (&identifier, &rhs_identifier) in identifiers.iter().zip_eq(rhs_identifiers) {
let stmt = AssignStatement {
place: Path::from(identifier).to_local().into(),
value: Path::from(rhs_identifier).to_local().into(),
id: self.state.node_builder.next_id(),
span: Default::default(),
}
.into();
statements.push(stmt);
}
return (Statement::dummy(), statements);
}
assign.value = value;
let mut place = &mut assign.place;
loop {
match place {
Expression::TupleAccess(access) => {
let Expression::Path(path) = &access.tuple else {
panic!("SSA should have ensured this is an identifier.");
};
let tuple_ids =
self.tuples.get(&path.identifier().name).expect("Tuple should have been encountered.");
let identifier = tuple_ids[access.index.value()];
*place = Path::from(identifier).to_local().into();
return (assign.into(), statements);
}
Expression::ArrayAccess(access) => {
place = &mut access.array;
}
Expression::MemberAccess(access) => {
place = &mut access.inner;
}
Expression::Path(..) => {
return (assign.into(), statements);
}
_ => panic!("Type checking should have prevented this."),
}
}
}
fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
let mut statements = Vec::with_capacity(block.statements.len());
for statement in block.statements {
let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
statements.extend(additional_statements);
if !reconstructed_statement.is_empty() {
statements.push(reconstructed_statement);
}
}
(Block { statements, ..block }, Default::default())
}
fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
let (then, statements2) = self.reconstruct_block(input.then);
statements.extend(statements2);
let otherwise = input.otherwise.map(|oth| {
let (expr, statements3) = self.reconstruct_statement(*oth);
statements.extend(statements3);
Box::new(expr)
});
(ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
}
fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
use DefinitionPlace::*;
let make_identifiers = |slf: &mut Self, single: Symbol, count: usize| -> Vec<Identifier> {
(0..count)
.map(|i| {
Identifier::new(
slf.state.assigner.unique_symbol(format_args!("{single}#tuple{i}"), "$"),
slf.state.node_builder.next_id(),
)
})
.collect()
};
let (value, mut statements) = self.reconstruct_expression(definition.value, &());
let ty = self.state.type_table.get(&value.id()).expect("Expressions should have a type.");
match (definition.place, value, ty) {
(Single(identifier), Expression::Path(rhs), Type::Tuple(tuple_type)) => {
let identifiers = make_identifiers(self, identifier.name, tuple_type.length());
let rhs_identifiers = self.tuples.get(&rhs.identifier().name).unwrap();
for (identifier, rhs_identifier, ty) in izip!(&identifiers, rhs_identifiers, tuple_type.elements()) {
let stmt = DefinitionStatement {
place: Single(*identifier),
type_: Some(ty.clone()),
value: Expression::Path(Path::from(*rhs_identifier).to_local()),
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
statements.push(stmt);
self.state.type_table.insert(identifier.id(), ty.clone());
}
self.tuples.insert(identifier.name, identifiers);
(Statement::dummy(), statements)
}
(Single(identifier), Expression::Tuple(tuple), Type::Tuple(tuple_type)) => {
let identifiers = make_identifiers(self, identifier.name, tuple_type.length());
for (identifier, expr, ty) in izip!(&identifiers, tuple.elements, tuple_type.elements()) {
let stmt = DefinitionStatement {
place: Single(*identifier),
type_: Some(ty.clone()),
value: expr,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
statements.push(stmt);
self.state.type_table.insert(identifier.id(), ty.clone());
}
self.tuples.insert(identifier.name, identifiers);
(Statement::dummy(), statements)
}
(Single(identifier), rhs @ Expression::Call(..), Type::Tuple(tuple_type)) => {
let definition_stmt = self.assign_tuple(rhs, identifier.name);
let Statement::Definition(DefinitionStatement {
place: DefinitionPlace::Multiple(identifiers), ..
}) = &definition_stmt
else {
panic!("assign_tuple creates `Multiple`.");
};
self.tuples.insert(identifier.name, identifiers.clone());
for (identifier, ty) in identifiers.iter().zip(tuple_type.elements()) {
self.state.type_table.insert(identifier.id(), ty.clone());
}
(definition_stmt, statements)
}
(Multiple(identifiers), Expression::Tuple(tuple), Type::Tuple(..)) => {
for (identifier, expr) in identifiers.into_iter().zip_eq(tuple.elements) {
let stmt = DefinitionStatement {
place: Single(identifier),
type_: None,
value: expr,
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
statements.push(stmt);
}
(Statement::dummy(), statements)
}
(Multiple(identifiers), Expression::Path(rhs), Type::Tuple(..)) => {
let rhs_identifiers =
self.tuples.get(&rhs.identifier().name).expect("We should have encountered this tuple by now");
for (identifier, rhs_identifier) in identifiers.into_iter().zip_eq(rhs_identifiers.iter()) {
let stmt = DefinitionStatement {
place: Single(identifier),
type_: None,
value: Expression::Path(Path::from(*rhs_identifier).to_local()),
span: Default::default(),
id: self.state.node_builder.next_id(),
}
.into();
statements.push(stmt);
}
(Statement::dummy(), statements)
}
(m @ Multiple(..), value @ Expression::Call(..), Type::Tuple(..)) => {
let stmt =
DefinitionStatement { place: m, type_: None, value, span: definition.span, id: definition.id }
.into();
(stmt, statements)
}
(_, _, Type::Tuple(..)) => {
panic!("Expressions of tuple type can only be tuple literals, identifiers, or calls.");
}
(s @ Single(..), rhs, _) => {
let stmt = DefinitionStatement {
place: s,
type_: None,
value: rhs,
span: Default::default(),
id: definition.id,
}
.into();
(stmt, statements)
}
(Multiple(_), _, _) => panic!("A definition with multiple identifiers must have tuple type"),
}
}
fn reconstruct_iteration(&mut self, _: IterationStatement) -> (Statement, Self::AdditionalOutput) {
panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
}
fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
let (expression, statements) = self.reconstruct_expression_tuple(input.expression);
(ReturnStatement { expression, ..input }.into(), statements)
}
}