#![deny(missing_docs, missing_debug_implementations)]
use anyhow::Context;
pub fn run(
module: &mut walrus::Module,
memory: walrus::MemoryId,
stack_pointer: walrus::GlobalId,
to_xform: &[(walrus::FunctionId, usize, Vec<walrus::ValType>)],
) -> Result<Vec<walrus::FunctionId>, anyhow::Error> {
crate::wasm_conventions::insert_target_feature(module, "multivalue")
.context("failed to parse `target_features` custom section")?;
let mut wrappers = Vec::new();
for (func, return_pointer_index, results) in to_xform {
wrappers.push(xform_one(
module,
memory,
stack_pointer,
*func,
*return_pointer_index,
results,
)?);
}
Ok(wrappers)
}
fn round_up_to_alignment(n: u64, align: u64) -> u64 {
debug_assert!(align.is_power_of_two());
(n + align - 1) & !(align - 1)
}
fn xform_one(
module: &mut walrus::Module,
memory: walrus::MemoryId,
stack_pointer: walrus::GlobalId,
func: walrus::FunctionId,
return_pointer_index: usize,
results: &[walrus::ValType],
) -> Result<walrus::FunctionId, anyhow::Error> {
let sp_ty = module.globals.get(stack_pointer).ty;
let memory64 = sp_ty == walrus::ValType::I64;
if sp_ty != walrus::ValType::I32 && sp_ty != walrus::ValType::I64 {
anyhow::bail!("stack pointer global has unexpected type `{sp_ty:?}`");
}
let mut results_size = 0;
for ty in results {
results_size = match ty {
walrus::ValType::I32 | walrus::ValType::F32 => {
debug_assert_eq!(results_size % 4, 0);
results_size + 4
}
walrus::ValType::I64 | walrus::ValType::F64 => {
round_up_to_alignment(results_size, 8) + 8
}
walrus::ValType::V128 => round_up_to_alignment(results_size, 16) + 16,
walrus::ValType::Ref(_) => anyhow::bail!(
"cannot multi-value transform functions that return \
reference types, since they can't go into linear memory"
),
};
}
let results_size = round_up_to_alignment(results_size, 16);
let ty = module.funcs.get(func).ty();
let (ty_params, ty_results) = module.types.params_results(ty);
let func_name = module
.funcs
.get(func)
.name
.as_deref()
.unwrap_or("<unnamed>");
if !ty_results.is_empty() {
anyhow::bail!(
"can only multi-value transform functions that don't return any \
results (since they should be returned on the stack via a \
pointer): `{func_name}` has params {ty_params:?} and results \
{ty_results:?}"
);
}
let ptr_ty = if memory64 {
walrus::ValType::I64
} else {
walrus::ValType::I32
};
match ty_params.get(return_pointer_index) {
Some(ty) if *ty == ptr_ty => {}
None => anyhow::bail!("the return pointer parameter doesn't exist"),
Some(ty) => {
anyhow::bail!("the return pointer parameter is not `{ptr_ty:?}` (got `{ty:?}`)")
}
}
let new_params: Vec<_> = ty_params
.iter()
.cloned()
.enumerate()
.filter_map(|(i, ty)| {
if i == return_pointer_index {
None
} else {
Some(ty)
}
})
.collect();
let params: Vec<_> = new_params.iter().map(|ty| module.locals.add(*ty)).collect();
let return_pointer = module.locals.add(ptr_ty);
let mut wrapper = walrus::FunctionBuilder::new(&mut module.types, &new_params, results);
let mut body = wrapper.func_body();
body.global_get(stack_pointer);
if memory64 {
body.i64_const(results_size as i64)
.binop(walrus::ir::BinaryOp::I64Sub);
} else {
body.i32_const(results_size as i32)
.binop(walrus::ir::BinaryOp::I32Sub);
}
body.local_tee(return_pointer).global_set(stack_pointer);
for (i, local) in params.iter().enumerate() {
if i == return_pointer_index {
body.local_get(return_pointer);
}
body.local_get(*local);
}
if return_pointer_index == params.len() {
body.local_get(return_pointer);
}
body.call(func);
let mut offset: u64 = 0;
for ty in results {
debug_assert!(offset < results_size);
body.local_get(return_pointer);
match ty {
walrus::ValType::I32 => {
debug_assert_eq!(offset % 4, 0);
body.load(
memory,
walrus::ir::LoadKind::I32 { atomic: false },
walrus::ir::MemArg { align: 4, offset },
);
offset += 4;
}
walrus::ValType::I64 => {
offset = round_up_to_alignment(offset, 8);
body.load(
memory,
walrus::ir::LoadKind::I64 { atomic: false },
walrus::ir::MemArg { align: 8, offset },
);
offset += 8;
}
walrus::ValType::F32 => {
debug_assert_eq!(offset % 4, 0);
body.load(
memory,
walrus::ir::LoadKind::F32,
walrus::ir::MemArg { align: 4, offset },
);
offset += 4;
}
walrus::ValType::F64 => {
offset = round_up_to_alignment(offset, 8);
body.load(
memory,
walrus::ir::LoadKind::F64,
walrus::ir::MemArg { align: 8, offset },
);
offset += 8;
}
walrus::ValType::V128 => {
offset = round_up_to_alignment(offset, 16);
body.load(
memory,
walrus::ir::LoadKind::V128,
walrus::ir::MemArg { align: 16, offset },
);
offset += 16;
}
walrus::ValType::Ref(_) => unreachable!(),
}
}
body.local_get(return_pointer);
if memory64 {
body.i64_const(results_size as i64)
.binop(walrus::ir::BinaryOp::I64Add);
} else {
body.i32_const(results_size as i32)
.binop(walrus::ir::BinaryOp::I32Add);
}
body.global_set(stack_pointer);
let wrapper = wrapper.finish(params, &mut module.funcs);
if let Some(name) = &module.funcs.get(func).name {
module.funcs.get_mut(wrapper).name = Some(format!("{name} multivalue shim"));
}
Ok(wrapper)
}
#[cfg(test)]
mod tests;