use crate::component::func::{Func, Memory, MemoryMut, Options};
use crate::component::storage::{storage_as_slice, storage_as_slice_mut};
use crate::store::StoreOpaque;
use crate::{AsContext, AsContextMut, StoreContext, StoreContextMut, ValRaw};
use anyhow::{anyhow, bail, Context, Result};
use std::borrow::Cow;
use std::fmt;
use std::marker;
use std::mem::{self, MaybeUninit};
use std::str;
use wasmtime_environ::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, VariantInfo, MAX_FLAT_PARAMS,
MAX_FLAT_RESULTS,
};
pub struct TypedFunc<Params, Return> {
func: Func,
_marker: marker::PhantomData<(Params, Return)>,
}
impl<Params, Return> Copy for TypedFunc<Params, Return> {}
impl<Params, Return> Clone for TypedFunc<Params, Return> {
fn clone(&self) -> TypedFunc<Params, Return> {
*self
}
}
impl<Params, Return> TypedFunc<Params, Return>
where
Params: ComponentNamedList + Lower,
Return: Lift,
{
pub unsafe fn new_unchecked(func: Func) -> TypedFunc<Params, Return> {
TypedFunc {
_marker: marker::PhantomData,
func,
}
}
pub fn func(&self) -> &Func {
&self.func
}
pub fn call(&self, store: impl AsContextMut, params: Params) -> Result<Return> {
assert!(
!store.as_context().async_support(),
"must use `call_async` when async support is enabled on the config"
);
self.call_impl(store, params)
}
#[cfg(feature = "async")]
#[cfg_attr(nightlydoc, doc(cfg(feature = "async")))]
pub async fn call_async<T>(
&self,
mut store: impl AsContextMut<Data = T>,
params: Params,
) -> Result<Return>
where
T: Send,
Params: Send + Sync,
Return: Send + Sync,
{
let mut store = store.as_context_mut();
assert!(
store.0.async_support(),
"cannot use `call_async` when async support is not enabled on the config"
);
store
.on_fiber(|store| self.call_impl(store, params))
.await?
}
fn call_impl(&self, mut store: impl AsContextMut, params: Params) -> Result<Return> {
let store = &mut store.as_context_mut();
if Params::flatten_count() <= MAX_FLAT_PARAMS {
if Return::flatten_count() <= MAX_FLAT_RESULTS {
self.func.call_raw(
store,
¶ms,
Self::lower_stack_args,
Self::lift_stack_result,
)
} else {
self.func.call_raw(
store,
¶ms,
Self::lower_stack_args,
Self::lift_heap_result,
)
}
} else {
if Return::flatten_count() <= MAX_FLAT_RESULTS {
self.func.call_raw(
store,
¶ms,
Self::lower_heap_args,
Self::lift_stack_result,
)
} else {
self.func.call_raw(
store,
¶ms,
Self::lower_heap_args,
Self::lift_heap_result,
)
}
}
}
fn lower_stack_args<T>(
store: &mut StoreContextMut<'_, T>,
options: &Options,
params: &Params,
dst: &mut MaybeUninit<Params::Lower>,
) -> Result<()> {
assert!(Params::flatten_count() <= MAX_FLAT_PARAMS);
params.lower(store, options, dst)?;
Ok(())
}
fn lower_heap_args<T>(
store: &mut StoreContextMut<'_, T>,
options: &Options,
params: &Params,
dst: &mut MaybeUninit<ValRaw>,
) -> Result<()> {
assert!(Params::flatten_count() > MAX_FLAT_PARAMS);
let mut memory = MemoryMut::new(store.as_context_mut(), options);
let ptr = memory.realloc(0, 0, Params::ALIGN32, Params::SIZE32)?;
params.store(&mut memory, ptr)?;
dst.write(ValRaw::i64(ptr as i64));
Ok(())
}
fn lift_stack_result(
store: &StoreOpaque,
options: &Options,
dst: &Return::Lower,
) -> Result<Return> {
assert!(Return::flatten_count() <= MAX_FLAT_RESULTS);
Return::lift(store, options, dst)
}
fn lift_heap_result(store: &StoreOpaque, options: &Options, dst: &ValRaw) -> Result<Return> {
assert!(Return::flatten_count() > MAX_FLAT_RESULTS);
let ptr = usize::try_from(dst.get_u32())?;
if ptr % usize::try_from(Return::ALIGN32)? != 0 {
bail!("return pointer not aligned");
}
let memory = Memory::new(store, options);
let bytes = memory
.as_slice()
.get(ptr..)
.and_then(|b| b.get(..Return::SIZE32))
.ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?;
Return::load(&memory, bytes)
}
pub fn post_return(&self, store: impl AsContextMut) -> Result<()> {
self.func.post_return(store)
}
#[cfg(feature = "async")]
#[cfg_attr(nightlydoc, doc(cfg(feature = "async")))]
pub async fn post_return_async<T: Send>(
&self,
store: impl AsContextMut<Data = T>,
) -> Result<()> {
self.func.post_return_async(store).await
}
}
pub unsafe trait ComponentNamedList: ComponentType {
#[doc(hidden)]
fn typecheck_list(params: &[InterfaceType], types: &ComponentTypes) -> Result<()>;
}
pub unsafe trait ComponentType {
#[doc(hidden)]
type Lower: Copy;
#[doc(hidden)]
const ABI: CanonicalAbiInfo;
#[doc(hidden)]
const SIZE32: usize = Self::ABI.size32 as usize;
#[doc(hidden)]
const ALIGN32: u32 = Self::ABI.align32;
#[doc(hidden)]
const IS_RUST_UNIT_TYPE: bool = false;
#[doc(hidden)]
fn flatten_count() -> usize {
assert!(mem::size_of::<Self::Lower>() % mem::size_of::<ValRaw>() == 0);
assert!(mem::align_of::<Self::Lower>() == mem::align_of::<ValRaw>());
mem::size_of::<Self::Lower>() / mem::size_of::<ValRaw>()
}
#[doc(hidden)]
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()>;
}
#[doc(hidden)]
pub unsafe trait ComponentVariant: ComponentType {
const CASES: &'static [Option<CanonicalAbiInfo>];
const INFO: VariantInfo = VariantInfo::new_static(Self::CASES);
const PAYLOAD_OFFSET32: usize = Self::INFO.payload_offset32 as usize;
}
pub unsafe trait Lower: ComponentType {
#[doc(hidden)]
fn lower<T>(
&self,
store: &mut StoreContextMut<T>,
options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()>;
#[doc(hidden)]
fn store<T>(&self, memory: &mut MemoryMut<'_, T>, offset: usize) -> Result<()>;
}
pub unsafe trait Lift: Sized + ComponentType {
#[doc(hidden)]
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self>;
#[doc(hidden)]
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self>;
}
macro_rules! forward_type_impls {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> ComponentType for $a {
type Lower = <$b as ComponentType>::Lower;
const ABI: CanonicalAbiInfo = <$b as ComponentType>::ABI;
#[inline]
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> {
<$b as ComponentType>::typecheck(ty, types)
}
}
)*)
}
forward_type_impls! {
(T: ComponentType + ?Sized) &'_ T => T,
(T: ComponentType + ?Sized) Box<T> => T,
(T: ComponentType + ?Sized) std::rc::Rc<T> => T,
(T: ComponentType + ?Sized) std::sync::Arc<T> => T,
() String => str,
(T: ComponentType) Vec<T> => [T],
}
macro_rules! forward_lowers {
($(($($generics:tt)*) $a:ty => $b:ty,)*) => ($(
unsafe impl <$($generics)*> Lower for $a {
fn lower<U>(
&self,
store: &mut StoreContextMut<U>,
options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
<$b as Lower>::lower(self, store, options, dst)
}
fn store<U>(&self, memory: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> {
<$b as Lower>::store(self, memory, offset)
}
}
)*)
}
forward_lowers! {
(T: Lower + ?Sized) &'_ T => T,
(T: Lower + ?Sized) Box<T> => T,
(T: Lower + ?Sized) std::rc::Rc<T> => T,
(T: Lower + ?Sized) std::sync::Arc<T> => T,
() String => str,
(T: Lower) Vec<T> => [T],
}
macro_rules! forward_string_lifts {
($($a:ty,)*) => ($(
unsafe impl Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(<WasmStr as Lift>::lift(store, options, src)?.to_str_from_store(store)?.into())
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
Ok(<WasmStr as Lift>::load(memory, bytes)?.to_str_from_store(&memory.store)?.into())
}
}
)*)
}
forward_string_lifts! {
Box<str>,
std::rc::Rc<str>,
std::sync::Arc<str>,
String,
}
macro_rules! forward_list_lifts {
($($a:ty,)*) => ($(
unsafe impl <T: Lift> Lift for $a {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
let list = <WasmList::<T> as Lift>::lift(store, options, src)?;
(0..list.len).map(|index| list.get_from_store(store, index).unwrap()).collect()
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
let list = <WasmList::<T> as Lift>::load(memory, bytes)?;
(0..list.len).map(|index| list.get_from_store(&memory.store, index).unwrap()).collect()
}
}
)*)
}
forward_list_lifts! {
Box<[T]>,
std::rc::Rc<[T]>,
std::sync::Arc<[T]>,
Vec<T>,
}
macro_rules! integers {
($($primitive:ident = $ty:ident in $field:ident/$get:ident with abi:$abi:ident,)*) => ($(
unsafe impl ComponentType for $primitive {
type Lower = ValRaw;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::$abi;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::$ty => Ok(()),
other => bail!("expected `{}` found `{}`", desc(&InterfaceType::$ty), desc(other))
}
}
}
unsafe impl Lower for $primitive {
fn lower<T>(
&self,
_store: &mut StoreContextMut<T>,
_options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
dst.write(ValRaw::$field(*self as $field));
Ok(())
}
fn store<T>(&self, memory: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % Self::SIZE32 == 0);
*memory.get(offset) = self.to_le_bytes();
Ok(())
}
}
unsafe impl Lift for $primitive {
#[inline]
fn lift(_store: &StoreOpaque, _options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(src.$get() as $primitive)
}
#[inline]
fn load(_mem: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % Self::SIZE32 == 0);
Ok($primitive::from_le_bytes(bytes.try_into().unwrap()))
}
}
)*)
}
integers! {
i8 = S8 in i32/get_i32 with abi:SCALAR1,
u8 = U8 in u32/get_u32 with abi:SCALAR1,
i16 = S16 in i32/get_i32 with abi:SCALAR2,
u16 = U16 in u32/get_u32 with abi:SCALAR2,
i32 = S32 in i32/get_i32 with abi:SCALAR4,
u32 = U32 in u32/get_u32 with abi:SCALAR4,
i64 = S64 in i64/get_i64 with abi:SCALAR8,
u64 = U64 in u64/get_u64 with abi:SCALAR8,
}
macro_rules! floats {
($($float:ident/$get_float:ident = $ty:ident with abi:$abi:ident)*) => ($(const _: () = {
#[inline]
fn canonicalize(float: $float) -> $float {
if float.is_nan() {
$float::NAN
} else {
float
}
}
unsafe impl ComponentType for $float {
type Lower = ValRaw;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::$abi;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::$ty => Ok(()),
other => bail!("expected `{}` found `{}`", desc(&InterfaceType::$ty), desc(other))
}
}
}
unsafe impl Lower for $float {
fn lower<T>(
&self,
_store: &mut StoreContextMut<T>,
_options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
dst.write(ValRaw::$float(canonicalize(*self).to_bits()));
Ok(())
}
fn store<T>(&self, memory: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % Self::SIZE32 == 0);
let ptr = memory.get(offset);
*ptr = canonicalize(*self).to_bits().to_le_bytes();
Ok(())
}
}
unsafe impl Lift for $float {
#[inline]
fn lift(_store: &StoreOpaque, _options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(canonicalize($float::from_bits(src.$get_float())))
}
#[inline]
fn load(_mem: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % Self::SIZE32 == 0);
Ok(canonicalize($float::from_le_bytes(bytes.try_into().unwrap())))
}
}
};)*)
}
floats! {
f32/get_f32 = Float32 with abi:SCALAR4
f64/get_f64 = Float64 with abi:SCALAR8
}
unsafe impl ComponentType for bool {
type Lower = ValRaw;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR1;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::Bool => Ok(()),
other => bail!("expected `bool` found `{}`", desc(other)),
}
}
}
unsafe impl Lower for bool {
fn lower<T>(
&self,
_store: &mut StoreContextMut<T>,
_options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
dst.write(ValRaw::i32(*self as i32));
Ok(())
}
fn store<T>(&self, memory: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % Self::SIZE32 == 0);
memory.get::<1>(offset)[0] = *self as u8;
Ok(())
}
}
unsafe impl Lift for bool {
#[inline]
fn lift(_store: &StoreOpaque, _options: &Options, src: &Self::Lower) -> Result<Self> {
match src.get_i32() {
0 => Ok(false),
_ => Ok(true),
}
}
#[inline]
fn load(_mem: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
match bytes[0] {
0 => Ok(false),
_ => Ok(true),
}
}
}
unsafe impl ComponentType for char {
type Lower = ValRaw;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::Char => Ok(()),
other => bail!("expected `char` found `{}`", desc(other)),
}
}
}
unsafe impl Lower for char {
fn lower<T>(
&self,
_store: &mut StoreContextMut<T>,
_options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
dst.write(ValRaw::u32(u32::from(*self)));
Ok(())
}
fn store<T>(&self, memory: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % Self::SIZE32 == 0);
*memory.get::<4>(offset) = u32::from(*self).to_le_bytes();
Ok(())
}
}
unsafe impl Lift for char {
#[inline]
fn lift(_store: &StoreOpaque, _options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(char::try_from(src.get_u32())?)
}
#[inline]
fn load(_memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % Self::SIZE32 == 0);
let bits = u32::from_le_bytes(bytes.try_into().unwrap());
Ok(char::try_from(bits)?)
}
}
const UTF16_TAG: usize = 1 << 31;
const MAX_STRING_BYTE_LENGTH: usize = (1 << 31) - 1;
unsafe impl ComponentType for str {
type Lower = [ValRaw; 2];
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::String => Ok(()),
other => bail!("expected `string` found `{}`", desc(other)),
}
}
}
unsafe impl Lower for str {
fn lower<T>(
&self,
store: &mut StoreContextMut<T>,
options: &Options,
dst: &mut MaybeUninit<[ValRaw; 2]>,
) -> Result<()> {
let (ptr, len) = lower_string(&mut MemoryMut::new(store.as_context_mut(), options), self)?;
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
map_maybe_uninit!(dst[1]).write(ValRaw::i64(len as i64));
Ok(())
}
fn store<T>(&self, mem: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> {
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
let (ptr, len) = lower_string(mem, self)?;
*mem.get(offset + 0) = (ptr as i32).to_le_bytes();
*mem.get(offset + 4) = (len as i32).to_le_bytes();
Ok(())
}
}
fn lower_string<T>(mem: &mut MemoryMut<'_, T>, string: &str) -> Result<(usize, usize)> {
match mem.string_encoding() {
StringEncoding::Utf8 => {
if string.len() > MAX_STRING_BYTE_LENGTH {
bail!(
"string length of {} too large to copy into wasm",
string.len()
);
}
let ptr = mem.realloc(0, 0, 1, string.len())?;
mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes());
Ok((ptr, string.len()))
}
StringEncoding::Utf16 => {
let size = string.len() * 2;
if size > MAX_STRING_BYTE_LENGTH {
bail!(
"string length of {} too large to copy into wasm",
string.len()
);
}
let mut ptr = mem.realloc(0, 0, 2, size)?;
let mut copied = 0;
let bytes = &mut mem.as_slice_mut()[ptr..][..size];
for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) {
let u_bytes = u.to_le_bytes();
bytes[0] = u_bytes[0];
bytes[1] = u_bytes[1];
copied += 1;
}
if (copied * 2) < size {
ptr = mem.realloc(ptr, size, 2, copied * 2)?;
}
Ok((ptr, copied))
}
StringEncoding::CompactUtf16 => {
let bytes = string.as_bytes();
let mut iter = string.char_indices();
let mut ptr = mem.realloc(0, 0, 2, bytes.len())?;
let mut dst = &mut mem.as_slice_mut()[ptr..][..bytes.len()];
let mut result = 0;
while let Some((i, ch)) = iter.next() {
if let Ok(byte) = u8::try_from(u32::from(ch)) {
dst[result] = byte;
result += 1;
continue;
}
let worst_case = bytes
.len()
.checked_mul(2)
.ok_or_else(|| anyhow!("byte length overflow"))?;
if worst_case > MAX_STRING_BYTE_LENGTH {
bail!("byte length too large");
}
ptr = mem.realloc(ptr, bytes.len(), 2, worst_case)?;
dst = &mut mem.as_slice_mut()[ptr..][..worst_case];
for i in (0..result).rev() {
dst[2 * i] = dst[i];
dst[2 * i + 1] = 0;
}
for (u, bytes) in string[i..]
.encode_utf16()
.zip(dst[2 * result..].chunks_mut(2))
{
let u_bytes = u.to_le_bytes();
bytes[0] = u_bytes[0];
bytes[1] = u_bytes[1];
result += 1;
}
if worst_case > 2 * result {
ptr = mem.realloc(ptr, worst_case, 2, 2 * result)?;
}
return Ok((ptr, result | UTF16_TAG));
}
if result < bytes.len() {
ptr = mem.realloc(ptr, bytes.len(), 2, result)?;
}
Ok((ptr, result))
}
}
}
pub struct WasmStr {
ptr: usize,
len: usize,
options: Options,
}
impl WasmStr {
fn new(ptr: usize, len: usize, memory: &Memory<'_>) -> Result<WasmStr> {
let byte_len = match memory.string_encoding() {
StringEncoding::Utf8 => Some(len),
StringEncoding::Utf16 => len.checked_mul(2),
StringEncoding::CompactUtf16 => {
if len & UTF16_TAG == 0 {
Some(len)
} else {
(len ^ UTF16_TAG).checked_mul(2)
}
}
};
match byte_len.and_then(|len| ptr.checked_add(len)) {
Some(n) if n <= memory.as_slice().len() => {}
_ => bail!("string pointer/length out of bounds of memory"),
}
Ok(WasmStr {
ptr,
len,
options: *memory.options(),
})
}
pub fn to_str<'a, T: 'a>(&self, store: impl Into<StoreContext<'a, T>>) -> Result<Cow<'a, str>> {
self.to_str_from_store(store.into().0)
}
fn to_str_from_store<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
match self.options.string_encoding() {
StringEncoding::Utf8 => self.decode_utf8(store),
StringEncoding::Utf16 => self.decode_utf16(store, self.len),
StringEncoding::CompactUtf16 => {
if self.len & UTF16_TAG == 0 {
self.decode_latin1(store)
} else {
self.decode_utf16(store, self.len ^ UTF16_TAG)
}
}
}
}
fn decode_utf8<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
let memory = self.options.memory(store);
Ok(str::from_utf8(&memory[self.ptr..][..self.len])?.into())
}
fn decode_utf16<'a>(&self, store: &'a StoreOpaque, len: usize) -> Result<Cow<'a, str>> {
let memory = self.options.memory(store);
let memory = &memory[self.ptr..][..len * 2];
Ok(std::char::decode_utf16(
memory
.chunks(2)
.map(|chunk| u16::from_le_bytes(chunk.try_into().unwrap())),
)
.collect::<Result<String, _>>()?
.into())
}
fn decode_latin1<'a>(&self, store: &'a StoreOpaque) -> Result<Cow<'a, str>> {
let memory = self.options.memory(store);
Ok(encoding_rs::mem::decode_latin1(
&memory[self.ptr..][..self.len],
))
}
}
unsafe impl ComponentType for WasmStr {
type Lower = <str as ComponentType>::Lower;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::String => Ok(()),
other => bail!("expected `string` found `{}`", desc(other)),
}
}
}
unsafe impl Lift for WasmStr {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
let ptr = src[0].get_u32();
let len = src[1].get_u32();
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
WasmStr::new(ptr, len, &Memory::new(store, options))
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap());
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap());
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
WasmStr::new(ptr, len, memory)
}
}
unsafe impl<T> ComponentType for [T]
where
T: ComponentType,
{
type Lower = [ValRaw; 2];
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::List(t) => T::typecheck(&types[*t].element, types),
other => bail!("expected `list` found `{}`", desc(other)),
}
}
}
unsafe impl<T> Lower for [T]
where
T: Lower,
{
fn lower<U>(
&self,
store: &mut StoreContextMut<U>,
options: &Options,
dst: &mut MaybeUninit<[ValRaw; 2]>,
) -> Result<()> {
let (ptr, len) = lower_list(&mut MemoryMut::new(store.as_context_mut(), options), self)?;
map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64));
map_maybe_uninit!(dst[1]).write(ValRaw::i64(len as i64));
Ok(())
}
fn store<U>(&self, mem: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> {
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
let (ptr, len) = lower_list(mem, self)?;
*mem.get(offset + 0) = (ptr as i32).to_le_bytes();
*mem.get(offset + 4) = (len as i32).to_le_bytes();
Ok(())
}
}
fn lower_list<T, U>(mem: &mut MemoryMut<'_, U>, list: &[T]) -> Result<(usize, usize)>
where
T: Lower,
{
let elem_size = T::SIZE32;
let size = list
.len()
.checked_mul(elem_size)
.ok_or_else(|| anyhow!("size overflow copying a list"))?;
let ptr = mem.realloc(0, 0, T::ALIGN32, size)?;
let mut cur = ptr;
for item in list {
item.store(mem, cur)?;
cur += elem_size;
}
Ok((ptr, list.len()))
}
pub struct WasmList<T> {
ptr: usize,
len: usize,
options: Options,
_marker: marker::PhantomData<T>,
}
impl<T: Lift> WasmList<T> {
fn new(ptr: usize, len: usize, memory: &Memory<'_>) -> Result<WasmList<T>> {
match len
.checked_mul(T::SIZE32)
.and_then(|len| ptr.checked_add(len))
{
Some(n) if n <= memory.as_slice().len() => {}
_ => bail!("list pointer/length out of bounds of memory"),
}
if ptr % usize::try_from(T::ALIGN32)? != 0 {
bail!("list pointer is not aligned")
}
Ok(WasmList {
ptr,
len,
options: *memory.options(),
_marker: marker::PhantomData,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
pub fn get(&self, store: impl AsContext, index: usize) -> Option<Result<T>> {
self.get_from_store(store.as_context().0, index)
}
fn get_from_store(&self, store: &StoreOpaque, index: usize) -> Option<Result<T>> {
if index >= self.len {
return None;
}
let memory = Memory::new(store, &self.options);
let bytes = &memory.as_slice()[self.ptr + index * T::SIZE32..][..T::SIZE32];
Some(T::load(&memory, bytes))
}
pub fn iter<'a, U: 'a>(
&'a self,
store: impl Into<StoreContext<'a, U>>,
) -> impl ExactSizeIterator<Item = Result<T>> + 'a {
let store = store.into().0;
(0..self.len).map(move |i| self.get_from_store(store, i).unwrap())
}
}
macro_rules! raw_wasm_list_accessors {
($($i:ident)*) => ($(
impl WasmList<$i> {
pub fn as_le_slice<'a, T: 'a>(&self, store: impl Into<StoreContext<'a, T>>) -> &'a [$i] {
let byte_size = self.len * mem::size_of::<$i>();
let bytes = &self.options.memory(store.into().0)[self.ptr..][..byte_size];
unsafe {
let (head, body, tail) = bytes.align_to::<$i>();
assert!(head.is_empty() && tail.is_empty());
body
}
}
}
)*)
}
raw_wasm_list_accessors! {
i8 i16 i32 i64
u8 u16 u32 u64
}
unsafe impl<T: ComponentType> ComponentType for WasmList<T> {
type Lower = <[T] as ComponentType>::Lower;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR;
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> {
<[T] as ComponentType>::typecheck(ty, types)
}
}
unsafe impl<T: Lift> Lift for WasmList<T> {
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
let ptr = src[0].get_u32();
let len = src[1].get_u32();
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
WasmList::new(ptr, len, &Memory::new(store, options))
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
let ptr = u32::from_le_bytes(bytes[..4].try_into().unwrap());
let len = u32::from_le_bytes(bytes[4..].try_into().unwrap());
let (ptr, len) = (usize::try_from(ptr)?, usize::try_from(len)?);
WasmList::new(ptr, len, memory)
}
}
fn typecheck_tuple(
ty: &InterfaceType,
types: &ComponentTypes,
expected: &[fn(&InterfaceType, &ComponentTypes) -> Result<()>],
) -> Result<()> {
match ty {
InterfaceType::Tuple(t) => {
let tuple = &types[*t];
if tuple.types.len() != expected.len() {
bail!(
"expected {}-tuple, found {}-tuple",
expected.len(),
tuple.types.len()
);
}
for (ty, check) in tuple.types.iter().zip(expected) {
check(ty, types)?;
}
Ok(())
}
other => bail!("expected `tuple` found `{}`", desc(other)),
}
}
pub fn typecheck_record(
ty: &InterfaceType,
types: &ComponentTypes,
expected: &[(&str, fn(&InterfaceType, &ComponentTypes) -> Result<()>)],
) -> Result<()> {
match ty {
InterfaceType::Record(index) => {
let fields = &types[*index].fields;
if fields.len() != expected.len() {
bail!(
"expected record of {} fields, found {} fields",
expected.len(),
fields.len()
);
}
for (field, &(name, check)) in fields.iter().zip(expected) {
check(&field.ty, types)
.with_context(|| format!("type mismatch for field {}", name))?;
if field.name != name {
bail!("expected record field named {}, found {}", name, field.name);
}
}
Ok(())
}
other => bail!("expected `record` found `{}`", desc(other)),
}
}
pub fn typecheck_variant(
ty: &InterfaceType,
types: &ComponentTypes,
expected: &[(
&str,
Option<fn(&InterfaceType, &ComponentTypes) -> Result<()>>,
)],
) -> Result<()> {
match ty {
InterfaceType::Variant(index) => {
let cases = &types[*index].cases;
if cases.len() != expected.len() {
bail!(
"expected variant of {} cases, found {} cases",
expected.len(),
cases.len()
);
}
for (case, &(name, check)) in cases.iter().zip(expected) {
if case.name != name {
bail!("expected variant case named {name}, found {}", case.name);
}
match (check, &case.ty) {
(Some(check), Some(ty)) => check(ty, types)
.with_context(|| format!("type mismatch for case {name}"))?,
(None, None) => {}
(Some(_), None) => {
bail!("case `{name}` has no type but one was expected")
}
(None, Some(_)) => {
bail!("case `{name}` has a type but none was expected")
}
}
}
Ok(())
}
other => bail!("expected `variant` found `{}`", desc(other)),
}
}
pub fn typecheck_enum(ty: &InterfaceType, types: &ComponentTypes, expected: &[&str]) -> Result<()> {
match ty {
InterfaceType::Enum(index) => {
let names = &types[*index].names;
if names.len() != expected.len() {
bail!(
"expected enum of {} names, found {} names",
expected.len(),
names.len()
);
}
for (name, expected) in names.iter().zip(expected) {
if name != expected {
bail!("expected enum case named {}, found {}", expected, name);
}
}
Ok(())
}
other => bail!("expected `enum` found `{}`", desc(other)),
}
}
pub fn typecheck_union(
ty: &InterfaceType,
types: &ComponentTypes,
expected: &[fn(&InterfaceType, &ComponentTypes) -> Result<()>],
) -> Result<()> {
match ty {
InterfaceType::Union(index) => {
let union_types = &types[*index].types;
if union_types.len() != expected.len() {
bail!(
"expected union of {} types, found {} types",
expected.len(),
union_types.len()
);
}
for (index, (ty, check)) in union_types.iter().zip(expected).enumerate() {
check(ty, types).with_context(|| format!("type mismatch for case {}", index))?;
}
Ok(())
}
other => bail!("expected `union` found `{}`", desc(other)),
}
}
pub fn typecheck_flags(
ty: &InterfaceType,
types: &ComponentTypes,
expected: &[&str],
) -> Result<()> {
match ty {
InterfaceType::Flags(index) => {
let names = &types[*index].names;
if names.len() != expected.len() {
bail!(
"expected flags type with {} names, found {} names",
expected.len(),
names.len()
);
}
for (name, expected) in names.iter().zip(expected) {
if name != expected {
bail!("expected flag named {}, found {}", expected, name);
}
}
Ok(())
}
other => bail!("expected `flags` found `{}`", desc(other)),
}
}
pub fn format_flags(bits: &[u32], names: &[&str], f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("(")?;
let mut wrote = false;
for (index, name) in names.iter().enumerate() {
if ((bits[index / 32] >> (index % 32)) & 1) != 0 {
if wrote {
f.write_str("|")?;
} else {
wrote = true;
}
f.write_str(name)?;
}
}
f.write_str(")")
}
unsafe impl<T> ComponentType for Option<T>
where
T: ComponentType,
{
type Lower = TupleLower2<<u32 as ComponentType>::Lower, T::Lower>;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::variant_static(&[None, Some(T::ABI)]);
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::Option(t) => T::typecheck(&types[*t].ty, types),
other => bail!("expected `option` found `{}`", desc(other)),
}
}
}
unsafe impl<T> ComponentVariant for Option<T>
where
T: ComponentType,
{
const CASES: &'static [Option<CanonicalAbiInfo>] = &[None, Some(T::ABI)];
}
unsafe impl<T> Lower for Option<T>
where
T: Lower,
{
fn lower<U>(
&self,
store: &mut StoreContextMut<U>,
options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
match self {
None => {
map_maybe_uninit!(dst.A1).write(ValRaw::i32(0));
unsafe {
map_maybe_uninit!(dst.A2).as_mut_ptr().write_bytes(0u8, 1);
}
}
Some(val) => {
map_maybe_uninit!(dst.A1).write(ValRaw::i32(1));
val.lower(store, options, map_maybe_uninit!(dst.A2))?;
}
}
Ok(())
}
fn store<U>(&self, mem: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> {
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
match self {
None => {
mem.get::<1>(offset)[0] = 0;
}
Some(val) => {
mem.get::<1>(offset)[0] = 1;
val.store(mem, offset + (Self::INFO.payload_offset32 as usize))?;
}
}
Ok(())
}
}
unsafe impl<T> Lift for Option<T>
where
T: Lift,
{
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(match src.A1.get_i32() {
0 => None,
1 => Some(T::lift(store, options, &src.A2)?),
_ => bail!("invalid option discriminant"),
})
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
let discrim = bytes[0];
let payload = &bytes[Self::INFO.payload_offset32 as usize..];
match discrim {
0 => Ok(None),
1 => Ok(Some(T::load(memory, payload)?)),
_ => bail!("invalid option discriminant"),
}
}
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct ResultLower<T: Copy, E: Copy> {
tag: ValRaw,
payload: ResultLowerPayload<T, E>,
}
#[derive(Clone, Copy)]
#[repr(C)]
union ResultLowerPayload<T: Copy, E: Copy> {
ok: T,
err: E,
}
unsafe impl<T, E> ComponentType for Result<T, E>
where
T: ComponentType,
E: ComponentType,
{
type Lower = ResultLower<T::Lower, E::Lower>;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::variant_static(&[Some(T::ABI), Some(E::ABI)]);
fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> {
match ty {
InterfaceType::Result(r) => {
let result = &types[*r];
match &result.ok {
Some(ty) => T::typecheck(ty, types)?,
None if T::IS_RUST_UNIT_TYPE => {}
None => bail!("expected no `ok` type"),
}
match &result.err {
Some(ty) => E::typecheck(ty, types)?,
None if E::IS_RUST_UNIT_TYPE => {}
None => bail!("expected no `err` type"),
}
Ok(())
}
other => bail!("expected `result` found `{}`", desc(other)),
}
}
}
pub unsafe fn lower_payload<P, T>(
payload: &mut MaybeUninit<P>,
typed_payload: impl FnOnce(&mut MaybeUninit<P>) -> &mut MaybeUninit<T>,
lower: impl FnOnce(&mut MaybeUninit<T>) -> Result<()>,
) -> Result<()> {
let typed = typed_payload(payload);
lower(typed)?;
let typed_len = storage_as_slice(typed).len();
let payload = storage_as_slice_mut(payload);
for slot in payload[typed_len..].iter_mut() {
*slot = ValRaw::u64(0);
}
Ok(())
}
unsafe impl<T, E> ComponentVariant for Result<T, E>
where
T: ComponentType,
E: ComponentType,
{
const CASES: &'static [Option<CanonicalAbiInfo>] = &[Some(T::ABI), Some(E::ABI)];
}
unsafe impl<T, E> Lower for Result<T, E>
where
T: Lower,
E: Lower,
{
fn lower<U>(
&self,
store: &mut StoreContextMut<U>,
options: &Options,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
match self {
Ok(e) => {
map_maybe_uninit!(dst.tag).write(ValRaw::i32(0));
unsafe {
lower_payload(
map_maybe_uninit!(dst.payload),
|payload| map_maybe_uninit!(payload.ok),
|dst| e.lower(store, options, dst),
)
}
}
Err(e) => {
map_maybe_uninit!(dst.tag).write(ValRaw::i32(1));
unsafe {
lower_payload(
map_maybe_uninit!(dst.payload),
|payload| map_maybe_uninit!(payload.err),
|dst| e.lower(store, options, dst),
)
}
}
}
}
fn store<U>(&self, mem: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> {
debug_assert!(offset % (Self::ALIGN32 as usize) == 0);
let payload_offset = Self::INFO.payload_offset32 as usize;
match self {
Ok(e) => {
mem.get::<1>(offset)[0] = 0;
e.store(mem, offset + payload_offset)?;
}
Err(e) => {
mem.get::<1>(offset)[0] = 1;
e.store(mem, offset + payload_offset)?;
}
}
Ok(())
}
}
unsafe impl<T, E> Lift for Result<T, E>
where
T: Lift,
E: Lift,
{
fn lift(store: &StoreOpaque, options: &Options, src: &Self::Lower) -> Result<Self> {
Ok(match src.tag.get_i32() {
0 => Ok(unsafe { T::lift(store, options, &src.payload.ok)? }),
1 => Err(unsafe { E::lift(store, options, &src.payload.err)? }),
_ => bail!("invalid expected discriminant"),
})
}
fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
let discrim = bytes[0];
let payload = &bytes[Self::INFO.payload_offset32 as usize..];
match discrim {
0 => Ok(Ok(T::load(memory, &payload[..T::SIZE32])?)),
1 => Ok(Err(E::load(memory, &payload[..E::SIZE32])?)),
_ => bail!("invalid expected discriminant"),
}
}
}
macro_rules! impl_component_ty_for_tuples {
($n:tt $($t:ident)*) => {paste::paste!{
#[allow(non_snake_case)]
#[doc(hidden)]
#[derive(Clone, Copy)]
#[repr(C)]
pub struct [<TupleLower$n>]<$($t),*> {
$($t: $t,)*
_align_tuple_lower0_correctly: [ValRaw; 0],
}
#[allow(non_snake_case)]
unsafe impl<$($t,)*> ComponentType for ($($t,)*)
where $($t: ComponentType),*
{
type Lower = [<TupleLower$n>]<$($t::Lower),*>;
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::record_static(&[
$($t::ABI),*
]);
const IS_RUST_UNIT_TYPE: bool = {
let mut _is_unit = true;
$(
let _anything_to_bind_the_macro_variable = $t::IS_RUST_UNIT_TYPE;
_is_unit = false;
)*
_is_unit
};
fn typecheck(
ty: &InterfaceType,
types: &ComponentTypes,
) -> Result<()> {
typecheck_tuple(ty, types, &[$($t::typecheck),*])
}
}
#[allow(non_snake_case)]
unsafe impl<$($t,)*> Lower for ($($t,)*)
where $($t: Lower),*
{
fn lower<U>(
&self,
_store: &mut StoreContextMut<U>,
_options: &Options,
_dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
let ($($t,)*) = self;
$($t.lower(_store, _options, map_maybe_uninit!(_dst.$t))?;)*
Ok(())
}
fn store<U>(&self, _memory: &mut MemoryMut<'_, U>, mut _offset: usize) -> Result<()> {
debug_assert!(_offset % (Self::ALIGN32 as usize) == 0);
let ($($t,)*) = self;
$($t.store(_memory, $t::ABI.next_field32_size(&mut _offset))?;)*
Ok(())
}
}
#[allow(non_snake_case)]
unsafe impl<$($t,)*> Lift for ($($t,)*)
where $($t: Lift),*
{
fn lift(_store: &StoreOpaque, _options: &Options, _src: &Self::Lower) -> Result<Self> {
Ok(($($t::lift(_store, _options, &_src.$t)?,)*))
}
fn load(_memory: &Memory<'_>, bytes: &[u8]) -> Result<Self> {
debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0);
let mut _offset = 0;
$(let $t = $t::load(_memory, &bytes[$t::ABI.next_field32_size(&mut _offset)..][..$t::SIZE32])?;)*
Ok(($($t,)*))
}
}
#[allow(non_snake_case)]
unsafe impl<$($t,)*> ComponentNamedList for ($($t,)*)
where $($t: ComponentType),*
{
fn typecheck_list(
names: &[InterfaceType],
_types: &ComponentTypes,
) -> Result<()> {
if names.len() != $n {
bail!("expected {} types, found {}", $n, names.len());
}
let mut names = names.iter();
$($t::typecheck(names.next().unwrap(), _types)?;)*
debug_assert!(names.next().is_none());
Ok(())
}
}
}};
}
for_each_function_signature!(impl_component_ty_for_tuples);
fn desc(ty: &InterfaceType) -> &'static str {
match ty {
InterfaceType::U8 => "u8",
InterfaceType::S8 => "s8",
InterfaceType::U16 => "u16",
InterfaceType::S16 => "s16",
InterfaceType::U32 => "u32",
InterfaceType::S32 => "s32",
InterfaceType::U64 => "u64",
InterfaceType::S64 => "s64",
InterfaceType::Float32 => "f32",
InterfaceType::Float64 => "f64",
InterfaceType::Bool => "bool",
InterfaceType::Char => "char",
InterfaceType::String => "string",
InterfaceType::List(_) => "list",
InterfaceType::Tuple(_) => "tuple",
InterfaceType::Option(_) => "option",
InterfaceType::Result(_) => "result",
InterfaceType::Record(_) => "record",
InterfaceType::Variant(_) => "variant",
InterfaceType::Flags(_) => "flags",
InterfaceType::Enum(_) => "enum",
InterfaceType::Union(_) => "union",
}
}