use leo_ast::{
ArrayAccess,
BinaryExpression,
CastExpression,
CoreFunction,
Expression,
ExpressionReconstructor,
MemberAccess,
Node,
StructExpression,
TernaryExpression,
TupleAccess,
Type,
UnaryExpression,
};
use leo_errors::StaticAnalyzerError;
use leo_interpreter::{StructContents, Value};
use leo_span::sym;
use super::{ConstPropagationVisitor, value_to_expression};
const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
impl ExpressionReconstructor for ConstPropagationVisitor<'_> {
type AdditionalOutput = Option<Value>;
fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
let old_id = input.id();
let (new_expr, opt_value) = match input {
Expression::Array(array) => self.reconstruct_array(array),
Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
Expression::Binary(binary) => self.reconstruct_binary(*binary),
Expression::Call(call) => self.reconstruct_call(*call),
Expression::Cast(cast) => self.reconstruct_cast(*cast),
Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
Expression::Err(err) => self.reconstruct_err(err),
Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
Expression::Literal(value) => self.reconstruct_literal(value),
Expression::Locator(locator) => self.reconstruct_locator(locator),
Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
Expression::Unary(unary) => self.reconstruct_unary(*unary),
Expression::Unit(unit) => self.reconstruct_unit(unit),
};
if old_id != new_expr.id() {
self.changed = true;
let old_type =
self.state.type_table.get(&old_id).expect("Type checking guarantees that all expressions have a type.");
self.state.type_table.insert(new_expr.id(), old_type);
}
(new_expr, opt_value)
}
fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
let mut values = Vec::new();
for member in input.members.iter_mut() {
if let Some(expr) = std::mem::take(&mut member.expression) {
let (new_expr, value_opt) = self.reconstruct_expression(expr);
member.expression = Some(new_expr);
if let Some(value) = value_opt {
values.push(value);
}
}
}
if values.len() == input.members.len() {
let value = Value::Struct(StructContents {
name: input.name.name,
contents: input.members.iter().map(|mem| mem.identifier.name).zip(values).collect(),
});
(input.into(), Some(value))
} else {
(input.into(), None)
}
}
fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
let (cond, cond_value) = self.reconstruct_expression(input.condition);
match cond_value {
Some(Value::Bool(true)) => self.reconstruct_expression(input.if_true),
Some(Value::Bool(false)) => self.reconstruct_expression(input.if_false),
_ => (
TernaryExpression {
condition: cond,
if_true: self.reconstruct_expression(input.if_true).0,
if_false: self.reconstruct_expression(input.if_false).0,
..input
}
.into(),
None,
),
}
}
fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
let span = input.span();
let array_id = input.array.id();
let (array, value_opt) = self.reconstruct_expression(input.array);
let (index, opt_value) = self.reconstruct_expression(input.index);
if let Some(value) = opt_value {
let ty = self.state.type_table.get(&array_id);
let Some(Type::Array(array_ty)) = ty else {
panic!("Type checking guaranteed that this is an array.");
};
let len = array_ty.length();
let index: usize = match value {
Value::U8(x) => x as usize,
Value::U16(x) => x as usize,
Value::U32(x) => x.try_into().unwrap_or(len),
Value::U64(x) => x.try_into().unwrap_or(len),
Value::U128(x) => x.try_into().unwrap_or(len),
Value::I8(x) => x.try_into().unwrap_or(len),
Value::I16(x) => x.try_into().unwrap_or(len),
Value::I32(x) => x.try_into().unwrap_or(len),
Value::I64(x) => x.try_into().unwrap_or(len),
Value::I128(x) => x.try_into().unwrap_or(len),
_ => panic!("Type checking guarantees this is an integer"),
};
if index >= len {
if !self.state.handler.had_errors() {
let str_index = match value {
Value::U8(x) => format!("{x}"),
Value::U16(x) => format!("{x}"),
Value::U32(x) => format!("{x}"),
Value::U64(x) => format!("{x}"),
Value::U128(x) => format!("{x}"),
Value::I8(x) => format!("{x}"),
Value::I16(x) => format!("{x}"),
Value::I32(x) => format!("{x}"),
Value::I64(x) => format!("{x}"),
Value::I128(x) => format!("{x}"),
_ => unreachable!("We would have panicked above"),
};
self.emit_err(StaticAnalyzerError::array_bounds(str_index, len, span));
}
} else if let Some(Value::Array(value)) = value_opt {
let result_value = value.get(index).expect("We already checked bounds.");
return (
value_to_expression(result_value, input.span, &self.state.node_builder).expect(VALUE_ERROR),
Some(result_value.clone()),
);
}
} else {
self.array_index_not_evaluated = Some(index.span());
}
(ArrayAccess { array, index, ..input }.into(), None)
}
fn reconstruct_associated_constant(
&mut self,
input: leo_ast::AssociatedConstantExpression,
) -> (Expression, Self::AdditionalOutput) {
let generator = Value::generator();
let expr = value_to_expression(&generator, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
(expr, Some(generator))
}
fn reconstruct_associated_function(
&mut self,
mut input: leo_ast::AssociatedFunctionExpression,
) -> (Expression, Self::AdditionalOutput) {
let mut values = Vec::new();
for argument in input.arguments.iter_mut() {
let (new_argument, opt_value) = self.reconstruct_expression(std::mem::take(argument));
*argument = new_argument;
if let Some(value) = opt_value {
values.push(value);
}
}
if values.len() == input.arguments.len() && !matches!(input.variant.name, sym::CheatCode | sym::Mapping) {
let core_function = CoreFunction::from_symbols(input.variant.name, input.name.name)
.expect("Type checking guarantees this is valid.");
match leo_interpreter::evaluate_core_function(&mut values, core_function, &[], input.span()) {
Ok(Some(value)) => {
let expr = value_to_expression(&value, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
return (expr, Some(value));
}
Ok(None) =>
{}
Err(err) => {
self.emit_err(StaticAnalyzerError::compile_core_function(err, input.span()));
}
}
}
(input.into(), Default::default())
}
fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
let span = input.span();
let (inner, value_opt) = self.reconstruct_expression(input.inner);
let member_name = input.name.name;
if let Some(Value::Struct(contents)) = value_opt {
let value_result =
contents.contents.get(&member_name).expect("Type checking guarantees the member exists.");
(
value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
Some(value_result.clone()),
)
} else {
(MemberAccess { inner, ..input }.into(), None)
}
}
fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
let span = input.span();
let (tuple, value_opt) = self.reconstruct_expression(input.tuple);
if let Some(Value::Tuple(tuple)) = value_opt {
let value_result = tuple.get(input.index.value()).expect("Type checking checked bounds.");
(
value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
Some(value_result.clone()),
)
} else {
(TupleAccess { tuple, ..input }.into(), None)
}
}
fn reconstruct_array(&mut self, mut input: leo_ast::ArrayExpression) -> (Expression, Self::AdditionalOutput) {
let mut values = Vec::new();
input.elements.iter_mut().for_each(|element| {
let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element));
if let Some(value) = value_opt {
values.push(value);
}
*element = new_element;
});
if values.len() == input.elements.len() {
(input.into(), Some(Value::Array(values)))
} else {
(input.into(), None)
}
}
fn reconstruct_binary(&mut self, input: leo_ast::BinaryExpression) -> (Expression, Self::AdditionalOutput) {
let span = input.span();
let (left, lhs_opt_value) = self.reconstruct_expression(input.left);
let (right, rhs_opt_value) = self.reconstruct_expression(input.right);
if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
match leo_interpreter::evaluate_binary(span, input.op, &lhs_value, &rhs_value) {
Ok(new_value) => {
let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
return (new_expr, Some(new_value));
}
Err(err) => self
.emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
}
}
(BinaryExpression { left, right, ..input }.into(), None)
}
fn reconstruct_call(&mut self, mut input: leo_ast::CallExpression) -> (Expression, Self::AdditionalOutput) {
input.arguments.iter_mut().for_each(|arg| {
*arg = self.reconstruct_expression(std::mem::take(arg)).0;
});
(input.into(), Default::default())
}
fn reconstruct_cast(&mut self, input: leo_ast::CastExpression) -> (Expression, Self::AdditionalOutput) {
let span = input.span();
let (expr, opt_value) = self.reconstruct_expression(input.expression);
if let Some(value) = opt_value {
if let Some(cast_value) = value.cast(&input.type_) {
let expr = value_to_expression(&cast_value, span, &self.state.node_builder).expect(VALUE_ERROR);
return (expr, Some(cast_value));
} else {
self.emit_err(StaticAnalyzerError::compile_time_cast(value, &input.type_, span));
}
}
(CastExpression { expression: expr, ..input }.into(), None)
}
fn reconstruct_err(&mut self, _input: leo_ast::ErrExpression) -> (Expression, Self::AdditionalOutput) {
panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
}
fn reconstruct_identifier(&mut self, input: leo_ast::Identifier) -> (Expression, Self::AdditionalOutput) {
if let Some(expression) = self.state.symbol_table.lookup_const(self.program, input.name) {
let (expression, opt_value) = self.reconstruct_expression(expression);
if opt_value.is_some() {
return (expression, opt_value);
}
}
(input.into(), None)
}
fn reconstruct_literal(&mut self, input: leo_ast::Literal) -> (Expression, Self::AdditionalOutput) {
let value = leo_interpreter::literal_to_value(&input).expect("Should work");
(input.into(), Some(value))
}
fn reconstruct_locator(&mut self, input: leo_ast::LocatorExpression) -> (Expression, Self::AdditionalOutput) {
(input.into(), Default::default())
}
fn reconstruct_tuple(&mut self, mut input: leo_ast::TupleExpression) -> (Expression, Self::AdditionalOutput) {
let mut values = Vec::with_capacity(input.elements.len());
for expr in input.elements.iter_mut() {
let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr));
*expr = new_expr;
if let Some(value) = opt_value {
values.push(value);
}
}
let opt_value = if values.len() == input.elements.len() { Some(Value::Tuple(values)) } else { None };
(input.into(), opt_value)
}
fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
let (receiver, opt_value) = self.reconstruct_expression(input.receiver);
let span = input.span;
if let Some(value) = opt_value {
match leo_interpreter::evaluate_unary(span, input.op, &value) {
Ok(new_value) => {
let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
return (new_expr, Some(new_value));
}
Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
}
}
(UnaryExpression { receiver, ..input }.into(), None)
}
fn reconstruct_unit(&mut self, input: leo_ast::UnitExpression) -> (Expression, Self::AdditionalOutput) {
(input.into(), None)
}
}