use super::MonomorphizationVisitor;
use leo_ast::{AstReconstructor, Module, Program, ProgramReconstructor, ProgramScope, Statement, Variant};
use leo_span::sym;
impl ProgramReconstructor for MonomorphizationVisitor<'_> {
fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
self.program = input.program_id.name.name;
let struct_order = self.state.struct_graph.post_order().unwrap();
for struct_name in &struct_order {
if let Some(r#struct) = self.struct_map.swap_remove(struct_name) {
let reconstructed_struct = self.reconstruct_struct(r#struct);
self.reconstructed_structs.insert(struct_name.clone(), reconstructed_struct);
}
}
if !self.struct_map.is_empty() {
self.changed = true;
}
let order = self
.state
.call_graph
.post_order_with_filter(|location| {
if location.program != self.program {
return false;
}
if location.program == self.program && location.path == vec![sym::constructor] {
return true;
}
self.function_map
.get(&location.path)
.map(|f| {
matches!(
f.variant,
Variant::AsyncTransition | Variant::Transition | Variant::Function | Variant::Script
)
})
.unwrap_or(false)
})
.unwrap() .into_iter()
.filter(|location| location.program == self.program).collect::<Vec<_>>();
for function_name in &order {
if let Some(function) = self.function_map.swap_remove(&function_name.path) {
let reconstructed_function = self.reconstruct_function(function);
self.reconstructed_functions.insert(function_name.path.clone(), reconstructed_function);
}
}
let mappings =
input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect();
let storage_variables = input
.storage_variables
.into_iter()
.map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
.collect();
let consts = input
.consts
.into_iter()
.map(|(i, c)| match self.reconstruct_const(c) {
(Statement::Const(declaration), _) => (i, declaration),
_ => panic!("`reconstruct_const` can only return `Statement::Const`"),
})
.collect();
let constructor = input.constructor.map(|c| self.reconstruct_constructor(c));
self.reconstructed_functions.retain(|f, _| {
let is_monomorphized = self.monomorphized_functions.contains(f);
let is_still_called = self.unresolved_calls.iter().any(|c| &c.function.absolute_path() == f);
!is_monomorphized || is_still_called
});
let (transitions, mut non_transitions): (Vec<_>, Vec<_>) =
self.reconstructed_functions.clone().into_iter().partition(|(_, f)| f.variant.is_transition());
let mut all_functions = transitions;
all_functions.append(&mut non_transitions);
ProgramScope {
program_id: input.program_id,
structs: self
.reconstructed_structs
.iter()
.filter_map(|(path, c)| {
path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, c.clone()))
})
.collect(),
mappings,
storage_variables,
functions: all_functions
.iter()
.filter_map(|(path, f)| {
path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, f.clone()))
})
.collect(),
constructor,
consts,
span: input.span,
}
}
fn reconstruct_program(&mut self, input: Program) -> Program {
input
.modules
.iter()
.flat_map(|(module_path, m)| {
m.functions.iter().map(move |(name, f)| {
(module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
})
})
.chain(
input
.program_scopes
.iter()
.flat_map(|(_, scope)| scope.functions.iter().map(|(name, f)| (vec![*name], f.clone()))),
)
.for_each(|(full_name, f)| {
self.function_map.insert(full_name, f);
});
input
.modules
.iter()
.flat_map(|(module_path, m)| {
m.structs.iter().map(move |(name, f)| {
(module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
})
})
.chain(
input
.program_scopes
.iter()
.flat_map(|(_, scope)| scope.structs.iter().map(|(name, f)| (vec![*name], f.clone()))),
)
.for_each(|(full_name, f)| {
self.struct_map.insert(full_name, f);
});
Program {
program_scopes: input
.program_scopes
.into_iter()
.map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
.collect(),
modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
..input
}
}
fn reconstruct_module(&mut self, input: Module) -> Module {
Module {
structs: self
.reconstructed_structs
.iter()
.filter_map(|(path, c)| path.split_last().map(|(last, rest)| (last, rest, c)))
.filter(|&(_, rest, _)| input.path == rest)
.map(|(last, _, c)| (*last, c.clone()))
.collect(),
functions: self
.reconstructed_functions
.iter()
.filter_map(|(path, f)| path.split_last().map(|(last, rest)| (last, rest, f)))
.filter(|&(_, rest, _)| input.path == rest)
.map(|(last, _, f)| (*last, f.clone()))
.collect(),
..input
}
}
}