use super::MonomorphizationVisitor;
use crate::{ConstPropagationVisitor, Replacer};
use leo_ast::{
AstReconstructor,
CallExpression,
CompositeExpression,
CompositeFieldInitializer,
CompositeType,
Expression,
Identifier,
Node as _,
ProgramReconstructor,
Type,
Variant,
};
use indexmap::IndexMap;
use itertools::{Either, Itertools};
impl<'a> MonomorphizationVisitor<'a> {
fn try_evaluate_const_args(&mut self, const_args: &[Expression]) -> Option<Vec<Expression>> {
let mut const_evaluator = ConstPropagationVisitor::new(self.state, self.program);
let (evaluated_const_args, non_const_args): (Vec<_>, Vec<_>) = const_args
.iter()
.map(|arg| const_evaluator.reconstruct_expression(arg.clone(), &()))
.partition_map(|(evaluated_arg, evaluated_value)| match (evaluated_value, evaluated_arg) {
(Some(_), expr @ Expression::Literal(_)) => Either::Left(expr),
_ => Either::Right(()),
});
if !non_const_args.is_empty() { None } else { Some(evaluated_const_args) }
}
}
impl AstReconstructor for MonomorphizationVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = ();
fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
if input.const_arguments.is_empty() {
return (Type::Composite(input), Default::default());
}
let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
self.unresolved_composite_types.push(input.clone());
return (Type::Composite(input), Default::default());
};
self.changed = true;
(
Type::Composite(CompositeType {
path: self.monomorphize_composite(&input.path, &evaluated_const_args),
const_arguments: vec![], }),
Default::default(),
)
}
fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
let opt_old_type = self.state.type_table.get(&input.id());
let (new_expr, opt_value) = match input {
Expression::Array(array) => self.reconstruct_array(array, &()),
Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, &()),
Expression::Async(async_) => self.reconstruct_async(async_, &()),
Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
Expression::Call(call) => self.reconstruct_call(*call, &()),
Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
Expression::Composite(composite) => self.reconstruct_composite_init(composite, &()),
Expression::Err(err) => self.reconstruct_err(err, &()),
Expression::Path(path) => self.reconstruct_path(path, &()),
Expression::Literal(value) => self.reconstruct_literal(value, &()),
Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
};
if let Some(old_type) = opt_old_type {
self.state.type_table.insert(new_expr.id(), old_type);
}
(new_expr, opt_value)
}
fn reconstruct_call(
&mut self,
input_call: CallExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
if input_call.function.expect_global_location().program != self.program {
return (input_call.into(), Default::default());
}
if input_call.const_arguments.is_empty() {
return (input_call.into(), Default::default());
}
let Some(evaluated_const_args) = self.try_evaluate_const_args(&input_call.const_arguments) else {
self.unresolved_calls.push(input_call.clone());
return (input_call.into(), Default::default());
};
let callee_fn = self
.reconstructed_functions
.get(input_call.function.expect_global_location())
.expect("Callee should already be reconstructed (post-order traversal).");
if !matches!(callee_fn.variant, Variant::Inline) {
return (input_call.into(), Default::default());
}
let new_callee_path = input_call.function.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!(
"\"{}::[{}]\"",
input_call.function.identifier().name,
evaluated_const_args.iter().format(", ")
)));
if self.reconstructed_functions.get(new_callee_path.expect_global_location()).is_none() {
let const_param_map: IndexMap<_, _> = callee_fn
.const_parameters
.iter()
.map(|param| param.identifier().name)
.zip_eq(&evaluated_const_args)
.collect();
let replace_identifier = |expr: &Expression| match expr {
Expression::Path(path) => const_param_map
.get(&path.identifier().name)
.map_or(Expression::Path(path.clone()), |&expr| expr.clone()),
_ => expr.clone(),
};
let mut replacer = Replacer::new(replace_identifier, true , self.state);
let mut function = replacer.reconstruct_function(callee_fn.clone());
function = self.reconstruct_function(function);
function.identifier = Identifier {
name: new_callee_path.identifier().name,
span: leo_span::Span::default(),
id: self.state.node_builder.next_id(),
};
function.const_parameters = vec![];
function.id = self.state.node_builder.next_id();
self.reconstructed_functions.insert(new_callee_path.expect_global_location().clone(), function);
self.monomorphized_functions.insert(input_call.function.expect_global_location().clone());
}
self.changed = true;
(
CallExpression {
function: new_callee_path,
const_arguments: vec![], arguments: input_call.arguments,
span: input_call.span, id: input_call.id,
}
.into(),
Default::default(),
)
}
fn reconstruct_composite_init(
&mut self,
mut input: CompositeExpression,
_additional: &(),
) -> (Expression, Self::AdditionalOutput) {
let members = input
.members
.clone()
.into_iter()
.map(|member| CompositeFieldInitializer {
identifier: member.identifier,
expression: member.expression.map(|expr| self.reconstruct_expression(expr, &()).0),
span: member.span,
id: member.id,
})
.collect();
if input.const_arguments.is_empty() {
input.members = members;
return (input.into(), Default::default());
}
let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
self.unresolved_composite_exprs.push(input.clone());
input.members = members;
return (input.into(), Default::default());
};
self.changed = true;
(
CompositeExpression {
path: self.monomorphize_composite(&input.path, &evaluated_const_args),
members,
const_arguments: vec![], span: input.span, id: input.id,
}
.into(),
Default::default(),
)
}
}