use crate::{externals::Memory, FromToNativeWasmType};
use std::{cell::Cell, fmt, marker::PhantomData, mem};
use wasmer_types::ValueType;
pub struct Array;
pub struct Item;
#[repr(transparent)]
pub struct WasmPtr<T: Copy, Ty = Item> {
offset: u32,
_phantom: PhantomData<(T, Ty)>,
}
impl<T: Copy, Ty> WasmPtr<T, Ty> {
#[inline]
pub fn new(offset: u32) -> Self {
Self {
offset,
_phantom: PhantomData,
}
}
#[inline]
pub fn offset(self) -> u32 {
self.offset
}
}
#[inline(always)]
fn align_pointer(ptr: usize, align: usize) -> usize {
debug_assert!(align.count_ones() == 1);
ptr & !(align - 1)
}
impl<T: Copy + ValueType> WasmPtr<T, Item> {
#[inline]
pub fn deref<'a>(self, memory: &'a Memory) -> Option<&'a Cell<T>> {
if (self.offset as usize) + mem::size_of::<T>() > memory.size().bytes().0
|| mem::size_of::<T>() == 0
{
return None;
}
unsafe {
let cell_ptr = align_pointer(
memory.view::<u8>().as_ptr().add(self.offset as usize) as usize,
mem::align_of::<T>(),
) as *const Cell<T>;
Some(&*cell_ptr)
}
}
#[inline]
pub unsafe fn deref_mut<'a>(self, memory: &'a Memory) -> Option<&'a mut Cell<T>> {
if (self.offset as usize) + mem::size_of::<T>() > memory.size().bytes().0
|| mem::size_of::<T>() == 0
{
return None;
}
let cell_ptr = align_pointer(
memory.view::<u8>().as_ptr().add(self.offset as usize) as usize,
mem::align_of::<T>(),
) as *mut Cell<T>;
Some(&mut *cell_ptr)
}
}
impl<T: Copy + ValueType> WasmPtr<T, Array> {
#[inline]
pub fn deref(self, memory: &Memory, index: u32, length: u32) -> Option<&[Cell<T>]> {
let item_size = mem::size_of::<T>() + (mem::size_of::<T>() % mem::align_of::<T>());
let slice_full_len = index as usize + length as usize;
let memory_size = memory.size().bytes().0;
if (self.offset as usize) + (item_size * slice_full_len) > memory_size
|| self.offset as usize >= memory_size
|| mem::size_of::<T>() == 0
{
return None;
}
unsafe {
let cell_ptr = align_pointer(
memory.view::<u8>().as_ptr().add(self.offset as usize) as usize,
mem::align_of::<T>(),
) as *const Cell<T>;
let cell_ptrs = &std::slice::from_raw_parts(cell_ptr, slice_full_len)
[index as usize..slice_full_len];
Some(cell_ptrs)
}
}
#[inline]
pub unsafe fn deref_mut(
self,
memory: &Memory,
index: u32,
length: u32,
) -> Option<&mut [Cell<T>]> {
let item_size = mem::size_of::<T>() + (mem::size_of::<T>() % mem::align_of::<T>());
let slice_full_len = index as usize + length as usize;
let memory_size = memory.size().bytes().0;
if (self.offset as usize) + (item_size * slice_full_len) > memory.size().bytes().0
|| self.offset as usize >= memory_size
|| mem::size_of::<T>() == 0
{
return None;
}
let cell_ptr = align_pointer(
memory.view::<u8>().as_ptr().add(self.offset as usize) as usize,
mem::align_of::<T>(),
) as *mut Cell<T>;
let cell_ptrs = &mut std::slice::from_raw_parts_mut(cell_ptr, slice_full_len)
[index as usize..slice_full_len];
Some(cell_ptrs)
}
pub unsafe fn get_utf8_str<'a>(self, memory: &'a Memory, str_len: u32) -> Option<&'a str> {
let memory_size = memory.size().bytes().0;
if self.offset as usize + str_len as usize > memory.size().bytes().0
|| self.offset as usize >= memory_size
{
return None;
}
let ptr = memory.view::<u8>().as_ptr().add(self.offset as usize) as *const u8;
let slice: &[u8] = std::slice::from_raw_parts(ptr, str_len as usize);
std::str::from_utf8(slice).ok()
}
pub fn get_utf8_string(self, memory: &Memory, str_len: u32) -> Option<String> {
let memory_size = memory.size().bytes().0;
if self.offset as usize + str_len as usize > memory.size().bytes().0
|| self.offset as usize >= memory_size
{
return None;
}
let view = memory.view::<u8>();
let mut vec: Vec<u8> = Vec::with_capacity(str_len as usize);
let base = self.offset as usize;
for i in 0..(str_len as usize) {
let byte = view[base + i].get();
vec.push(byte);
}
String::from_utf8(vec).ok()
}
pub unsafe fn get_utf8_str_with_nul<'a>(self, memory: &'a Memory) -> Option<&'a str> {
memory.view::<u8>()[(self.offset as usize)..]
.iter()
.map(|cell| cell.get())
.position(|byte| byte == 0)
.and_then(|length| self.get_utf8_str(memory, length as u32))
}
pub fn get_utf8_string_with_nul(self, memory: &Memory) -> Option<String> {
unsafe { self.get_utf8_str_with_nul(memory) }.map(|s| s.to_owned())
}
}
unsafe impl<T: Copy, Ty> FromToNativeWasmType for WasmPtr<T, Ty> {
type Native = i32;
fn to_native(self) -> Self::Native {
self.offset as i32
}
fn from_native(n: Self::Native) -> Self {
Self {
offset: n as u32,
_phantom: PhantomData,
}
}
}
unsafe impl<T: Copy, Ty> ValueType for WasmPtr<T, Ty> {}
impl<T: Copy, Ty> Clone for WasmPtr<T, Ty> {
fn clone(&self) -> Self {
Self {
offset: self.offset,
_phantom: PhantomData,
}
}
}
impl<T: Copy, Ty> Copy for WasmPtr<T, Ty> {}
impl<T: Copy, Ty> PartialEq for WasmPtr<T, Ty> {
fn eq(&self, other: &Self) -> bool {
self.offset == other.offset
}
}
impl<T: Copy, Ty> Eq for WasmPtr<T, Ty> {}
impl<T: Copy, Ty> fmt::Debug for WasmPtr<T, Ty> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WasmPtr({:#x})", self.offset)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Memory, MemoryType, Store};
#[test]
fn wasm_ptr_memory_bounds_checks_hold() {
let store = Store::default();
let memory_descriptor = MemoryType::new(1, Some(1), false);
let memory = Memory::new(&store, memory_descriptor).unwrap();
let start_wasm_ptr: WasmPtr<u8> = WasmPtr::new(0);
let start_wasm_ptr_array: WasmPtr<u8, Array> = WasmPtr::new(0);
assert!(start_wasm_ptr.deref(&memory).is_some());
assert!(unsafe { start_wasm_ptr.deref_mut(&memory).is_some() });
assert!(start_wasm_ptr_array.deref(&memory, 0, 0).is_some());
assert!(unsafe { start_wasm_ptr_array.get_utf8_str(&memory, 0).is_some() });
assert!(start_wasm_ptr_array.get_utf8_string(&memory, 0).is_some());
assert!(unsafe { start_wasm_ptr_array.deref_mut(&memory, 0, 0).is_some() });
assert!(start_wasm_ptr_array.deref(&memory, 0, 1).is_some());
assert!(unsafe { start_wasm_ptr_array.deref_mut(&memory, 0, 1).is_some() });
let last_valid_address_for_u8 = (memory.size().bytes().0 - 1) as u32;
let end_wasm_ptr: WasmPtr<u8> = WasmPtr::new(last_valid_address_for_u8);
assert!(end_wasm_ptr.deref(&memory).is_some());
assert!(unsafe { end_wasm_ptr.deref_mut(&memory).is_some() });
let end_wasm_ptr_array: WasmPtr<u8, Array> = WasmPtr::new(last_valid_address_for_u8);
assert!(end_wasm_ptr_array.deref(&memory, 0, 1).is_some());
assert!(unsafe { end_wasm_ptr_array.deref_mut(&memory, 0, 1).is_some() });
let invalid_idx_len_combos: [(u32, u32); 3] =
[(last_valid_address_for_u8 + 1, 0), (0, 2), (1, 1)];
for &(idx, len) in invalid_idx_len_combos.iter() {
assert!(end_wasm_ptr_array.deref(&memory, idx, len).is_none());
assert!(unsafe { end_wasm_ptr_array.deref_mut(&memory, idx, len).is_none() });
}
assert!(unsafe { end_wasm_ptr_array.get_utf8_str(&memory, 2).is_none() });
assert!(end_wasm_ptr_array.get_utf8_string(&memory, 2).is_none());
let last_valid_address_for_u32 = (memory.size().bytes().0 - 4) as u32;
let end_wasm_ptr: WasmPtr<u32> = WasmPtr::new(last_valid_address_for_u32);
assert!(end_wasm_ptr.deref(&memory).is_some());
assert!(unsafe { end_wasm_ptr.deref_mut(&memory).is_some() });
assert!(end_wasm_ptr.deref(&memory).is_some());
assert!(unsafe { end_wasm_ptr.deref_mut(&memory).is_some() });
let end_wasm_ptr_oob_array: [WasmPtr<u32>; 4] = [
WasmPtr::new(last_valid_address_for_u32 + 1),
WasmPtr::new(last_valid_address_for_u32 + 2),
WasmPtr::new(last_valid_address_for_u32 + 3),
WasmPtr::new(last_valid_address_for_u32 + 4),
];
for oob_end_ptr in end_wasm_ptr_oob_array.iter() {
assert!(oob_end_ptr.deref(&memory).is_none());
assert!(unsafe { oob_end_ptr.deref_mut(&memory).is_none() });
}
let end_wasm_ptr_array: WasmPtr<u32, Array> = WasmPtr::new(last_valid_address_for_u32);
assert!(end_wasm_ptr_array.deref(&memory, 0, 1).is_some());
assert!(unsafe { end_wasm_ptr_array.deref_mut(&memory, 0, 1).is_some() });
let invalid_idx_len_combos: [(u32, u32); 3] =
[(last_valid_address_for_u32 + 1, 0), (0, 2), (1, 1)];
for &(idx, len) in invalid_idx_len_combos.iter() {
assert!(end_wasm_ptr_array.deref(&memory, idx, len).is_none());
assert!(unsafe { end_wasm_ptr_array.deref_mut(&memory, idx, len).is_none() });
}
let end_wasm_ptr_array_oob_array: [WasmPtr<u32, Array>; 4] = [
WasmPtr::new(last_valid_address_for_u32 + 1),
WasmPtr::new(last_valid_address_for_u32 + 2),
WasmPtr::new(last_valid_address_for_u32 + 3),
WasmPtr::new(last_valid_address_for_u32 + 4),
];
for oob_end_array_ptr in end_wasm_ptr_array_oob_array.iter() {
assert!(oob_end_array_ptr.deref(&memory, 0, 1).is_none());
assert!(unsafe { oob_end_array_ptr.deref_mut(&memory, 0, 1).is_none() });
assert!(oob_end_array_ptr.deref(&memory, 1, 0).is_none());
assert!(unsafe { oob_end_array_ptr.deref_mut(&memory, 1, 0).is_none() });
}
}
}