use super::StorageLoweringVisitor;
use leo_ast::*;
use leo_span::{Span, Symbol, sym};
impl leo_ast::AstReconstructor for StorageLoweringVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = Vec<Statement>;
fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
let (length, stmts) = self.reconstruct_expression(*input.length, &());
(
Type::Array(ArrayType {
element_type: Box::new(self.reconstruct_type(*input.element_type).0),
length: Box::new(length),
}),
stmts,
)
}
fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
let mut statements = Vec::new();
let const_arguments = input
.const_arguments
.into_iter()
.map(|arg| {
let (expr, stmts) = self.reconstruct_expression(arg, &Default::default());
statements.extend(stmts);
expr
})
.collect();
(Type::Composite(CompositeType { const_arguments, ..input }), statements)
}
fn reconstruct_array_access(
&mut self,
mut input: ArrayAccess,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (array, mut stmts_array) = self.reconstruct_expression(input.array, &());
let (index, mut stmts_index) = self.reconstruct_expression(input.index, &());
input.array = array;
input.index = index;
stmts_array.append(&mut stmts_index);
(input.into(), stmts_array)
}
fn reconstruct_intrinsic(
&mut self,
mut input: IntrinsicExpression,
_additional: &Self::AdditionalInput,
) -> (Expression, Self::AdditionalOutput) {
match Intrinsic::from_symbol(input.name, &input.type_parameters) {
Some(Intrinsic::VectorPush) => {
let [vector_expr, value_expr] = &mut input.arguments[..] else {
panic!("Vector::push should have 2 arguments");
};
assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::push can only be called with `Expression::Path`");
};
let (value, stmts) = self.reconstruct_expression(value_expr.clone(), &());
let (vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
let len_var_ident =
Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
let len_stmt = self.state.assigner.simple_definition(
len_var_ident,
get_len_expr,
self.state.node_builder.next_id(),
);
let len_var_expr: Expression = len_var_ident.into();
let literal_one = self.literal_one_u32();
let increment_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Add, literal_one);
let set_vec_stmt_expr = self.set_mapping_expr(vec_path_expr, len_var_expr.clone(), value, input.span);
let literal_false = self.literal_false();
let set_len_stmt = Statement::Expression(ExpressionStatement {
expression: self.set_mapping_expr(len_path_expr, literal_false, increment_expr, input.span),
span: input.span,
id: self.state.node_builder.next_id(),
});
(set_vec_stmt_expr, [stmts, vec![len_stmt, set_len_stmt]].concat())
}
Some(Intrinsic::VectorLen) => {
let [vector_expr] = &mut input.arguments[..] else {
panic!("Vector::len should have 1 argument");
};
assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::len can only be called with `Expression::Path`");
};
let (_vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let get_len_expr = self.get_vector_len_expr(len_path_expr, input.span);
(get_len_expr, vec![])
}
Some(Intrinsic::VectorPop) => {
let [vector_expr] = &mut input.arguments[..] else {
panic!("Vector::pop should have 1 argument");
};
let Some(Type::Vector(VectorType { element_type })) = self.state.type_table.get(&vector_expr.id())
else {
panic!("argument to Vector::pop should be of type `Vector`.");
};
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::pop can only be called with `Expression::Path`");
};
let (vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
let len_var_ident =
Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
let len_stmt = self.state.assigner.simple_definition(
len_var_ident,
get_len_expr,
self.state.node_builder.next_id(),
);
let len_var_expr: Expression = len_var_ident.into();
let literal_zero = self.literal_zero_u32();
let len_gt_zero_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Gt, literal_zero);
let literal_one = self.literal_one_u32();
let len_minus_one_expr =
self.binary_expr(len_var_expr.clone(), BinaryOperation::SubWrapped, literal_one);
let new_len_expr = self.ternary_expr(
len_gt_zero_expr.clone(),
len_minus_one_expr.clone(),
len_var_expr.clone(),
input.span,
);
let literal_false = self.literal_false();
let set_len_stmt = Statement::Expression(ExpressionStatement {
expression: self.set_mapping_expr(len_path_expr.clone(), literal_false, new_len_expr, input.span),
span: input.span,
id: self.state.node_builder.next_id(),
});
let zero = self.zero(&element_type);
let get_or_use_expr =
self.get_or_use_mapping_expr(vec_path_expr, len_minus_one_expr.clone(), zero, input.span);
let none_expr: Expression = Literal::none(Span::default(), self.state.node_builder.next_id()).into();
let ternary_expr = self.ternary_expr(len_gt_zero_expr, get_or_use_expr, none_expr, input.span);
(ternary_expr, vec![len_stmt, set_len_stmt])
}
Some(Intrinsic::VectorGet) => {
let [vector_expr, key_expr] = &mut input.arguments[..] else {
panic!("Vector::get should have 2 arguments");
};
let Some(Type::Vector(VectorType { element_type })) = self.state.type_table.get(&vector_expr.id())
else {
panic!("argument to Vector::get should be of type `Vector`.");
};
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::get can only be called with `Expression::Path`");
};
let (reconstructed_key_expr, key_stmts) =
self.reconstruct_expression(key_expr.clone(), &Default::default());
let (vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
let len_var_ident =
Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
let len_stmt = self.state.assigner.simple_definition(
len_var_ident,
get_len_expr,
self.state.node_builder.next_id(),
);
let len_var_expr: Expression = len_var_ident.into();
let index_lt_len_expr =
self.binary_expr(reconstructed_key_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
let zero = self.zero(&element_type);
let get_or_use_expr =
self.get_or_use_mapping_expr(vec_path_expr, reconstructed_key_expr.clone(), zero, input.span);
let none_expr: Expression = Literal::none(Span::default(), self.state.node_builder.next_id()).into();
let ternary_expr = self.ternary_expr(index_lt_len_expr, get_or_use_expr, none_expr, input.span);
(ternary_expr, [key_stmts, vec![len_stmt]].concat())
}
Some(Intrinsic::VectorSet) => {
let [vector_expr, index_expr, value_expr] = &mut input.arguments[..] else {
panic!("Vector::set should have 3 arguments");
};
assert!(
matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))),
"argument to Vector::set should be of type `Vector`."
);
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::set can only be called with `Expression::Path`");
};
let (reconstructed_key_expr, key_stmts) =
self.reconstruct_expression(index_expr.clone(), &Default::default());
let (reconstructed_value_expr, value_stmts) =
self.reconstruct_expression(value_expr.clone(), &Default::default());
let (vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
let len_var_ident =
Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
let len_stmt = self.state.assigner.simple_definition(
len_var_ident,
get_len_expr,
self.state.node_builder.next_id(),
);
let len_var_expr: Expression = len_var_ident.into();
let index_lt_len_expr =
self.binary_expr(reconstructed_key_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
let set_stmt_expr = self.set_mapping_expr(
vec_path_expr.clone(),
reconstructed_key_expr.clone(),
reconstructed_value_expr.clone(),
input.span,
);
let assert_stmt = Statement::Assert(AssertStatement {
variant: AssertVariant::Assert(index_lt_len_expr.clone()),
span: Span::default(),
id: self.state.node_builder.next_id(),
});
(set_stmt_expr, [key_stmts, value_stmts, vec![len_stmt, assert_stmt]].concat())
}
Some(Intrinsic::VectorClear) => {
let [vector_expr] = &mut input.arguments[..] else {
panic!("Vector::clear should have 1 argument");
};
assert!(
matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))),
"argument to Vector::clear should be of type `Vector`."
);
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::clear can only be called with `Expression::Path`");
};
let (_vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let literal_false = self.literal_false();
let literal_zero = self.literal_zero_u32();
let set_len_stmt_expr = self.set_mapping_expr(len_path_expr, literal_false, literal_zero, input.span);
(set_len_stmt_expr, vec![])
}
Some(Intrinsic::VectorSwapRemove) => {
let [vector_expr, index_expr] = &mut input.arguments[..] else {
panic!("Vector::swap_remove should have 2 arguments");
};
assert!(
matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))),
"argument to Vector::swap_remove should be of type `Vector`."
);
let Expression::Path(path_to_vector) = vector_expr else {
panic!("Vector::swap_remove can only be called with `Expression::Path`");
};
let (reconstructed_index_expr, index_stmts) =
self.reconstruct_expression(index_expr.clone(), &Default::default());
let (vec_path_expr, len_path_expr) = self.generate_vector_mapping_exprs(path_to_vector);
let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
let len_var_ident =
Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
let len_stmt = self.state.assigner.simple_definition(
len_var_ident,
get_len_expr,
self.state.node_builder.next_id(),
);
let len_var_expr: Expression = len_var_ident.into();
let index_lt_len_expr =
self.binary_expr(reconstructed_index_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
let assert_stmt = Statement::Assert(AssertStatement {
variant: AssertVariant::Assert(index_lt_len_expr.clone()),
span: input.span,
id: self.state.node_builder.next_id(),
});
let get_elem_expr =
self.get_mapping_expr(vec_path_expr.clone(), reconstructed_index_expr.clone(), input.span);
let removed_sym = self.state.assigner.unique_symbol("$removed", "$");
let removed_ident =
Identifier { name: removed_sym, span: Default::default(), id: self.state.node_builder.next_id() };
let removed_stmt = Statement::Definition(DefinitionStatement {
place: DefinitionPlace::Single(removed_ident),
type_: None,
value: get_elem_expr,
span: input.span,
id: self.state.node_builder.next_id(),
});
let literal_one = self.literal_one_u32();
let len_minus_one_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Sub, literal_one);
let get_last_expr =
self.get_mapping_expr(vec_path_expr.clone(), len_minus_one_expr.clone(), input.span);
let set_swap_stmt = Statement::Expression(ExpressionStatement {
expression: self.set_mapping_expr(
vec_path_expr.clone(),
reconstructed_index_expr.clone(),
get_last_expr,
input.span,
),
span: input.span,
id: self.state.node_builder.next_id(),
});
let literal_false = self.literal_false();
let set_len_stmt = Statement::Expression(ExpressionStatement {
expression: self.set_mapping_expr(
len_path_expr.clone(),
literal_false,
len_minus_one_expr,
input.span,
),
span: input.span,
id: self.state.node_builder.next_id(),
});
(
removed_ident.into(),
[index_stmts, vec![len_stmt, assert_stmt, removed_stmt, set_swap_stmt, set_len_stmt]].concat(),
)
}
_ => {
let statements: Vec<_> = input
.arguments
.iter_mut()
.flat_map(|arg| {
let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &());
*arg = expr;
stmts
})
.collect();
(input.into(), statements)
}
}
}
fn reconstruct_member_access(
&mut self,
mut input: MemberAccess,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &());
input.inner = inner;
(input.into(), stmts_inner)
}
fn reconstruct_repeat(
&mut self,
mut input: RepeatExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &());
let (count, mut stmts_count) = self.reconstruct_expression(input.count, &());
input.expr = expr;
input.count = count;
stmts_expr.append(&mut stmts_count);
(input.into(), stmts_expr)
}
fn reconstruct_tuple_access(
&mut self,
mut input: TupleAccess,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (tuple, stmts) = self.reconstruct_expression(input.tuple, &());
input.tuple = tuple;
(input.into(), stmts)
}
fn reconstruct_array(
&mut self,
mut input: ArrayExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let mut all_stmts = Vec::new();
let mut new_elements = Vec::with_capacity(input.elements.len());
for element in input.elements.into_iter() {
let (expr, mut stmts) = self.reconstruct_expression(element, &());
all_stmts.append(&mut stmts);
new_elements.push(expr);
}
input.elements = new_elements;
(input.into(), all_stmts)
}
fn reconstruct_binary(
&mut self,
mut input: BinaryExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (left, mut stmts_left) = self.reconstruct_expression(input.left, &());
let (right, mut stmts_right) = self.reconstruct_expression(input.right, &());
input.left = left;
input.right = right;
stmts_left.append(&mut stmts_right);
(input.into(), stmts_left)
}
fn reconstruct_call(&mut self, mut input: CallExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
let mut statements = Vec::new();
for arg in input.arguments.iter_mut() {
let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
statements.extend(statements2);
*arg = expr;
}
(input.into(), statements)
}
fn reconstruct_cast(&mut self, input: CastExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
let (expression, statements) = self.reconstruct_expression(input.expression, &());
(CastExpression { expression, ..input }.into(), statements)
}
fn reconstruct_composite_init(
&mut self,
mut input: CompositeExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let mut statements = Vec::new();
for const_arg in input.const_arguments.iter_mut() {
let (expr, statements2) = self.reconstruct_expression(const_arg.clone(), &());
statements.extend(statements2);
*const_arg = expr;
}
for member in input.members.iter_mut() {
assert!(member.expression.is_some());
let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap(), &());
statements.extend(statements2);
member.expression = Some(expr);
}
(input.into(), statements)
}
fn reconstruct_path(&mut self, input: Path, _additional: &()) -> (Expression, Self::AdditionalOutput) {
(self.reconstruct_path_or_locator(input.into()), vec![])
}
fn reconstruct_ternary(
&mut self,
input: TernaryExpression,
_addiional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
let (if_true, statements2) = self.reconstruct_expression(input.if_true, &());
let (if_false, statements3) = self.reconstruct_expression(input.if_false, &());
statements.extend(statements2);
statements.extend(statements3);
(TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
}
fn reconstruct_tuple(
&mut self,
input: leo_ast::TupleExpression,
_addiional: &(),
) -> (Expression, Self::AdditionalOutput) {
let mut statements = Vec::new();
let elements = input
.elements
.into_iter()
.map(|element| {
let (expr, statements2) = self.reconstruct_expression(element, &());
statements.extend(statements2);
expr
})
.collect();
(TupleExpression { elements, ..input }.into(), statements)
}
fn reconstruct_unary(
&mut self,
input: leo_ast::UnaryExpression,
_addiional: &(),
) -> (Expression, Self::AdditionalOutput) {
let (receiver, statements) = self.reconstruct_expression(input.receiver, &());
(UnaryExpression { receiver, ..input }.into(), statements)
}
fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
let mut statements = Vec::new();
let stmt = AssertStatement {
variant: match input.variant {
AssertVariant::Assert(expr) => {
let (expr, statements2) = self.reconstruct_expression(expr, &());
statements.extend(statements2);
AssertVariant::Assert(expr)
}
AssertVariant::AssertEq(left, right) => {
let (left, statements2) = self.reconstruct_expression(left, &());
statements.extend(statements2);
let (right, statements3) = self.reconstruct_expression(right, &());
statements.extend(statements3);
AssertVariant::AssertEq(left, right)
}
AssertVariant::AssertNeq(left, right) => {
let (left, statements2) = self.reconstruct_expression(left, &());
statements.extend(statements2);
let (right, statements3) = self.reconstruct_expression(right, &());
statements.extend(statements3);
AssertVariant::AssertNeq(left, right)
}
},
..input
}
.into();
(stmt, statements)
}
fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
let AssignStatement { place, value, span, .. } = input;
let mut statements = vec![];
if let Expression::Path(path) = &place {
if let Some(global_location) = path.try_global_location() {
let var = self
.state
.symbol_table
.lookup_global(self.program, global_location)
.expect("A global path must point to a global");
assert!(
var.type_.as_ref().expect("must be known by now").is_optional(),
"Only storage variables that are not vectors or mappings are expected here."
);
let (new_value, mut value_stmts) = self.reconstruct_expression(value, &());
statements.append(&mut value_stmts);
let id = || self.state.node_builder.next_id();
let var_name = path.identifier().name;
let mapping_symbol = Symbol::intern(&format!("{var_name}__"));
let mapping_ident = Identifier::new(mapping_symbol, id());
let mapping_expr: Expression =
Path::from(mapping_ident).to_global(Location::new(self.program, vec![mapping_symbol])).into();
let false_literal: Expression = Literal::boolean(false, Span::default(), id()).into();
let stmt = if matches!(new_value, Expression::Literal(Literal { variant: LiteralVariant::None, .. })) {
let remove_expr: Expression = IntrinsicExpression {
name: sym::_mapping_remove,
type_parameters: vec![],
arguments: vec![mapping_expr, false_literal],
span,
id: id(),
}
.into();
Statement::Expression(ExpressionStatement { expression: remove_expr, span, id: id() })
} else {
let set_expr: Expression = IntrinsicExpression {
name: sym::_mapping_set,
type_parameters: vec![],
arguments: vec![mapping_expr, false_literal, new_value],
span,
id: id(),
}
.into();
Statement::Expression(ExpressionStatement { expression: set_expr, span, id: id() })
};
return (stmt, statements);
}
}
let (new_place, mut place_stmts) = self.reconstruct_expression(place, &());
let (new_value, mut value_stmts) = self.reconstruct_expression(value, &());
statements.append(&mut place_stmts);
statements.append(&mut value_stmts);
let stmt =
AssignStatement { place: new_place, value: new_value, span, id: self.state.node_builder.next_id() }.into();
(stmt, statements)
}
fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
let mut statements = Vec::with_capacity(block.statements.len());
for statement in block.statements {
let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
statements.extend(additional_statements);
statements.push(reconstructed_statement);
}
(Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
}
fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
let (then, statements2) = self.reconstruct_block(input.then);
statements.extend(statements2);
let otherwise = input.otherwise.map(|oth| {
let (expr, statements3) = self.reconstruct_statement(*oth);
statements.extend(statements3);
Box::new(expr)
});
(ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
}
fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
let (type_expr, type_statements) = self.reconstruct_type(input.type_);
let (value_expr, value_statements) = self.reconstruct_expression(input.value, &Default::default());
let mut statements = Vec::new();
statements.extend(type_statements);
statements.extend(value_statements);
(ConstDeclaration { type_: type_expr, value: value_expr, ..input }.into(), statements)
}
fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &());
input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
input.value = new_value;
(input.into(), additional_stmts)
}
fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
let (reconstructed_expression, statements) = self.reconstruct_expression(input.expression, &Default::default());
if !matches!(reconstructed_expression, Expression::Call(_) | Expression::Intrinsic(_)) {
(
ExpressionStatement {
expression: Expression::Unit(UnitExpression {
span: Span::default(),
id: self.state.node_builder.next_id(),
}),
..input
}
.into(),
statements,
)
} else {
(ExpressionStatement { expression: reconstructed_expression, ..input }.into(), statements)
}
}
fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
panic!("`IterationStatement`s should not be in the AST at this point.");
}
fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
let (expression, statements) = self.reconstruct_expression(input.expression, &());
(ReturnStatement { expression, ..input }.into(), statements)
}
}