#![deny(missing_docs, missing_debug_implementations)]
pub fn run(
module: &mut walrus::Module,
memory: walrus::MemoryId,
shadow_stack_pointer: walrus::GlobalId,
to_xform: &[(walrus::FunctionId, usize, Vec<walrus::ValType>)],
) -> Result<Vec<walrus::FunctionId>, anyhow::Error> {
let mut wrappers = Vec::new();
for (func, return_pointer_index, results) in to_xform {
wrappers.push(xform_one(
module,
memory,
shadow_stack_pointer,
*func,
*return_pointer_index,
results,
)?);
}
Ok(wrappers)
}
fn round_up_to_alignment(n: u32, align: u32) -> u32 {
debug_assert!(align.is_power_of_two());
(n + align - 1) & !(align - 1)
}
fn xform_one(
module: &mut walrus::Module,
memory: walrus::MemoryId,
shadow_stack_pointer: walrus::GlobalId,
func: walrus::FunctionId,
return_pointer_index: usize,
results: &[walrus::ValType],
) -> Result<walrus::FunctionId, anyhow::Error> {
if module.globals.get(shadow_stack_pointer).ty != walrus::ValType::I32 {
anyhow::bail!("shadow stack pointer global does not have type `i32`");
}
let mut results_size = 0;
for ty in results {
results_size = match ty {
walrus::ValType::I32 | walrus::ValType::F32 => {
debug_assert_eq!(results_size % 4, 0);
results_size + 4
}
walrus::ValType::I64 | walrus::ValType::F64 => {
round_up_to_alignment(results_size, 8) + 8
}
walrus::ValType::V128 => round_up_to_alignment(results_size, 16) + 16,
walrus::ValType::Externref | walrus::ValType::Funcref => anyhow::bail!(
"cannot multi-value transform functions that return \
reference types, since they can't go into linear memory"
),
};
}
let results_size = round_up_to_alignment(results_size, 16);
let ty = module.funcs.get(func).ty();
let (ty_params, ty_results) = module.types.params_results(ty);
if !ty_results.is_empty() {
anyhow::bail!(
"can only multi-value transform functions that don't return any \
results (since they should be returned on the stack via a pointer)"
);
}
match ty_params.get(return_pointer_index) {
Some(walrus::ValType::I32) => {}
None => anyhow::bail!("the return pointer parameter doesn't exist"),
Some(_) => anyhow::bail!("the return pointer parameter is not `i32`"),
}
let new_params: Vec<_> = ty_params
.iter()
.cloned()
.enumerate()
.filter_map(|(i, ty)| {
if i == return_pointer_index {
None
} else {
Some(ty)
}
})
.collect();
let params: Vec<_> = new_params.iter().map(|ty| module.locals.add(*ty)).collect();
let return_pointer = module.locals.add(walrus::ValType::I32);
let mut wrapper = walrus::FunctionBuilder::new(&mut module.types, &new_params, results);
let mut body = wrapper.func_body();
body.global_get(shadow_stack_pointer)
.i32_const(results_size as i32)
.binop(walrus::ir::BinaryOp::I32Sub)
.local_tee(return_pointer)
.global_set(shadow_stack_pointer);
for (i, local) in params.iter().enumerate() {
if i == return_pointer_index {
body.local_get(return_pointer);
}
body.local_get(*local);
}
if return_pointer_index == params.len() {
body.local_get(return_pointer);
}
body.call(func);
let mut offset = 0;
for ty in results {
debug_assert!(offset < results_size);
body.local_get(return_pointer);
match ty {
walrus::ValType::I32 => {
debug_assert_eq!(offset % 4, 0);
body.load(
memory,
walrus::ir::LoadKind::I32 { atomic: false },
walrus::ir::MemArg { align: 4, offset },
);
offset += 4;
}
walrus::ValType::I64 => {
offset = round_up_to_alignment(offset, 8);
body.load(
memory,
walrus::ir::LoadKind::I64 { atomic: false },
walrus::ir::MemArg { align: 8, offset },
);
offset += 8;
}
walrus::ValType::F32 => {
debug_assert_eq!(offset % 4, 0);
body.load(
memory,
walrus::ir::LoadKind::F32,
walrus::ir::MemArg { align: 4, offset },
);
offset += 4;
}
walrus::ValType::F64 => {
offset = round_up_to_alignment(offset, 8);
body.load(
memory,
walrus::ir::LoadKind::F64,
walrus::ir::MemArg { align: 8, offset },
);
offset += 8;
}
walrus::ValType::V128 => {
offset = round_up_to_alignment(offset, 16);
body.load(
memory,
walrus::ir::LoadKind::V128,
walrus::ir::MemArg { align: 16, offset },
);
offset += 16;
}
walrus::ValType::Externref | walrus::ValType::Funcref => unreachable!(),
}
}
body.local_get(return_pointer)
.i32_const(results_size as i32)
.binop(walrus::ir::BinaryOp::I32Add)
.global_set(shadow_stack_pointer);
let wrapper = wrapper.finish(params, &mut module.funcs);
if let Some(name) = &module.funcs.get(func).name {
module.funcs.get_mut(wrapper).name = Some(format!("{} multivalue shim", name));
}
Ok(wrapper)
}
#[cfg(test)]
mod tests {
#[test]
fn round_up_to_alignment_works() {
for (n, align, expected) in vec![
(0, 1, 0),
(1, 1, 1),
(2, 1, 2),
(0, 2, 0),
(1, 2, 2),
(2, 2, 2),
(3, 2, 4),
(0, 4, 0),
(1, 4, 4),
(2, 4, 4),
(3, 4, 4),
(4, 4, 4),
(5, 4, 8),
] {
let actual = super::round_up_to_alignment(n, align);
println!(
"round_up_to_alignment(n = {}, align = {}) = {} (expected {})",
n, align, actual, expected
);
assert_eq!(actual, expected);
}
}
}