use std::string::String;
use std::vec::Vec;
use swasm::elements::{self, Type};
use swasm::builder;
macro_rules! instrument_call {
($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
use $crate::swasm::elements::Instruction::*;
[
GetGlobal($stack_height_global_idx),
I32Const($callee_stack_cost),
I32Add,
SetGlobal($stack_height_global_idx),
GetGlobal($stack_height_global_idx),
I32Const($stack_limit as i32),
I32GtU,
If(elements::BlockType::NoResult),
Unreachable,
End,
Call($callee_idx),
GetGlobal($stack_height_global_idx),
I32Const($callee_stack_cost),
I32Sub,
SetGlobal($stack_height_global_idx),
]
}};
}
mod max_height;
mod thunk;
#[derive(Debug)]
pub struct Error(String);
pub(crate) struct Context {
stack_height_global_idx: Option<u32>,
func_stack_costs: Option<Vec<u32>>,
stack_limit: u32,
}
impl Context {
fn stack_height_global_idx(&self) -> u32 {
self.stack_height_global_idx.expect(
"stack_height_global_idx isn't yet generated;
Did you call `inject_stack_counter_global`",
)
}
fn stack_cost(&self, func_idx: u32) -> Option<u32> {
self.func_stack_costs
.as_ref()
.expect(
"func_stack_costs isn't yet computed;
Did you call `compute_stack_costs`?",
)
.get(func_idx as usize)
.cloned()
}
fn stack_limit(&self) -> u32 {
self.stack_limit
}
}
pub fn inject_limiter(
mut module: elements::Module,
stack_limit: u32,
) -> Result<elements::Module, Error> {
let mut ctx = Context {
stack_height_global_idx: None,
func_stack_costs: None,
stack_limit,
};
generate_stack_height_global(&mut ctx, &mut module);
compute_stack_costs(&mut ctx, &module)?;
instrument_functions(&mut ctx, &mut module)?;
let module = thunk::generate_thunks(&mut ctx, module)?;
Ok(module)
}
fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) {
let global_entry = builder::global()
.value_type()
.i32()
.mutable()
.init_expr(elements::Instruction::I32Const(0))
.build();
for section in module.sections_mut() {
if let elements::Section::Global(ref mut gs) = *section {
gs.entries_mut().push(global_entry);
let stack_height_global_idx = (gs.entries().len() as u32) - 1;
ctx.stack_height_global_idx = Some(stack_height_global_idx);
return;
}
}
module.sections_mut().push(elements::Section::Global(
elements::GlobalSection::with_entries(vec![global_entry]),
));
ctx.stack_height_global_idx = Some(0);
}
fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> {
let func_imports = module.import_count(elements::ImportCountType::Function);
let mut func_stack_costs = vec![0; module.functions_space()];
for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() {
if func_idx >= func_imports {
*func_stack_cost = compute_stack_cost(func_idx as u32, &module)?;
}
}
ctx.func_stack_costs = Some(func_stack_costs);
Ok(())
}
fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
Error("This should be a index of a defined function".into())
})?;
let code_section = module.code_section().ok_or_else(|| {
Error("Due to validation code section should exists".into())
})?;
let body = &code_section
.bodies()
.get(defined_func_idx as usize)
.ok_or_else(|| Error("Function body is out of bounds".into()))?;
let locals_count = body.locals().len() as u32;
let max_stack_height =
max_height::compute(
defined_func_idx,
module
)?;
Ok(locals_count + max_stack_height)
}
fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
for section in module.sections_mut() {
if let elements::Section::Code(ref mut code_section) = *section {
for func_body in code_section.bodies_mut() {
let mut opcodes = func_body.code_mut();
instrument_function(ctx, opcodes)?;
}
}
}
Ok(())
}
fn instrument_function(
ctx: &mut Context,
instructions: &mut elements::Instructions,
) -> Result<(), Error> {
use swasm::elements::Instruction::*;
let mut cursor = 0;
loop {
if cursor >= instructions.elements().len() {
break;
}
enum Action {
InstrumentCall {
callee_idx: u32,
callee_stack_cost: u32,
},
Nop,
}
let action: Action = {
let instruction = &instructions.elements()[cursor];
match *instruction {
Call(ref callee_idx) => {
let callee_stack_cost = ctx
.stack_cost(*callee_idx)
.ok_or_else(||
Error(
format!("Call to function that out-of-bounds: {}", callee_idx)
)
)?;
if callee_stack_cost > 0 {
Action::InstrumentCall {
callee_idx: *callee_idx,
callee_stack_cost,
}
} else {
Action::Nop
}
},
_ => Action::Nop,
}
};
match action {
Action::InstrumentCall { callee_idx, callee_stack_cost } => {
let new_seq = instrument_call!(
callee_idx,
callee_stack_cost as i32,
ctx.stack_height_global_idx(),
ctx.stack_limit()
);
let _ = instructions
.elements_mut()
.splice(cursor..(cursor + 1), new_seq.iter().cloned())
.count();
cursor += new_seq.len();
}
_ => {
cursor += 1;
}
}
}
Ok(())
}
fn resolve_func_type(
func_idx: u32,
module: &elements::Module,
) -> Result<&elements::FunctionType, Error> {
let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
let functions = module
.function_section()
.map(|fs| fs.entries())
.unwrap_or(&[]);
let func_imports = module.import_count(elements::ImportCountType::Function);
let sig_idx = if func_idx < func_imports as u32 {
module
.import_section()
.expect("function import count is not zero; import section must exists; qed")
.entries()
.iter()
.filter_map(|entry| match *entry.external() {
elements::External::Function(ref idx) => Some(*idx),
_ => None,
})
.nth(func_idx as usize)
.expect(
"func_idx is less than function imports count;
nth function import must be `Some`;
qed",
)
} else {
functions
.get(func_idx as usize - func_imports)
.ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
.type_ref()
};
let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| {
Error(format!(
"Signature {} (specified by func {}) isn't defined",
sig_idx, func_idx
))
})?;
Ok(ty)
}
#[cfg(test)]
mod tests {
extern crate wabt;
use swasm::elements;
use super::*;
fn parse_wat(source: &str) -> elements::Module {
elements::deserialize_buffer(&wabt::wat2swasm(source).expect("Failed to wat2swasm"))
.expect("Failed to deserialize the module")
}
fn validate_module(module: elements::Module) {
let binary = elements::serialize(module).expect("Failed to serialize");
wabt::Module::read_binary(&binary, &Default::default())
.expect("Wabt failed to read final binary")
.validate()
.expect("Invalid module");
}
#[test]
fn test_with_params_and_result() {
let module = parse_wat(
r#"
(module
(func (export "i32.add") (param i32 i32) (result i32)
get_local 0
get_local 1
i32.add
)
)
"#,
);
let module = inject_limiter(module, 1024)
.expect("Failed to inject stack counter");
validate_module(module);
}
}