use super::MonomorphizationVisitor;
use leo_ast::{AstReconstructor, Composite, Function, ProgramReconstructor, ProgramScope, Statement, Variant};
use leo_span::Symbol;
use indexmap::IndexMap;
impl ProgramReconstructor for MonomorphizationVisitor<'_> {
fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
self.program = input.program_id.name.name;
let mut struct_map: IndexMap<Symbol, Composite> = input.structs.clone().into_iter().collect();
let struct_order = self.state.struct_graph.post_order().unwrap();
for struct_name in &struct_order {
if let Some(r#struct) = struct_map.swap_remove(struct_name) {
let reconstructed_struct = self.reconstruct_struct(r#struct);
self.reconstructed_structs.insert(*struct_name, reconstructed_struct);
}
}
if !struct_map.is_empty() {
self.changed = true;
}
let mut function_map: IndexMap<Symbol, Function> = input.functions.into_iter().collect();
let order = self
.state
.call_graph
.post_order_from_entry_points(|node| {
function_map
.get(node)
.map(|f| {
matches!(
f.variant,
Variant::AsyncTransition | Variant::Transition | Variant::Function | Variant::Script
)
})
.unwrap_or(false)
})
.unwrap();
for function_name in &order {
if let Some(function) = function_map.swap_remove(function_name) {
let reconstructed_function = self.reconstruct_function(function);
self.reconstructed_functions.insert(*function_name, reconstructed_function);
}
}
let mappings =
input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect();
let consts = input
.consts
.into_iter()
.map(|(i, c)| match self.reconstruct_const(c) {
(Statement::Const(declaration), _) => (i, declaration),
_ => panic!("`reconstruct_const` can only return `Statement::Const`"),
})
.collect();
self.reconstructed_functions.retain(|f, _| {
let is_monomorphized = self.monomorphized_functions.contains(f);
let is_still_called = self.unresolved_calls.iter().any(|c| c.function.name == *f);
!is_monomorphized || is_still_called
});
let structs = core::mem::take(&mut self.reconstructed_structs).into_iter().collect::<Vec<_>>();
let (transitions, mut non_transitions): (Vec<_>, Vec<_>) = core::mem::take(&mut self.reconstructed_functions)
.into_iter()
.partition(|(_, f)| f.variant.is_transition());
let mut all_functions = transitions;
all_functions.append(&mut non_transitions);
ProgramScope {
program_id: input.program_id,
structs,
mappings,
functions: all_functions,
consts,
span: input.span,
}
}
}