1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
use crate::errors::*;
use parity_wasm::elements::{
    CodeSection, ElementSection, ExportSection, FuncBody, Instruction, Instructions, Internal,
    Module,
};

fn shift_function_ids_in_code_section(
    code_section: &mut CodeSection,
    shift: u32,
) -> Result<(), WError> {
    let code_bodies = code_section.bodies_mut();
    for code_body in code_bodies.iter_mut() {
        let opcodes = code_body.code_mut().elements_mut();
        for opcode in opcodes.iter_mut() {
            if let Instruction::Call(function_id) = *opcode {
                *opcode = Instruction::Call(function_id + shift)
            }
        }
    }
    Ok(())
}

fn shift_function_ids_in_exports_section(export_section: &mut ExportSection, shift: u32) {
    for entry in export_section.entries_mut() {
        let internal = entry.internal_mut();
        if let Internal::Function(function_id) = *internal {
            *internal = Internal::Function(function_id + shift)
        }
    }
}

fn shift_function_ids_in_elements_section(elements_section: &mut ElementSection, shift: u32) {
    for elements_segment in elements_section.entries_mut() {
        for function_id in elements_segment.members_mut() {
            *function_id += shift;
        }
    }
}

pub fn shift_function_ids(module: &mut Module, shift: u32) -> Result<(), WError> {
    shift_function_ids_in_code_section(module.code_section_mut().expect("No code section"), shift)?;
    if let Some(export_section) = module.export_section_mut() {
        shift_function_ids_in_exports_section(export_section, shift)
    }
    if let Some(elements_section) = module.elements_section_mut() {
        shift_function_ids_in_elements_section(elements_section, shift)
    }
    Ok(())
}

fn replace_function_id_in_code_section(code_section: &mut CodeSection, before: u32, after: u32) {
    let code_bodies = code_section.bodies_mut();
    for code_body in code_bodies.iter_mut() {
        let opcodes = code_body.code_mut().elements_mut();
        for opcode in opcodes.iter_mut() {
            match *opcode {
                Instruction::Call(function_id) if function_id == before => {
                    *opcode = Instruction::Call(after)
                }
                _ => {}
            }
        }
    }
}

fn replace_function_id_in_elements_section(
    elements_section: &mut ElementSection,
    before: u32,
    after: u32,
) {
    for elements_segment in elements_section.entries_mut() {
        for function_id in elements_segment.members_mut() {
            if *function_id == before {
                *function_id = after;
            }
        }
    }
}

pub fn replace_function_id(module: &mut Module, before: u32, after: u32) -> Result<(), WError> {
    if let Some(code_section) = module.code_section_mut() {
        replace_function_id_in_code_section(code_section, before, after)
    }

    if let Some(elements_section) = module.elements_section_mut() {
        replace_function_id_in_elements_section(elements_section, before, after)
    };

    Ok(())
}

#[allow(dead_code)]
pub fn disable_function_id(module: &mut Module, function_id: u32) -> Result<(), WError> {
    let base_id = match module.import_section() {
        None => 0,
        Some(import_section) => import_section.entries().len() as u32,
    };
    let code_section = module.code_section_mut().expect("No code section");
    let code_bodies = code_section.bodies_mut();
    let opcodes = Instructions::new(vec![Instruction::Unreachable, Instruction::End]);
    let func_body = FuncBody::new(vec![], opcodes);
    code_bodies[(function_id - base_id) as usize] = func_body;
    Ok(())
}