use crate::{Function, Handle, Int, Resolve, Type, TypeDefKind};
use alloc::vec::Vec;
#[derive(Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct WasmSignature {
pub params: Vec<WasmType>,
pub results: Vec<WasmType>,
pub indirect_params: bool,
pub retptr: bool,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum WasmType {
I32,
I64,
F32,
F64,
Pointer,
PointerOrI64,
Length,
}
fn join(a: WasmType, b: WasmType) -> WasmType {
use WasmType::*;
match (a, b) {
(I32, I32)
| (I64, I64)
| (F32, F32)
| (F64, F64)
| (Pointer, Pointer)
| (PointerOrI64, PointerOrI64)
| (Length, Length) => a,
(I32, F32) | (F32, I32) => I32,
(Length, I32 | F32) => Length,
(I32 | F32, Length) => Length,
(Length, I64 | F64) => I64,
(I64 | F64, Length) => I64,
(Pointer, I32 | F32 | Length) => Pointer,
(I32 | F32 | Length, Pointer) => Pointer,
(Pointer, I64 | F64) => PointerOrI64,
(I64 | F64, Pointer) => PointerOrI64,
(PointerOrI64, _) => PointerOrI64,
(_, PointerOrI64) => PointerOrI64,
(_, I64 | F64) | (I64 | F64, _) => I64,
}
}
impl From<Int> for WasmType {
fn from(i: Int) -> WasmType {
match i {
Int::U8 | Int::U16 | Int::U32 => WasmType::I32,
Int::U64 => WasmType::I64,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub enum AbiVariant {
GuestImport,
GuestExport,
GuestImportAsync,
GuestExportAsync,
GuestExportAsyncStackful,
}
impl AbiVariant {
pub fn is_async(&self) -> bool {
match self {
Self::GuestImport | Self::GuestExport => false,
Self::GuestImportAsync | Self::GuestExportAsync | Self::GuestExportAsyncStackful => {
true
}
}
}
}
pub struct FlatTypes<'a> {
types: &'a mut [WasmType],
cur: usize,
overflow: bool,
}
impl<'a> FlatTypes<'a> {
pub fn new(types: &'a mut [WasmType]) -> FlatTypes<'a> {
FlatTypes {
types,
cur: 0,
overflow: false,
}
}
pub fn push(&mut self, ty: WasmType) -> bool {
match self.types.get_mut(self.cur) {
Some(next) => {
*next = ty;
self.cur += 1;
true
}
None => {
self.overflow = true;
false
}
}
}
pub fn to_vec(&self) -> Vec<WasmType> {
self.types[..self.cur].to_vec()
}
}
impl Resolve {
pub const MAX_FLAT_PARAMS: usize = 16;
pub const MAX_FLAT_ASYNC_PARAMS: usize = 4;
pub const MAX_FLAT_RESULTS: usize = 1;
pub fn wasm_signature(&self, variant: AbiVariant, func: &Function) -> WasmSignature {
let mut storage = [WasmType::I32; Self::MAX_FLAT_PARAMS + 1];
let mut params = FlatTypes::new(&mut storage);
let ok = self.push_flat_list(func.params.iter().map(|p| &p.ty), &mut params);
assert_eq!(ok, !params.overflow);
let max = match variant {
AbiVariant::GuestImport
| AbiVariant::GuestExport
| AbiVariant::GuestExportAsync
| AbiVariant::GuestExportAsyncStackful => Self::MAX_FLAT_PARAMS,
AbiVariant::GuestImportAsync => Self::MAX_FLAT_ASYNC_PARAMS,
};
let indirect_params = !ok || params.cur > max;
if indirect_params {
params.types[0] = WasmType::Pointer;
params.cur = 1;
} else {
if matches!(
(&func.kind, variant),
(
crate::FunctionKind::Method(_) | crate::FunctionKind::AsyncMethod(_),
AbiVariant::GuestExport
| AbiVariant::GuestExportAsync
| AbiVariant::GuestExportAsyncStackful
)
) {
assert!(matches!(params.types[0], WasmType::I32));
params.types[0] = WasmType::Pointer;
}
}
let mut storage = [WasmType::I32; Self::MAX_FLAT_RESULTS];
let mut results = FlatTypes::new(&mut storage);
let mut retptr = false;
match variant {
AbiVariant::GuestImport | AbiVariant::GuestExport => {
if let Some(ty) = &func.result {
self.push_flat(ty, &mut results);
}
retptr = results.overflow;
if retptr {
results.cur = 0;
match variant {
AbiVariant::GuestImport => {
assert!(params.push(WasmType::Pointer));
}
AbiVariant::GuestExport => {
assert!(results.push(WasmType::Pointer));
}
_ => unreachable!(),
}
}
}
AbiVariant::GuestImportAsync => {
if func.result.is_some() {
assert!(params.push(WasmType::Pointer));
retptr = true;
}
assert!(results.push(WasmType::I32));
}
AbiVariant::GuestExportAsync => {
assert!(results.push(WasmType::I32));
}
AbiVariant::GuestExportAsyncStackful => {
}
}
WasmSignature {
params: params.to_vec(),
indirect_params,
results: results.to_vec(),
retptr,
}
}
fn push_flat_list<'a>(
&self,
mut list: impl Iterator<Item = &'a Type>,
result: &mut FlatTypes<'_>,
) -> bool {
list.all(|ty| self.push_flat(ty, result))
}
pub fn push_flat(&self, ty: &Type, result: &mut FlatTypes<'_>) -> bool {
match ty {
Type::Bool
| Type::S8
| Type::U8
| Type::S16
| Type::U16
| Type::S32
| Type::U32
| Type::Char
| Type::ErrorContext => result.push(WasmType::I32),
Type::U64 | Type::S64 => result.push(WasmType::I64),
Type::F32 => result.push(WasmType::F32),
Type::F64 => result.push(WasmType::F64),
Type::String => result.push(WasmType::Pointer) && result.push(WasmType::Length),
Type::Id(id) => match &self.types[*id].kind {
TypeDefKind::Type(t) => self.push_flat(t, result),
TypeDefKind::Handle(Handle::Own(_) | Handle::Borrow(_)) => {
result.push(WasmType::I32)
}
TypeDefKind::Resource => todo!(),
TypeDefKind::Record(r) => {
self.push_flat_list(r.fields.iter().map(|f| &f.ty), result)
}
TypeDefKind::Tuple(t) => self.push_flat_list(t.types.iter(), result),
TypeDefKind::Flags(r) => {
self.push_flat_list((0..r.repr().count()).map(|_| &Type::U32), result)
}
TypeDefKind::List(_) => {
result.push(WasmType::Pointer) && result.push(WasmType::Length)
}
TypeDefKind::Map(_, _) => {
result.push(WasmType::Pointer) && result.push(WasmType::Length)
}
TypeDefKind::FixedLengthList(ty, size) => {
self.push_flat_list((0..*size).map(|_| ty), result)
}
TypeDefKind::Variant(v) => {
result.push(v.tag().into())
&& self.push_flat_variants(v.cases.iter().map(|c| c.ty.as_ref()), result)
}
TypeDefKind::Enum(e) => result.push(e.tag().into()),
TypeDefKind::Option(t) => {
result.push(WasmType::I32) && self.push_flat_variants([None, Some(t)], result)
}
TypeDefKind::Result(r) => {
result.push(WasmType::I32)
&& self.push_flat_variants([r.ok.as_ref(), r.err.as_ref()], result)
}
TypeDefKind::Future(_) => result.push(WasmType::I32),
TypeDefKind::Stream(_) => result.push(WasmType::I32),
TypeDefKind::Unknown => unreachable!(),
},
}
}
fn push_flat_variants<'a>(
&self,
tys: impl IntoIterator<Item = Option<&'a Type>>,
result: &mut FlatTypes<'_>,
) -> bool {
let mut temp = result.types[result.cur..].to_vec();
let mut temp = FlatTypes::new(&mut temp);
let start = result.cur;
for ty in tys {
if let Some(ty) = ty {
if !self.push_flat(ty, &mut temp) {
result.overflow = true;
return false;
}
for (i, ty) in temp.types[..temp.cur].iter().enumerate() {
let i = i + start;
if i < result.cur {
result.types[i] = join(result.types[i], *ty);
} else if result.cur == result.types.len() {
result.overflow = true;
return false;
} else {
result.types[i] = *ty;
result.cur += 1;
}
}
temp.cur = 0;
}
}
true
}
}