use super::FunctionInliningVisitor;
use leo_ast::{AstReconstructor, Constructor, Function, Program, ProgramReconstructor, ProgramScope};
use snarkvm::prelude::Itertools;
impl ProgramReconstructor for FunctionInliningVisitor<'_> {
fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
self.program = input.program_id.name.name;
let order = self
.state
.call_graph
.post_order()
.unwrap()
.into_iter()
.filter_map(|location| (location.program == self.program).then_some(location.path))
.collect_vec();
for function_name in order {
if let Some(function) = self.function_map.shift_remove(&function_name) {
let reconstructed_function = self.reconstruct_function(function);
self.reconstructed_functions.push((function_name.clone(), reconstructed_function));
}
}
assert!(self.function_map.is_empty(), "All functions in the program should have been processed.");
let constructor = input.constructor.map(|constructor| self.reconstruct_constructor(constructor));
let functions = core::mem::take(&mut self.reconstructed_functions)
.iter()
.filter_map(|(path, f)| {
path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, f.clone()))
})
.collect();
ProgramScope {
program_id: input.program_id,
structs: input.structs,
mappings: input.mappings,
storage_variables: input.storage_variables,
constructor,
functions,
consts: input.consts,
span: input.span,
}
}
fn reconstruct_function(&mut self, input: Function) -> Function {
Function {
annotations: input.annotations,
variant: input.variant,
identifier: input.identifier,
const_parameters: input.const_parameters,
input: input.input,
output: input.output,
output_type: input.output_type,
block: {
self.is_async = input.variant.is_async_function();
let block = self.reconstruct_block(input.block).0;
self.is_async = false;
block
},
span: input.span,
id: input.id,
}
}
fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
Constructor {
annotations: input.annotations,
block: {
self.is_async = true;
let block = self.reconstruct_block(input.block).0;
self.is_async = false;
block
},
span: input.span,
id: input.id,
}
}
fn reconstruct_program(&mut self, input: Program) -> Program {
input
.modules
.iter()
.flat_map(|(module_path, m)| {
m.functions.iter().map(move |(name, f)| {
(module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
})
})
.chain(
input
.program_scopes
.iter()
.flat_map(|(_, scope)| scope.functions.iter().map(|(name, f)| (vec![*name], f.clone()))),
)
.for_each(|(full_name, f)| {
self.function_map.insert(full_name, f);
});
Program {
program_scopes: input
.program_scopes
.into_iter()
.map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
.collect(),
..input
}
}
}