use anyhow::Result;
use rand::Rng;
use rand::RngCore;
use rand::SeedableRng;
use wirm::ir::function::FunctionBuilder;
use wirm::ir::id::{FunctionID, GlobalID, LocalID};
use wirm::ir::module::module_functions::FuncKind;
use wirm::ir::types::{InitExpr, Instructions, Value};
use wirm::module_builder::AddLocal;
use wirm::wasmparser::{MemArg, Operator, Validator};
use wirm::{DataType, InitInstr, Module, Opcode};
use crate::constants::{AFL_COVERAGE_MAP_SIZE, API_VERSION_IC0, COVERAGE_FN_EXPORT_NAME};
pub struct InstrumentationArgs {
pub wasm_bytes: Vec<u8>,
pub history_size: usize,
pub seed: Seed,
}
#[derive(Copy, Clone, Debug)]
pub enum Seed {
Random,
Static(u32),
}
pub static mut COVERAGE_MAP: &mut [u8] = &mut [0; AFL_COVERAGE_MAP_SIZE as usize];
pub fn instrument_wasm_for_fuzzing(instrumentation_args: InstrumentationArgs) -> Vec<u8> {
assert!(
matches!(instrumentation_args.history_size, 1 | 2 | 4 | 8),
"History size must be 1, 2, 4, or 8"
);
let mut module = Module::parse(&instrumentation_args.wasm_bytes, false)
.expect("Failed to parse module with wirm");
instrument_for_afl(&mut module, &instrumentation_args)
.expect("Unable to instrument wasm module for AFL");
let buf = vec![0u8; AFL_COVERAGE_MAP_SIZE as usize * instrumentation_args.history_size]
.into_boxed_slice();
let buf: &'static mut [u8] = Box::leak(buf);
unsafe {
COVERAGE_MAP = buf;
}
let instrumented_wasm = module.encode();
validate_wasm(&instrumented_wasm).expect("Wasm is not valid");
instrumented_wasm
}
fn instrument_for_afl(
module: &mut Module<'_>,
instrumentation_args: &InstrumentationArgs,
) -> Result<()> {
let (afl_prev_loc_indices, afl_mem_ptr_idx) =
inject_globals(module, instrumentation_args.history_size);
println!(
" -> Injected globals: prev_locs @ indices {afl_prev_loc_indices:?}, mem_ptr @ index {afl_mem_ptr_idx:?}"
);
inject_afl_coverage_export(module, instrumentation_args.history_size, afl_mem_ptr_idx)?;
println!(" -> Injected `canister_update __export_coverage_for_afl` function.");
instrument_branches(
module,
&afl_prev_loc_indices,
afl_mem_ptr_idx,
instrumentation_args.seed,
);
println!(" -> Instrumented branch instructions in all functions.");
Ok(())
}
fn inject_globals(module: &mut Module<'_>, history_size: usize) -> (Vec<GlobalID>, GlobalID) {
let mut afl_prev_loc_indices = Vec::with_capacity(history_size);
for _ in 0..history_size {
let global_id = module.add_global(
InitExpr::new(vec![InitInstr::Value(Value::I32(0))]),
DataType::I32,
true,
false,
);
afl_prev_loc_indices.push(global_id);
}
let afl_mem_ptr_idx = module.add_global(
InitExpr::new(vec![InitInstr::Value(Value::I32(0))]),
DataType::I32,
false,
false,
);
(afl_prev_loc_indices, afl_mem_ptr_idx)
}
fn inject_afl_coverage_export<'a>(
module: &mut Module<'a>,
history_size: usize,
afl_mem_ptr_idx: GlobalID,
) -> Result<()> {
let (msg_reply_data_append_idx, msg_reply_idx) = ensure_ic0_imports(module)?;
let mut func_builder = FunctionBuilder::new(&[], &[]);
func_builder
.global_get(afl_mem_ptr_idx)
.i32_const(AFL_COVERAGE_MAP_SIZE * history_size as i32)
.call(msg_reply_data_append_idx)
.call(msg_reply_idx)
.global_get(afl_mem_ptr_idx)
.i32_const(0)
.i32_const(AFL_COVERAGE_MAP_SIZE * history_size as i32)
.memory_fill(0);
let coverage_function_id = func_builder.finish_module(module);
let export_name = format!("canister_update {COVERAGE_FN_EXPORT_NAME}");
module
.exports
.add_export_func(export_name, coverage_function_id.0);
Ok(())
}
fn instrument_branches(
module: &mut Module<'_>,
afl_prev_loc_indices: &[GlobalID],
afl_mem_ptr_idx: GlobalID,
seed: Seed,
) {
let instrumentation_function =
afl_instrumentation_slice(module, afl_prev_loc_indices, afl_mem_ptr_idx);
let seed = match seed {
Seed::Random => rand::thread_rng().next_u32(),
Seed::Static(s) => s,
};
println!("The seed used for instrumentation is {seed}");
let mut rng = rand::rngs::StdRng::seed_from_u64(seed as u64);
let mut create_instrumentation_ops = |ops: &mut Vec<Operator>| {
let curr_location = rng.gen_range(0..AFL_COVERAGE_MAP_SIZE);
ops.push(Operator::I32Const {
value: curr_location,
});
ops.push(Operator::Call {
function_index: instrumentation_function.0,
});
};
for (function_index, function) in module.functions.iter_mut().enumerate() {
if matches!(function.kind(), FuncKind::Local(_))
&& FunctionID(function_index as u32) != instrumentation_function
{
let local_function = function.unwrap_local_mut();
let mut new_instructions = Vec::with_capacity(local_function.body.num_instructions * 2);
create_instrumentation_ops(&mut new_instructions);
for instruction in local_function.body.instructions.get_ops() {
match instruction {
Operator::Block { .. }
| Operator::Loop { .. }
| Operator::If { .. }
| Operator::Else => {
new_instructions.push(instruction.clone());
create_instrumentation_ops(&mut new_instructions);
}
Operator::Br { .. }
| Operator::BrIf { .. }
| Operator::BrTable { .. }
| Operator::Return => {
create_instrumentation_ops(&mut new_instructions);
new_instructions.push(instruction.clone());
}
_ => new_instructions.push(instruction.clone()),
}
}
local_function.body.instructions = Instructions::new(new_instructions);
}
}
}
fn afl_instrumentation_slice(
module: &mut Module<'_>,
afl_prev_loc_indices: &[GlobalID],
afl_mem_ptr_idx: GlobalID,
) -> FunctionID {
let mut func_builder = FunctionBuilder::new(&[DataType::I32], &[]);
let curr_location = LocalID(0);
let afl_local_idx = func_builder.add_local(DataType::I32);
func_builder.local_get(curr_location);
for &prev_loc_idx in afl_prev_loc_indices {
func_builder.global_get(prev_loc_idx).i32_xor();
}
func_builder
.global_get(afl_mem_ptr_idx)
.i32_add()
.local_tee(afl_local_idx)
.local_get(afl_local_idx)
.i32_load8_u(MemArg {
offset: 0,
align: 0,
memory: 0,
max_align: 0,
})
.i32_const(1)
.i32_add()
.i32_store8(MemArg {
offset: 0,
align: 0,
memory: 0,
max_align: 0,
});
for i in (1..afl_prev_loc_indices.len()).rev() {
func_builder
.global_get(afl_prev_loc_indices[i - 1])
.i32_const(1)
.i32_shr_unsigned()
.global_set(afl_prev_loc_indices[i]);
}
func_builder
.local_get(curr_location)
.i32_const(1)
.i32_shr_unsigned()
.global_set(afl_prev_loc_indices[0]);
func_builder.finish_module(module)
}
fn ensure_ic0_imports(module: &mut Module<'_>) -> Result<(FunctionID, FunctionID)> {
let mut data_append_idx = module.imports.get_func(
API_VERSION_IC0.to_string(),
"msg_reply_data_append".to_string(),
);
let mut reply_idx = module
.imports
.get_func(API_VERSION_IC0.to_string(), "msg_reply".to_string());
if data_append_idx.is_none() {
let type_id = module
.types
.add_func_type(&[DataType::I32, DataType::I32], &[]);
let (func_index, _) = module.add_import_func(
API_VERSION_IC0.to_string(),
"msg_reply_data_append".to_string(),
type_id,
);
data_append_idx = Some(func_index);
}
if reply_idx.is_none() {
let type_id = module.types.add_func_type(&[], &[]);
let (func_index, _) = module.add_import_func(
API_VERSION_IC0.to_string(),
"msg_reply".to_string(),
type_id,
);
reply_idx = Some(func_index);
}
Ok((data_append_idx.unwrap(), reply_idx.unwrap()))
}
fn validate_wasm(wasm_bytes: &[u8]) -> Result<()> {
let mut validator = Validator::new();
validator.validate_all(wasm_bytes)?;
println!("Validation of instrumented Wasm successful.");
Ok(())
}