use crate::component::{ComponentInstance, VMComponentContext};
use anyhow::{anyhow, Result};
use std::cell::Cell;
use std::slice;
use wasmtime_environ::component::TypeResourceTableIndex;
const UTF16_TAG: usize = 1 << 31;
#[repr(C)] pub struct VMComponentLibcalls {
builtins: VMComponentBuiltins,
transcoders: VMBuiltinTranscodeArray,
}
impl VMComponentLibcalls {
pub const INIT: VMComponentLibcalls = VMComponentLibcalls {
builtins: VMComponentBuiltins::INIT,
transcoders: VMBuiltinTranscodeArray::INIT,
};
}
macro_rules! signature {
(@ty size) => (usize);
(@ty size_pair) => (usize);
(@ty ptr_u8) => (*mut u8);
(@ty ptr_u16) => (*mut u16);
(@ty ptr_size) => (*mut usize);
(@ty u32) => (u32);
(@ty u64) => (u64);
(@ty vmctx) => (*mut VMComponentContext);
(@retptr size_pair) => (*mut usize);
(@retptr $other:ident) => (());
}
macro_rules! define_builtins {
(
$(
$( #[$attr:meta] )*
$name:ident( $( $pname:ident: $param:ident ),* ) $( -> $result:ident )?;
)*
) => {
#[repr(C)]
struct VMComponentBuiltins {
$(
$name: unsafe extern "C" fn(
$(signature!(@ty $param),)*
) $( -> signature!(@ty $result))?,
)*
}
impl VMComponentBuiltins {
const INIT: VMComponentBuiltins = VMComponentBuiltins {
$($name: trampolines::$name,)*
};
}
};
}
wasmtime_environ::foreach_builtin_component_function!(define_builtins);
macro_rules! define_transcoders {
(
$(
$( #[$attr:meta] )*
$name:ident( $( $pname:ident: $param:ident ),* ) $( -> $result:ident )?;
)*
) => {
#[repr(C)]
struct VMBuiltinTranscodeArray {
$(
$name: unsafe extern "C" fn(
$(signature!(@ty $param),)*
) $( -> signature!(@ty $result))?,
)*
}
impl VMBuiltinTranscodeArray {
const INIT: VMBuiltinTranscodeArray = VMBuiltinTranscodeArray {
$($name: trampolines::$name,)*
};
}
};
}
wasmtime_environ::foreach_transcoder!(define_transcoders);
#[allow(improper_ctypes_definitions)]
mod trampolines {
use super::VMComponentContext;
macro_rules! shims {
(
$(
$( #[$attr:meta] )*
$name:ident( $( $pname:ident: $param:ident ),* ) $( -> $result:ident )?;
)*
) => (
$(
pub unsafe extern "C" fn $name(
$($pname : signature!(@ty $param),)*
) $( -> signature!(@ty $result))? {
$(shims!(@validate_param $pname $param);)*
let result = crate::traphandlers::catch_unwind_and_longjmp(|| {
shims!(@invoke $name() $($pname)*)
});
match result {
Ok(ret) => shims!(@convert_ret ret $($pname: $param)*),
Err(err) => crate::traphandlers::raise_trap(
crate::traphandlers::TrapReason::User {
error: err,
needs_backtrace: true,
},
),
}
}
)*
);
(@convert_ret $ret:ident) => ($ret);
(@convert_ret $ret:ident $retptr:ident: ptr_size) => ({
let (a, b) = $ret;
*$retptr = b;
a
});
(@convert_ret $ret:ident $name:ident: $ty:ident $($rest:tt)*) => (
shims!(@convert_ret $ret $($rest)*)
);
(@validate_param $arg:ident ptr_u16) => ({
assert!(($arg as usize) % 2 == 0, "unaligned 16-bit pointer");
});
(@validate_param $arg:ident $ty:ident) => ();
(@invoke $m:ident ($($args:tt)*)) => (super::$m($($args)*));
(@invoke $m:ident ($($args:tt)*) ret2 $($rest:tt)*) => (
shims!(@invoke $m ($($args)*) $($rest)*)
);
(@invoke $m:ident ($($args:tt)*) $param:ident $($rest:tt)*) => (
shims!(@invoke $m ($($args)* $param,) $($rest)*)
);
}
wasmtime_environ::foreach_builtin_component_function!(shims);
wasmtime_environ::foreach_transcoder!(shims);
}
fn assert_no_overlap<T, U>(a: &[T], b: &[U]) {
let a_start = a.as_ptr() as usize;
let a_end = a_start + (a.len() * std::mem::size_of::<T>());
let b_start = b.as_ptr() as usize;
let b_end = b_start + (b.len() * std::mem::size_of::<U>());
if a_start < b_start {
assert!(a_end < b_start);
} else {
assert!(b_end < a_start);
}
}
unsafe fn utf8_to_utf8(src: *mut u8, len: usize, dst: *mut u8) -> Result<()> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
log::trace!("utf8-to-utf8 {len}");
let src = std::str::from_utf8(src).map_err(|_| anyhow!("invalid utf8 encoding"))?;
dst.copy_from_slice(src.as_bytes());
Ok(())
}
unsafe fn utf16_to_utf16(src: *mut u16, len: usize, dst: *mut u16) -> Result<()> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
log::trace!("utf16-to-utf16 {len}");
run_utf16_to_utf16(src, dst)?;
Ok(())
}
fn run_utf16_to_utf16(src: &[u16], mut dst: &mut [u16]) -> Result<bool> {
let mut all_latin1 = true;
for ch in std::char::decode_utf16(src.iter().map(|i| u16::from_le(*i))) {
let ch = ch.map_err(|_| anyhow!("invalid utf16 encoding"))?;
all_latin1 = all_latin1 && u8::try_from(u32::from(ch)).is_ok();
let result = ch.encode_utf16(dst);
let size = result.len();
for item in result {
*item = item.to_le();
}
dst = &mut dst[size..];
}
Ok(all_latin1)
}
unsafe fn latin1_to_latin1(src: *mut u8, len: usize, dst: *mut u8) -> Result<()> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
log::trace!("latin1-to-latin1 {len}");
dst.copy_from_slice(src);
Ok(())
}
unsafe fn latin1_to_utf16(src: *mut u8, len: usize, dst: *mut u16) -> Result<()> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
for (src, dst) in src.iter().zip(dst) {
*dst = u16::from(*src).to_le();
}
log::trace!("latin1-to-utf16 {len}");
Ok(())
}
unsafe fn utf8_to_utf16(src: *mut u8, len: usize, dst: *mut u16) -> Result<usize> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
let result = run_utf8_to_utf16(src, dst)?;
log::trace!("utf8-to-utf16 {len} => {result}");
Ok(result)
}
fn run_utf8_to_utf16(src: &[u8], dst: &mut [u16]) -> Result<usize> {
let src = std::str::from_utf8(src).map_err(|_| anyhow!("invalid utf8 encoding"))?;
let mut amt = 0;
for (i, dst) in src.encode_utf16().zip(dst) {
*dst = i.to_le();
amt += 1;
}
Ok(amt)
}
unsafe fn utf16_to_utf8(
src: *mut u16,
src_len: usize,
dst: *mut u8,
dst_len: usize,
) -> Result<(usize, usize)> {
let src = slice::from_raw_parts(src, src_len);
let mut dst = slice::from_raw_parts_mut(dst, dst_len);
assert_no_overlap(src, dst);
let src_iter_read = Cell::new(0);
let src_iter = src.iter().map(|i| {
src_iter_read.set(src_iter_read.get() + 1);
u16::from_le(*i)
});
let mut src_read = 0;
let mut dst_written = 0;
for ch in std::char::decode_utf16(src_iter) {
let ch = ch.map_err(|_| anyhow!("invalid utf16 encoding"))?;
if dst.len() < 4 && dst.len() < ch.len_utf8() {
break;
}
src_read = src_iter_read.get();
let len = ch.encode_utf8(dst).len();
dst_written += len;
dst = &mut dst[len..];
}
log::trace!("utf16-to-utf8 {src_len}/{dst_len} => {src_read}/{dst_written}");
Ok((src_read, dst_written))
}
unsafe fn latin1_to_utf8(
src: *mut u8,
src_len: usize,
dst: *mut u8,
dst_len: usize,
) -> Result<(usize, usize)> {
let src = slice::from_raw_parts(src, src_len);
let dst = slice::from_raw_parts_mut(dst, dst_len);
assert_no_overlap(src, dst);
let (read, written) = encoding_rs::mem::convert_latin1_to_utf8_partial(src, dst);
log::trace!("latin1-to-utf8 {src_len}/{dst_len} => ({read}, {written})");
Ok((read, written))
}
unsafe fn utf16_to_compact_probably_utf16(
src: *mut u16,
len: usize,
dst: *mut u16,
) -> Result<usize> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
let all_latin1 = run_utf16_to_utf16(src, dst)?;
if all_latin1 {
let (left, dst, right) = dst.align_to_mut::<u8>();
assert!(left.is_empty());
assert!(right.is_empty());
for i in 0..len {
dst[i] = dst[2 * i];
}
log::trace!("utf16-to-compact-probably-utf16 {len} => latin1 {len}");
Ok(len)
} else {
log::trace!("utf16-to-compact-probably-utf16 {len} => utf16 {len}");
Ok(len | UTF16_TAG)
}
}
unsafe fn utf8_to_latin1(src: *mut u8, len: usize, dst: *mut u8) -> Result<(usize, usize)> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
let read = encoding_rs::mem::utf8_latin1_up_to(src);
let written = encoding_rs::mem::convert_utf8_to_latin1_lossy(&src[..read], dst);
log::trace!("utf8-to-latin1 {len} => ({read}, {written})");
Ok((read, written))
}
unsafe fn utf16_to_latin1(src: *mut u16, len: usize, dst: *mut u8) -> Result<(usize, usize)> {
let src = slice::from_raw_parts(src, len);
let dst = slice::from_raw_parts_mut(dst, len);
assert_no_overlap(src, dst);
let mut size = 0;
for (src, dst) in src.iter().zip(dst) {
let src = u16::from_le(*src);
match u8::try_from(src) {
Ok(src) => *dst = src,
Err(_) => break,
}
size += 1;
}
log::trace!("utf16-to-latin1 {len} => {size}");
Ok((size, size))
}
unsafe fn utf8_to_compact_utf16(
src: *mut u8,
src_len: usize,
dst: *mut u16,
dst_len: usize,
latin1_bytes_so_far: usize,
) -> Result<usize> {
let src = slice::from_raw_parts(src, src_len);
let dst = slice::from_raw_parts_mut(dst, dst_len);
assert_no_overlap(src, dst);
let dst = inflate_latin1_bytes(dst, latin1_bytes_so_far);
let result = run_utf8_to_utf16(src, dst)?;
log::trace!("utf8-to-compact-utf16 {src_len}/{dst_len}/{latin1_bytes_so_far} => {result}");
Ok(result + latin1_bytes_so_far)
}
unsafe fn utf16_to_compact_utf16(
src: *mut u16,
src_len: usize,
dst: *mut u16,
dst_len: usize,
latin1_bytes_so_far: usize,
) -> Result<usize> {
let src = slice::from_raw_parts(src, src_len);
let dst = slice::from_raw_parts_mut(dst, dst_len);
assert_no_overlap(src, dst);
let dst = inflate_latin1_bytes(dst, latin1_bytes_so_far);
run_utf16_to_utf16(src, dst)?;
let result = src.len();
log::trace!("utf16-to-compact-utf16 {src_len}/{dst_len}/{latin1_bytes_so_far} => {result}");
Ok(result + latin1_bytes_so_far)
}
fn inflate_latin1_bytes(dst: &mut [u16], latin1_bytes_so_far: usize) -> &mut [u16] {
let (to_inflate, rest) = dst.split_at_mut(latin1_bytes_so_far);
let (left, mid, right) = unsafe { to_inflate.align_to_mut::<u8>() };
assert!(left.is_empty());
assert!(right.is_empty());
for i in (0..latin1_bytes_so_far).rev() {
mid[2 * i] = mid[i];
mid[2 * i + 1] = 0;
}
return rest;
}
unsafe fn resource_new32(vmctx: *mut VMComponentContext, resource: u32, rep: u32) -> Result<u32> {
let resource = TypeResourceTableIndex::from_u32(resource);
ComponentInstance::from_vmctx(vmctx, |instance| instance.resource_new32(resource, rep))
}
unsafe fn resource_rep32(vmctx: *mut VMComponentContext, resource: u32, idx: u32) -> Result<u32> {
let resource = TypeResourceTableIndex::from_u32(resource);
ComponentInstance::from_vmctx(vmctx, |instance| instance.resource_rep32(resource, idx))
}
unsafe fn resource_drop(vmctx: *mut VMComponentContext, resource: u32, idx: u32) -> Result<u64> {
let resource = TypeResourceTableIndex::from_u32(resource);
ComponentInstance::from_vmctx(vmctx, |instance| {
Ok(match instance.resource_drop(resource, idx)? {
Some(rep) => (u64::from(rep) << 1) | 1,
None => 0,
})
})
}
unsafe fn resource_transfer_own(
vmctx: *mut VMComponentContext,
src_idx: u32,
src_table: u32,
dst_table: u32,
) -> Result<u32> {
let src_table = TypeResourceTableIndex::from_u32(src_table);
let dst_table = TypeResourceTableIndex::from_u32(dst_table);
ComponentInstance::from_vmctx(vmctx, |instance| {
instance.resource_transfer_own(src_idx, src_table, dst_table)
})
}
unsafe fn resource_transfer_borrow(
vmctx: *mut VMComponentContext,
src_idx: u32,
src_table: u32,
dst_table: u32,
) -> Result<u32> {
let src_table = TypeResourceTableIndex::from_u32(src_table);
let dst_table = TypeResourceTableIndex::from_u32(dst_table);
ComponentInstance::from_vmctx(vmctx, |instance| {
instance.resource_transfer_borrow(src_idx, src_table, dst_table)
})
}
unsafe fn resource_enter_call(vmctx: *mut VMComponentContext) -> Result<()> {
ComponentInstance::from_vmctx(vmctx, |instance| Ok(instance.resource_enter_call()))
}
unsafe fn resource_exit_call(vmctx: *mut VMComponentContext) -> Result<()> {
ComponentInstance::from_vmctx(vmctx, |instance| instance.resource_exit_call())
}