extern crate parity_wasm;
mod callgraph;
mod parity_wasm_ext;
mod spliterror;
use parity_wasm::elements;
use parity_wasm_ext::*;
pub use spliterror::{Result, SplitError};
use std::collections::{HashMap, HashSet};
pub fn split_module(
module: &elements::Module,
entry_name: &str,
module_name: &str,
field_name: &str,
) -> Result<(elements::Module, elements::Module)> {
let (main_funcs, side_funcs, cross_calls, _call_graph) = split_funcs(module, entry_name)?;
let mut main_module = module.clone();
truncate_funcs(&mut main_module, &side_funcs)?;
remove_func_exports(&mut main_module, &side_funcs)?;
let offset = expose_cross_calls(&mut main_module, &cross_calls, field_name)?;
main_module.sort_sections();
let mut side_module = module.clone();
truncate_funcs(&mut side_module, &main_funcs)?;
remove_func_exports(&mut side_module, &main_funcs)?;
rewrite_cross_calls(&mut side_module, &cross_calls, offset)?;
remove_table(&mut side_module);
add_table_import(&mut side_module, module_name, field_name);
side_module.sort_sections();
Ok((main_module, side_module))
}
fn remove_table(module: &mut elements::Module) {
module.sections_mut().retain(|section| match section {
elements::Section::Table(_) => false,
_ => true,
});
let export_section = module.export_section_mut();
if let Some(export_section) = export_section {
export_section
.entries_mut()
.retain(|entry| match entry.internal() {
elements::Internal::Table(_) => false,
_ => true,
});
};
}
fn add_table_import(module: &mut elements::Module, module_name: &str, field_name: &str) {
let import_section = module.ensure_import_section();
import_section
.entries_mut()
.push(elements::ImportEntry::new(
String::from(module_name),
String::from(field_name),
elements::External::Table(elements::TableType::new(0, None)),
));
}
fn rewrite_cross_calls(
module: &mut elements::Module,
cross_calls: &HashSet<u32>,
offset: u32,
) -> Result<()> {
let cross_call_map: HashMap<u32, u32> = cross_calls
.iter()
.clone()
.enumerate()
.map(|(idx, fid)| (idx as u32 + offset, fid))
.fold(HashMap::new(), |mut map, (idx, fid)| {
map.insert(*fid, idx);
map
});
let func_bodies = module
.code_section_mut()
.ok_or(SplitError::MissingCodeSection)?
.bodies_mut();
for func_body in func_bodies {
let instructions = func_body.code_mut().elements_mut();
*instructions = instructions
.iter()
.cloned()
.flat_map(|instruction| match instruction {
elements::Instruction::Call(id) if cross_call_map.contains_key(&id) => {
vec![
elements::Instruction::I32Const(*cross_call_map.get(&id).unwrap() as i32),
elements::Instruction::CallIndirect(id, 0),
]
}
x => vec![x],
})
.collect();
}
Ok(())
}
fn split_funcs(
module: &elements::Module,
entry_name: &str,
) -> Result<(
HashSet<u32>,
HashSet<u32>,
HashSet<u32>,
callgraph::CallGraph,
)> {
let call_graph = module.call_graph().map(|cg| cg.flatten())?;
let exported_funcs = module.exported_funcs()?;
let (_, entry_func_id) = exported_funcs
.iter()
.find(|func| func.0 == entry_name)
.ok_or(SplitError::NoFunctionWithName(String::from(entry_name)))?;
let main_funcs = call_graph.get(*entry_func_id).unwrap().clone();
let side_funcs: HashSet<u32> = call_graph
.all_funcs()
.difference(&main_funcs)
.cloned()
.collect();
let cross_calls = determine_cross_calls(&module, &main_funcs, &side_funcs)?;
Ok((main_funcs, side_funcs, cross_calls, call_graph))
}
fn expose_cross_calls(
module: &mut elements::Module,
cross_calls: &HashSet<u32>,
field_name: &str,
) -> Result<u32> {
let offset = increase_table_size(module, cross_calls.len())?;
let exports = module
.export_section_mut()
.ok_or(SplitError::MissingExportSection)?
.entries_mut();
exports.push(elements::ExportEntry::new(
String::from(field_name),
elements::Internal::Table(0),
));
let element_entries = module.ensure_elements_section().entries_mut();
let init_expr = elements::InitExpr::new(vec![elements::Instruction::I32Const(offset as i32)]);
element_entries.push(elements::ElementSegment::new(
0,
Some(init_expr),
cross_calls.iter().cloned().collect(),
true,
));
Ok(offset)
}
fn increase_table_size(module: &mut elements::Module, delta: usize) -> Result<u32> {
if let Some(table_section) = module.table_section() {
if table_section.entries().len() > 1 {
return Err(SplitError::TooManyTables);
}
}
let old_limits = module
.table_section()
.map(|table_section| table_section.entries()[0].limits().clone())
.unwrap_or(elements::ResizableLimits::new(0, None));
let sections = module.sections_mut();
sections.retain(|section| match section {
elements::Section::Table(_) => false,
_ => true,
});
sections.push(elements::Section::Table(
elements::TableSection::with_entries(vec![elements::TableType::new(
old_limits.initial() + delta as u32,
old_limits.maximum().map(|max| max + delta as u32),
)]),
));
Ok(old_limits.initial())
}
fn determine_cross_calls(
module: &elements::Module,
main_funcs: &HashSet<u32>,
side_funcs: &HashSet<u32>,
) -> Result<HashSet<u32>> {
let mut cross_calls: HashSet<u32> = HashSet::new();
let func_bodies = module
.code_section()
.ok_or(SplitError::MissingCodeSection)?
.bodies();
for side_func in side_funcs {
for instruction in func_bodies[*side_func as usize].code().elements() {
match instruction {
elements::Instruction::Call(id) if main_funcs.contains(id) => {
cross_calls.insert(*id);
}
_ => (),
};
}
}
Ok(cross_calls)
}
fn truncate_funcs(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
let empty_func_id = inject_empty_function_type(module)?;
let function_entries = module
.function_section_mut()
.ok_or(SplitError::MissingFunctionSection)?
.entries_mut();
function_entries
.iter_mut()
.enumerate()
.filter(|(idx, _func)| funcs.contains(&(*idx as u32)))
.for_each(|(_idx, func)| {
*func.type_ref_mut() = empty_func_id;
});
let function_bodies = module
.code_section_mut()
.ok_or(SplitError::MissingCodeSection)?
.bodies_mut();
function_bodies
.iter_mut()
.enumerate()
.filter(|(idx, _body)| funcs.contains(&(*idx as u32)))
.for_each(|(_idx, body)| {
body.locals_mut().truncate(0);
let ops = body.code_mut().elements_mut();
ops.truncate(1);
ops[0] = elements::Instruction::End;
});
Ok(())
}
fn remove_func_exports(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
let export_entries = module
.export_section_mut()
.ok_or(SplitError::MissingExportSection)?
.entries_mut();
export_entries.retain(|entry| match maybe_exported_function_id(entry) {
Some(id) => !funcs.contains(&id),
None => true,
});
Ok(())
}
fn inject_empty_function_type(module: &mut elements::Module) -> spliterror::Result<u32> {
let types = module
.type_section_mut()
.ok_or(SplitError::MissingTypeSection)?
.types_mut();
let empty_function_type_idx = types
.iter()
.enumerate()
.filter_map(|(idx, typ)| match typ {
elements::Type::Function(ftype) => Some((idx as u32, ftype)),
_ => None,
})
.find(|(_idx, ftype)| ftype.params().len() == 0 && ftype.return_type().is_none())
.map(|(idx, _ftype)| idx);
Ok(empty_function_type_idx.unwrap_or_else(|| {
types.push(elements::Type::Function(elements::FunctionType::new(
vec![],
None,
)));
types.len() as u32 - 1
}))
}