wasm-split 0.1.0

Code splitting for WebAssembly
Documentation
/**
 * Copyright 2019 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
use crate::spliterror::{Result, SplitError};

use std::collections::HashSet;

use parity_wasm::elements;

use crate::callgraph::CallGraph;

pub trait ParityWasmExt {
    fn find_function_type_for_export<'a>(
        &'a self,
        f: &'a elements::ExportEntry,
    ) -> Result<&'a elements::FunctionType>;
    fn call_graph_edges_for_func(&self, body: &elements::FuncBody) -> Result<HashSet<u32>>;
    fn func_by_idx<'a>(&'a self, func_idx: u32) -> Result<&'a elements::FuncBody>;
    fn call_graph(&self) -> Result<CallGraph>;
    fn exported_funcs(&self) -> Result<Vec<(&str, u32)>>;
    fn sort_sections(&mut self);
    fn ensure_elements_section(&mut self) -> &mut elements::ElementSection;
    fn ensure_function_section(&mut self) -> &mut elements::FunctionSection;
    fn ensure_import_section(&mut self) -> &mut elements::ImportSection;
}

impl ParityWasmExt for elements::Module {
    fn find_function_type_for_export<'a>(
        &'a self,
        f: &'a elements::ExportEntry,
    ) -> Result<&'a elements::FunctionType> {
        let func_id = match f.internal() {
            elements::Internal::Function(func_id) => Ok(func_id),
            _ => Err(SplitError::NotAFunction),
        }?;

        let type_section = self.type_section().ok_or(SplitError::MissingTypeSection)?;
        let typ = &type_section
            .types()
            .get(*func_id as usize)
            .ok_or(SplitError::NoTypeWithIndex(*func_id))?;
        match typ {
            elements::Type::Function(func_type) => Ok(func_type),
            _ => Err(SplitError::NotAFunction),
        }
    }

    fn call_graph_edges_for_func(&self, body: &elements::FuncBody) -> Result<HashSet<u32>> {
        let mut deps = HashSet::new();
        let failure = body
            .code()
            .elements()
            .iter()
            .map(|instruction| -> Result<()> {
                match instruction {
                    elements::Instruction::Call(func_idx) => {
                        deps.insert(*func_idx);
                        Ok(())
                    }
                    elements::Instruction::CallIndirect(_, _) => {
                        println!(
                            "This module has indirect call. This module might be code-splittable."
                        );
                        Ok(())
                    }
                    _ => Ok(()),
                }
            })
            .find(|r| r.is_err());
        failure.unwrap_or(Ok(())).and(Ok(deps))
    }

    fn func_by_idx<'a>(&'a self, func_idx: u32) -> Result<&'a elements::FuncBody> {
        let code_section = self.code_section().ok_or(SplitError::MissingCodeSection)?;
        code_section
            .bodies()
            .get(func_idx as usize)
            .ok_or(SplitError::NoFunctionWithIndex(func_idx))
    }

    fn call_graph(&self) -> Result<CallGraph> {
        let code_section = self.code_section().ok_or(SplitError::MissingCodeSection)?;
        code_section
            .bodies()
            .iter()
            .map(|func_body| self.call_graph_edges_for_func(func_body))
            .enumerate()
            .fold(Ok(CallGraph::new()), |deps_map, (idx, edges)| match edges {
                Ok(mut edges) => deps_map.map(|mut map| {
                    // Each function has itself as a dependency
                    edges.insert(idx as u32);
                    map.0.insert(idx as u32, edges);
                    map
                }),
                Err(err) => Err(deps_map.err().unwrap_or(err)),
            })
    }

    fn exported_funcs(&self) -> Result<Vec<(&str, u32)>> {
        let export_section = self
            .export_section()
            .ok_or(SplitError::MissingExportSection)?;

        Ok(export_section
            .entries()
            .iter()
            .filter_map(|export| match export.internal() {
                elements::Internal::Function(id) => Some((export.field(), *id)),
                _ => None,
            })
            .collect())
    }

    fn sort_sections(&mut self) {
        self.sections_mut().sort_by(section_cmp);
    }

    fn ensure_function_section(&mut self) -> &mut elements::FunctionSection {
        if self.function_section_mut().is_none() {
            let sections = self.sections_mut();
            sections.push(elements::Section::Function(
                elements::FunctionSection::with_entries(vec![]),
            ));
        }

        self.sort_sections();
        self.function_section_mut().unwrap()
    }

    fn ensure_elements_section(&mut self) -> &mut elements::ElementSection {
        if self.elements_section_mut().is_none() {
            let sections = self.sections_mut();
            sections.push(elements::Section::Element(
                elements::ElementSection::with_entries(vec![]),
            ));
        }

        self.sort_sections();
        self.elements_section_mut().unwrap()
    }

    fn ensure_import_section(&mut self) -> &mut elements::ImportSection {
        if self.import_section_mut().is_none() {
            let sections = self.sections_mut();
            sections.push(elements::Section::Import(
                elements::ImportSection::with_entries(vec![]),
            ));
        }

        self.sort_sections();
        self.import_section_mut().unwrap()
    }
}

pub fn maybe_exported_function_id(entry: &elements::ExportEntry) -> Option<u32> {
    match entry.internal() {
        elements::Internal::Function(id) => Some(*id),
        _ => None,
    }
}

fn section_order(section: &elements::Section) -> isize {
    match section {
        elements::Section::Type(_) => 0,
        elements::Section::Import(_) => 1,
        elements::Section::Function(_) => 2,
        elements::Section::Table(_) => 3,
        elements::Section::Memory(_) => 4,
        elements::Section::Global(_) => 5,
        elements::Section::Export(_) => 6,
        elements::Section::Start(_) => 7,
        elements::Section::Element(_) => 8,
        elements::Section::Code(_) => 9,
        elements::Section::Data(_) => 10,
        _ => -1,
    }
}

fn section_cmp(section_a: &elements::Section, section_b: &elements::Section) -> std::cmp::Ordering {
    section_order(section_a).cmp(&section_order(section_b))
}