use crate::spliterror::{Result, SplitError};
use std::collections::HashSet;
use parity_wasm::elements;
use crate::callgraph::CallGraph;
pub trait ParityWasmExt {
fn find_function_type_for_export<'a>(
&'a self,
f: &'a elements::ExportEntry,
) -> Result<&'a elements::FunctionType>;
fn call_graph_edges_for_func(&self, body: &elements::FuncBody) -> Result<HashSet<u32>>;
fn func_by_idx<'a>(&'a self, func_idx: u32) -> Result<&'a elements::FuncBody>;
fn call_graph(&self) -> Result<CallGraph>;
fn exported_funcs(&self) -> Result<Vec<(&str, u32)>>;
fn sort_sections(&mut self);
fn ensure_elements_section(&mut self) -> &mut elements::ElementSection;
fn ensure_function_section(&mut self) -> &mut elements::FunctionSection;
fn ensure_import_section(&mut self) -> &mut elements::ImportSection;
}
impl ParityWasmExt for elements::Module {
fn find_function_type_for_export<'a>(
&'a self,
f: &'a elements::ExportEntry,
) -> Result<&'a elements::FunctionType> {
let func_id = match f.internal() {
elements::Internal::Function(func_id) => Ok(func_id),
_ => Err(SplitError::NotAFunction),
}?;
let type_section = self.type_section().ok_or(SplitError::MissingTypeSection)?;
let typ = &type_section
.types()
.get(*func_id as usize)
.ok_or(SplitError::NoTypeWithIndex(*func_id))?;
match typ {
elements::Type::Function(func_type) => Ok(func_type),
_ => Err(SplitError::NotAFunction),
}
}
fn call_graph_edges_for_func(&self, body: &elements::FuncBody) -> Result<HashSet<u32>> {
let mut deps = HashSet::new();
let failure = body
.code()
.elements()
.iter()
.map(|instruction| -> Result<()> {
match instruction {
elements::Instruction::Call(func_idx) => {
deps.insert(*func_idx);
Ok(())
}
elements::Instruction::CallIndirect(_, _) => {
println!(
"This module has indirect call. This module might be code-splittable."
);
Ok(())
}
_ => Ok(()),
}
})
.find(|r| r.is_err());
failure.unwrap_or(Ok(())).and(Ok(deps))
}
fn func_by_idx<'a>(&'a self, func_idx: u32) -> Result<&'a elements::FuncBody> {
let code_section = self.code_section().ok_or(SplitError::MissingCodeSection)?;
code_section
.bodies()
.get(func_idx as usize)
.ok_or(SplitError::NoFunctionWithIndex(func_idx))
}
fn call_graph(&self) -> Result<CallGraph> {
let code_section = self.code_section().ok_or(SplitError::MissingCodeSection)?;
code_section
.bodies()
.iter()
.map(|func_body| self.call_graph_edges_for_func(func_body))
.enumerate()
.fold(Ok(CallGraph::new()), |deps_map, (idx, edges)| match edges {
Ok(mut edges) => deps_map.map(|mut map| {
edges.insert(idx as u32);
map.0.insert(idx as u32, edges);
map
}),
Err(err) => Err(deps_map.err().unwrap_or(err)),
})
}
fn exported_funcs(&self) -> Result<Vec<(&str, u32)>> {
let export_section = self
.export_section()
.ok_or(SplitError::MissingExportSection)?;
Ok(export_section
.entries()
.iter()
.filter_map(|export| match export.internal() {
elements::Internal::Function(id) => Some((export.field(), *id)),
_ => None,
})
.collect())
}
fn sort_sections(&mut self) {
self.sections_mut().sort_by(section_cmp);
}
fn ensure_function_section(&mut self) -> &mut elements::FunctionSection {
if self.function_section_mut().is_none() {
let sections = self.sections_mut();
sections.push(elements::Section::Function(
elements::FunctionSection::with_entries(vec![]),
));
}
self.sort_sections();
self.function_section_mut().unwrap()
}
fn ensure_elements_section(&mut self) -> &mut elements::ElementSection {
if self.elements_section_mut().is_none() {
let sections = self.sections_mut();
sections.push(elements::Section::Element(
elements::ElementSection::with_entries(vec![]),
));
}
self.sort_sections();
self.elements_section_mut().unwrap()
}
fn ensure_import_section(&mut self) -> &mut elements::ImportSection {
if self.import_section_mut().is_none() {
let sections = self.sections_mut();
sections.push(elements::Section::Import(
elements::ImportSection::with_entries(vec![]),
));
}
self.sort_sections();
self.import_section_mut().unwrap()
}
}
pub fn maybe_exported_function_id(entry: &elements::ExportEntry) -> Option<u32> {
match entry.internal() {
elements::Internal::Function(id) => Some(*id),
_ => None,
}
}
fn section_order(section: &elements::Section) -> isize {
match section {
elements::Section::Type(_) => 0,
elements::Section::Import(_) => 1,
elements::Section::Function(_) => 2,
elements::Section::Table(_) => 3,
elements::Section::Memory(_) => 4,
elements::Section::Global(_) => 5,
elements::Section::Export(_) => 6,
elements::Section::Start(_) => 7,
elements::Section::Element(_) => 8,
elements::Section::Code(_) => 9,
elements::Section::Data(_) => 10,
_ => -1,
}
}
fn section_cmp(section_a: &elements::Section, section_b: &elements::Section) -> std::cmp::Ordering {
section_order(section_a).cmp(§ion_order(section_b))
}