wasm-split 0.1.0

Code splitting for WebAssembly
Documentation
/**
 * Copyright 2019 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
extern crate parity_wasm;

mod callgraph;
mod parity_wasm_ext;
mod spliterror;

use parity_wasm::elements;
use parity_wasm_ext::*;
pub use spliterror::{Result, SplitError};
use std::collections::{HashMap, HashSet};

pub fn split_module(
    module: &elements::Module,
    entry_name: &str,
    module_name: &str,
    field_name: &str,
) -> Result<(elements::Module, elements::Module)> {
    let (main_funcs, side_funcs, cross_calls, _call_graph) = split_funcs(module, entry_name)?;

    let mut main_module = module.clone();
    truncate_funcs(&mut main_module, &side_funcs)?;
    remove_func_exports(&mut main_module, &side_funcs)?;
    let offset = expose_cross_calls(&mut main_module, &cross_calls, field_name)?;
    main_module.sort_sections();

    let mut side_module = module.clone();
    truncate_funcs(&mut side_module, &main_funcs)?;
    remove_func_exports(&mut side_module, &main_funcs)?;
    rewrite_cross_calls(&mut side_module, &cross_calls, offset)?;
    remove_table(&mut side_module);
    add_table_import(&mut side_module, module_name, field_name);
    side_module.sort_sections();

    Ok((main_module, side_module))
}

fn remove_table(module: &mut elements::Module) {
    module.sections_mut().retain(|section| match section {
        elements::Section::Table(_) => false,
        _ => true,
    });
    let export_section = module.export_section_mut();
    if let Some(export_section) = export_section {
        export_section
            .entries_mut()
            .retain(|entry| match entry.internal() {
                elements::Internal::Table(_) => false,
                _ => true,
            });
    };
}

fn add_table_import(module: &mut elements::Module, module_name: &str, field_name: &str) {
    let import_section = module.ensure_import_section();
    import_section
        .entries_mut()
        .push(elements::ImportEntry::new(
            String::from(module_name),
            String::from(field_name),
            elements::External::Table(elements::TableType::new(0, None)),
        ));
}

fn rewrite_cross_calls(
    module: &mut elements::Module,
    cross_calls: &HashSet<u32>,
    offset: u32,
) -> Result<()> {
    let cross_call_map: HashMap<u32, u32> = cross_calls
        .iter()
        .clone()
        .enumerate()
        .map(|(idx, fid)| (idx as u32 + offset, fid))
        .fold(HashMap::new(), |mut map, (idx, fid)| {
            map.insert(*fid, idx);
            map
        });
    let func_bodies = module
        .code_section_mut()
        .ok_or(SplitError::MissingCodeSection)?
        .bodies_mut();
    for func_body in func_bodies {
        let instructions = func_body.code_mut().elements_mut();
        *instructions = instructions
            .iter()
            .cloned()
            .flat_map(|instruction| match instruction {
                elements::Instruction::Call(id) if cross_call_map.contains_key(&id) => {
                    vec![
                        elements::Instruction::I32Const(*cross_call_map.get(&id).unwrap() as i32),
                        // The current iteration of WebAssembly only supports one table
                        // and implicitly works on table idx 0
                        elements::Instruction::CallIndirect(id, 0),
                    ]
                }
                x => vec![x],
            })
            .collect();
    }
    Ok(())
}

fn split_funcs(
    module: &elements::Module,
    entry_name: &str,
) -> Result<(
    HashSet<u32>,
    HashSet<u32>,
    HashSet<u32>,
    callgraph::CallGraph,
)> {
    let call_graph = module.call_graph().map(|cg| cg.flatten())?;
    let exported_funcs = module.exported_funcs()?;
    let (_, entry_func_id) = exported_funcs
        .iter()
        .find(|func| func.0 == entry_name)
        .ok_or(SplitError::NoFunctionWithName(String::from(entry_name)))?;

    let main_funcs = call_graph.get(*entry_func_id).unwrap().clone();
    let side_funcs: HashSet<u32> = call_graph
        .all_funcs()
        .difference(&main_funcs)
        .cloned()
        .collect();

    let cross_calls = determine_cross_calls(&module, &main_funcs, &side_funcs)?;
    Ok((main_funcs, side_funcs, cross_calls, call_graph))
}

fn expose_cross_calls(
    module: &mut elements::Module,
    cross_calls: &HashSet<u32>,
    field_name: &str,
) -> Result<u32> {
    let offset = increase_table_size(module, cross_calls.len())?;
    let exports = module
        .export_section_mut()
        .ok_or(SplitError::MissingExportSection)?
        .entries_mut();
    // You can export the same table multiple times with different names.
    exports.push(elements::ExportEntry::new(
        String::from(field_name),
        elements::Internal::Table(0),
    ));
    let element_entries = module.ensure_elements_section().entries_mut();
    let init_expr = elements::InitExpr::new(vec![elements::Instruction::I32Const(offset as i32)]);
    element_entries.push(elements::ElementSegment::new(
        0,
        Some(init_expr),
        cross_calls.iter().cloned().collect(),
        true,
    ));
    Ok(offset)
}

fn increase_table_size(module: &mut elements::Module, delta: usize) -> Result<u32> {
    if let Some(table_section) = module.table_section() {
        // Current iteration of WebAssembly allows at most one table.
        if table_section.entries().len() > 1 {
            return Err(SplitError::TooManyTables);
        }
    }
    let old_limits = module
        .table_section()
        .map(|table_section| table_section.entries()[0].limits().clone())
        .unwrap_or(elements::ResizableLimits::new(0, None));

    let sections = module.sections_mut();
    // Remove old table section
    sections.retain(|section| match section {
        elements::Section::Table(_) => false,
        _ => true,
    });
    sections.push(elements::Section::Table(
        elements::TableSection::with_entries(vec![elements::TableType::new(
            old_limits.initial() + delta as u32,
            old_limits.maximum().map(|max| max + delta as u32),
        )]),
    ));
    Ok(old_limits.initial())
}

fn determine_cross_calls(
    module: &elements::Module,
    main_funcs: &HashSet<u32>,
    side_funcs: &HashSet<u32>,
) -> Result<HashSet<u32>> {
    let mut cross_calls: HashSet<u32> = HashSet::new();
    let func_bodies = module
        .code_section()
        .ok_or(SplitError::MissingCodeSection)?
        .bodies();
    for side_func in side_funcs {
        for instruction in func_bodies[*side_func as usize].code().elements() {
            match instruction {
                elements::Instruction::Call(id) if main_funcs.contains(id) => {
                    cross_calls.insert(*id);
                }
                _ => (),
            };
        }
    }
    Ok(cross_calls)
}

fn truncate_funcs(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
    let empty_func_id = inject_empty_function_type(module)?;
    let function_entries = module
        .function_section_mut()
        .ok_or(SplitError::MissingFunctionSection)?
        .entries_mut();
    function_entries
        .iter_mut()
        .enumerate()
        .filter(|(idx, _func)| funcs.contains(&(*idx as u32)))
        .for_each(|(_idx, func)| {
            *func.type_ref_mut() = empty_func_id;
        });
    let function_bodies = module
        .code_section_mut()
        .ok_or(SplitError::MissingCodeSection)?
        .bodies_mut();
    function_bodies
        .iter_mut()
        .enumerate()
        .filter(|(idx, _body)| funcs.contains(&(*idx as u32)))
        .for_each(|(_idx, body)| {
            // Make function empty, which is almost as good as removing it but leaves the
            // indices in place. `wasm-opt` and similar tools can do the
            // remaining optimizations.
            body.locals_mut().truncate(0);
            let ops = body.code_mut().elements_mut();
            ops.truncate(1);
            ops[0] = elements::Instruction::End;
        });
    Ok(())
}

fn remove_func_exports(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
    let export_entries = module
        .export_section_mut()
        .ok_or(SplitError::MissingExportSection)?
        .entries_mut();
    export_entries.retain(|entry| match maybe_exported_function_id(entry) {
        Some(id) => !funcs.contains(&id),
        None => true,
    });

    Ok(())
}

fn inject_empty_function_type(module: &mut elements::Module) -> spliterror::Result<u32> {
    let types = module
        .type_section_mut()
        .ok_or(SplitError::MissingTypeSection)?
        .types_mut();

    let empty_function_type_idx = types
        .iter()
        .enumerate()
        .filter_map(|(idx, typ)| match typ {
            elements::Type::Function(ftype) => Some((idx as u32, ftype)),
            _ => None,
        })
        .find(|(_idx, ftype)| ftype.params().len() == 0 && ftype.return_type().is_none())
        .map(|(idx, _ftype)| idx);

    Ok(empty_function_type_idx.unwrap_or_else(|| {
        types.push(elements::Type::Function(elements::FunctionType::new(
            vec![],
            None,
        )));
        types.len() as u32 - 1
    }))
}