use super::MonomorphizationVisitor;
use crate::Replacer;
use leo_ast::{
AstReconstructor,
CallExpression,
CompositeType,
Expression,
Identifier,
ProgramReconstructor,
StructExpression,
StructVariableInitializer,
Type,
Variant,
};
use indexmap::IndexMap;
use itertools::Itertools;
impl AstReconstructor for MonomorphizationVisitor<'_> {
type AdditionalOutput = ();
fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
if input.const_arguments.is_empty() {
return (Type::Composite(input), Default::default());
}
if input.const_arguments.iter().any(|arg| !matches!(arg, Expression::Literal(_))) {
self.unresolved_struct_types.push(input.clone());
return (Type::Composite(input), Default::default());
}
self.changed = true;
(
Type::Composite(CompositeType {
id: Identifier {
name: self.monomorphize_struct(&input.id.name, &input.const_arguments), span: input.id.span,
id: self.state.node_builder.next_id(),
},
const_arguments: vec![], program: input.program,
}),
Default::default(),
)
}
fn reconstruct_call(&mut self, input_call: CallExpression) -> (Expression, Self::AdditionalOutput) {
if input_call.program.unwrap() != self.program {
return (input_call.into(), Default::default());
}
let callee_fn = self
.reconstructed_functions
.get(&input_call.function.name)
.expect("Callee should already be reconstructed (post-order traversal).");
if !matches!(callee_fn.variant, Variant::Inline) || input_call.const_arguments.is_empty() {
return (input_call.into(), Default::default());
}
if input_call.const_arguments.iter().any(|arg| !matches!(arg, Expression::Literal(_))) {
self.unresolved_calls.push(input_call.clone());
return (input_call.into(), Default::default());
}
let new_callee_name = leo_span::Symbol::intern(&format!(
"\"{}::[{}]\"",
input_call.function.name,
input_call.const_arguments.iter().format(", ")
));
if self.reconstructed_functions.get(&new_callee_name).is_none() {
let const_param_map: IndexMap<_, _> = callee_fn
.const_parameters
.iter()
.map(|param| param.identifier().name)
.zip_eq(&input_call.const_arguments)
.collect();
let replace_identifier = |ident: &Identifier| {
const_param_map.get(&ident.name).map_or(Expression::Identifier(*ident), |&expr| expr.clone())
};
let mut replacer = Replacer::new(replace_identifier, &self.state.node_builder);
let mut function = replacer.reconstruct_function(callee_fn.clone());
function = self.reconstruct_function(function);
function.identifier = Identifier {
name: new_callee_name,
span: leo_span::Span::default(),
id: self.state.node_builder.next_id(),
};
function.const_parameters = vec![];
function.id = self.state.node_builder.next_id();
self.reconstructed_functions.insert(new_callee_name, function);
self.monomorphized_functions.insert(input_call.function.name);
}
self.changed = true;
(
CallExpression {
function: Identifier {
name: new_callee_name, span: leo_span::Span::default(),
id: self.state.node_builder.next_id(),
},
const_arguments: vec![], arguments: input_call.arguments,
program: input_call.program,
span: input_call.span, id: input_call.id,
}
.into(),
Default::default(),
)
}
fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
let members = input
.members
.clone()
.into_iter()
.map(|member| StructVariableInitializer {
identifier: member.identifier,
expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
span: member.span,
id: member.id,
})
.collect();
if input.const_arguments.is_empty() {
input.members = members;
return (input.into(), Default::default());
}
if input.const_arguments.iter().any(|arg| !matches!(arg, Expression::Literal(_))) {
self.unresolved_struct_exprs.push(input.clone());
input.members = members;
return (input.into(), Default::default());
}
self.changed = true;
(
StructExpression {
name: Identifier {
name: self.monomorphize_struct(&input.name.name, &input.const_arguments),
span: input.name.span,
id: self.state.node_builder.next_id(),
},
members,
const_arguments: vec![], span: input.span, id: input.id,
}
.into(),
Default::default(),
)
}
}