use super::{
ExecOutcome, GlobalValueErr, HeapPages, NewErr, OutOfBoundsError, RunErr, Signature, StartErr,
Trap, ValueType, WasmValue,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use core::{fmt, future, mem, pin, slice, task};
use std::sync::Mutex;
pub struct JitPrototype {
base_components: BaseComponents,
store: wasmtime::Store<()>,
instance: wasmtime::Instance,
shared: Arc<Mutex<Shared>>,
memory: wasmtime::Memory,
memory_type: wasmtime::MemoryType,
}
struct BaseComponents {
module: wasmtime::Module,
resolved_imports: Vec<Option<usize>>,
}
impl JitPrototype {
pub fn new(
module_bytes: &[u8],
symbols: &mut dyn FnMut(&str, &str, &Signature) -> Result<usize, ()>,
) -> Result<Self, NewErr> {
let mut config = wasmtime::Config::new();
config.cranelift_nan_canonicalization(true);
config.cranelift_opt_level(wasmtime::OptLevel::Speed);
config.async_support(true);
config.wasm_backtrace_details(wasmtime::WasmBacktraceDetails::Enable);
config.wasm_threads(false);
config.wasm_reference_types(false);
config.wasm_function_references(false);
config.wasm_simd(false);
config.wasm_relaxed_simd(false);
config.wasm_bulk_memory(false);
config.wasm_multi_value(false);
config.wasm_multi_memory(false);
config.wasm_memory64(false);
config.wasm_tail_call(false);
config.wasm_component_model(false);
config.wasm_wide_arithmetic(false);
config.wasm_extended_const(false);
config.wasm_shared_everything_threads(false);
config.wasm_stack_switching(false);
let engine =
wasmtime::Engine::new(&config).map_err(|err| NewErr::InvalidWasm(err.to_string()))?;
let module = wasmtime::Module::from_binary(&engine, module_bytes)
.map_err(|err| NewErr::InvalidWasm(err.to_string()))?;
let resolved_imports = {
let mut imports = Vec::with_capacity(module.imports().len());
for import in module.imports() {
match import.ty() {
wasmtime::ExternType::Func(func_type) => {
let function_index =
match Signature::try_from(&func_type)
.ok()
.and_then(|conv_signature| {
symbols(import.module(), import.name(), &conv_signature).ok()
}) {
Some(i) => i,
None => {
return Err(NewErr::UnresolvedFunctionImport {
module_name: import.module().to_owned(),
function: import.name().to_owned(),
});
}
};
imports.push(Some(function_index));
}
wasmtime::ExternType::Global(_)
| wasmtime::ExternType::Table(_)
| wasmtime::ExternType::Tag(_) => {
return Err(NewErr::ImportTypeNotSupported);
}
wasmtime::ExternType::Memory(_) => {
imports.push(None);
}
};
}
imports
};
Self::from_base_components(BaseComponents {
module,
resolved_imports,
})
}
fn from_base_components(base_components: BaseComponents) -> Result<Self, NewErr> {
let mut store = wasmtime::Store::new(base_components.module.engine(), ());
let mut imported_memory = None;
let shared = Arc::new(Mutex::new(Shared::ExecutingStart));
let imports = {
let mut imports = Vec::with_capacity(base_components.module.imports().len());
for (module_import, resolved_function) in base_components
.module
.imports()
.zip(base_components.resolved_imports.iter())
{
match module_import.ty() {
wasmtime::ExternType::Func(func_type) => {
let function_index = resolved_function.unwrap();
let shared = shared.clone();
let expected_return_ty = func_type
.results()
.next()
.map(|v| ValueType::try_from(v).unwrap());
imports.push(wasmtime::Extern::Func(wasmtime::Func::new_async(
&mut store,
func_type,
move |mut caller, params, ret_val| {
{
let mut shared_lock = shared.try_lock().unwrap();
match mem::replace(&mut *shared_lock, Shared::Poisoned) {
Shared::OutsideFunctionCall { memory } => {
*shared_lock = Shared::EnteredFunctionCall {
function_index,
parameters: params
.iter()
.map(TryFrom::try_from)
.collect::<Result<_, _>>()
.unwrap(),
expected_return_ty,
in_interrupted_waker: None, memory: SliceRawParts(
memory.data_ptr(&caller),
memory.data_size(&caller),
),
};
}
Shared::ExecutingStart => {
return Box::new(future::ready(Err(
wasmtime::Error::new(
NewErr::StartFunctionNotSupported,
),
)));
}
_ => unreachable!(),
}
}
let shared = shared.clone();
Box::new(future::poll_fn(move |cx| {
let mut shared_lock = shared.try_lock().unwrap();
match *shared_lock {
Shared::EnteredFunctionCall {
ref mut in_interrupted_waker,
..
}
| Shared::WithinFunctionCall {
ref mut in_interrupted_waker,
..
} => {
*in_interrupted_waker = Some(cx.waker().clone());
task::Poll::Pending
}
Shared::MemoryGrowRequired {
ref memory,
additional,
} => {
memory.grow(&mut caller, additional).unwrap();
*shared_lock = Shared::WithinFunctionCall {
in_interrupted_waker: Some(cx.waker().clone()),
memory: SliceRawParts(
memory.data_ptr(&caller),
memory.data_size(&caller),
),
expected_return_ty,
};
task::Poll::Pending
}
Shared::Return {
ref mut return_value,
memory,
} => {
if let Some(returned) = return_value.take() {
assert_eq!(ret_val.len(), 1);
ret_val[0] = From::from(returned);
} else {
assert!(ret_val.is_empty());
}
*shared_lock = Shared::OutsideFunctionCall { memory };
task::Poll::Ready(Ok(()))
}
_ => unreachable!(),
}
}))
},
)));
}
wasmtime::ExternType::Global(_)
| wasmtime::ExternType::Table(_)
| wasmtime::ExternType::Tag(_) => {
unreachable!() }
wasmtime::ExternType::Memory(m) => {
if module_import.module() != "env" || module_import.name() != "memory" {
return Err(NewErr::MemoryNotNamedMemory);
}
debug_assert!(imported_memory.is_none());
imported_memory = Some(
wasmtime::Memory::new(&mut store, m)
.map_err(|_| NewErr::CouldntAllocateMemory)?,
);
imports.push(wasmtime::Extern::Memory(*imported_memory.as_ref().unwrap()));
}
};
}
imports
};
let instance = match Future::poll(
pin::pin!(wasmtime::Instance::new_async(
&mut store,
&base_components.module,
&imports
)),
&mut task::Context::from_waker(task::Waker::noop()),
) {
task::Poll::Pending => return Err(NewErr::StartFunctionNotSupported), task::Poll::Ready(Ok(i)) => i,
task::Poll::Ready(Err(err)) => return Err(NewErr::Instantiation(err.to_string())),
};
*shared.lock().unwrap() = Shared::Poisoned;
let exported_memory = if let Some(mem) = instance.get_export(&mut store, "memory") {
if let Some(mem) = mem.into_memory() {
Some(mem)
} else {
return Err(NewErr::MemoryIsntMemory);
}
} else {
None
};
let memory = match (exported_memory, imported_memory) {
(Some(_), Some(_)) => return Err(NewErr::TwoMemories),
(Some(m), None) => m,
(None, Some(m)) => m,
(None, None) => return Err(NewErr::NoMemory),
};
let memory_type = memory.ty(&store);
Ok(JitPrototype {
base_components,
store,
instance,
shared,
memory,
memory_type,
})
}
pub fn global_value(&mut self, name: &str) -> Result<u32, GlobalValueErr> {
match self.instance.get_export(&mut self.store, name) {
Some(wasmtime::Extern::Global(g)) => match g.get(&mut self.store) {
wasmtime::Val::I32(v) => Ok(u32::from_ne_bytes(v.to_ne_bytes())),
_ => Err(GlobalValueErr::Invalid),
},
_ => Err(GlobalValueErr::NotFound),
}
}
pub fn memory_max_pages(&self) -> Option<HeapPages> {
let num = self.memory.ty(&self.store).maximum()?;
match u32::try_from(num) {
Ok(n) => Some(HeapPages::new(n)),
Err(_) => None,
}
}
pub fn prepare(self) -> Prepare {
Prepare { inner: self }
}
}
impl Clone for JitPrototype {
fn clone(&self) -> Self {
JitPrototype::from_base_components(BaseComponents {
module: self.base_components.module.clone(),
resolved_imports: self.base_components.resolved_imports.clone(),
})
.unwrap()
}
}
impl fmt::Debug for JitPrototype {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("JitPrototype").finish()
}
}
pub struct Prepare {
inner: JitPrototype,
}
impl Prepare {
pub fn into_prototype(self) -> JitPrototype {
JitPrototype::from_base_components(self.inner.base_components).unwrap()
}
pub fn memory_size(&self) -> HeapPages {
let heap_pages = self.inner.memory.size(&self.inner.store);
HeapPages::new(u32::try_from(heap_pages).unwrap())
}
pub fn read_memory(
&self,
offset: u32,
size: u32,
) -> Result<impl AsRef<[u8]>, OutOfBoundsError> {
let memory_slice = self.inner.memory.data(&self.inner.store);
let start = usize::try_from(offset).map_err(|_| OutOfBoundsError)?;
let end = start
.checked_add(usize::try_from(size).map_err(|_| OutOfBoundsError)?)
.ok_or(OutOfBoundsError)?;
if end > memory_slice.len() {
return Err(OutOfBoundsError);
}
Ok(&memory_slice[start..end])
}
pub fn write_memory(&mut self, offset: u32, value: &[u8]) -> Result<(), OutOfBoundsError> {
let memory_slice = self.inner.memory.data_mut(&mut self.inner.store);
let start = usize::try_from(offset).map_err(|_| OutOfBoundsError)?;
let end = start.checked_add(value.len()).ok_or(OutOfBoundsError)?;
if end > memory_slice.len() {
return Err(OutOfBoundsError);
}
if !value.is_empty() {
memory_slice[start..end].copy_from_slice(value);
}
Ok(())
}
pub fn grow_memory(&mut self, additional: HeapPages) -> Result<(), OutOfBoundsError> {
let additional = u64::from(u32::from(additional));
self.inner
.memory
.grow(&mut self.inner.store, additional)
.map_err(|_| OutOfBoundsError)?;
Ok(())
}
pub fn start(
mut self,
function_name: &str,
params: &[WasmValue],
) -> Result<Jit, (StartErr, JitPrototype)> {
let function_to_call = match self
.inner
.instance
.get_export(&mut self.inner.store, function_name)
{
Some(export) => match export.into_func() {
Some(f) => f,
None => return Err((StartErr::NotAFunction, self.inner)),
},
None => return Err((StartErr::FunctionNotFound, self.inner)),
};
let Ok(signature) = Signature::try_from(&function_to_call.ty(&self.inner.store)) else {
return Err((StartErr::SignatureNotSupported, self.inner));
};
if params.len() != signature.parameters().len() {
return Err((StartErr::InvalidParameters, self.inner));
}
for (obtained, expected) in params.iter().zip(signature.parameters()) {
if obtained.ty() != *expected {
return Err((StartErr::InvalidParameters, self.inner));
}
}
Ok(Jit {
base_components: self.inner.base_components,
inner: JitInner::NotStarted {
store: self.inner.store,
function_to_call,
params: params.iter().map(|v| (*v).into()).collect::<Vec<_>>(),
},
shared: self.inner.shared,
memory: self.inner.memory,
memory_type: self.inner.memory_type,
})
}
}
impl fmt::Debug for Prepare {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Prepare").finish()
}
}
enum Shared {
Poisoned,
ExecutingStart,
OutsideFunctionCall {
memory: wasmtime::Memory,
},
EnteredFunctionCall {
function_index: usize,
parameters: Vec<WasmValue>,
memory: SliceRawParts,
expected_return_ty: Option<ValueType>,
in_interrupted_waker: Option<task::Waker>,
},
WithinFunctionCall {
memory: SliceRawParts,
expected_return_ty: Option<ValueType>,
in_interrupted_waker: Option<task::Waker>,
},
MemoryGrowRequired {
memory: wasmtime::Memory,
additional: u64,
},
Return {
return_value: Option<WasmValue>,
memory: wasmtime::Memory,
},
}
#[derive(Copy, Clone)]
struct SliceRawParts(*mut u8, usize);
unsafe impl Send for SliceRawParts {}
unsafe impl Sync for SliceRawParts {}
pub struct Jit {
base_components: BaseComponents,
inner: JitInner,
shared: Arc<Mutex<Shared>>,
memory: wasmtime::Memory,
memory_type: wasmtime::MemoryType,
}
enum JitInner {
Poisoned,
NotStarted {
store: wasmtime::Store<()>,
function_to_call: wasmtime::Func,
params: Vec<wasmtime::Val>,
},
Executing(BoxFuture<(wasmtime::Store<()>, ExecOutcomeValue)>),
Done(wasmtime::Store<()>),
}
type BoxFuture<T> = pin::Pin<Box<dyn Future<Output = T> + Send>>;
type ExecOutcomeValue = Result<Option<WasmValue>, wasmtime::Error>;
impl Jit {
pub fn run(&mut self, value: Option<WasmValue>) -> Result<ExecOutcome, RunErr> {
match self.inner {
JitInner::Executing(_) => {
let mut shared_lock = self.shared.try_lock().unwrap();
match mem::replace(&mut *shared_lock, Shared::Poisoned) {
Shared::WithinFunctionCall {
in_interrupted_waker,
expected_return_ty,
memory,
} => {
let provided_value_ty = value.as_ref().map(|v| v.ty());
if expected_return_ty != provided_value_ty {
*shared_lock = Shared::WithinFunctionCall {
in_interrupted_waker,
expected_return_ty,
memory,
};
return Err(RunErr::BadValueTy {
expected: expected_return_ty,
obtained: provided_value_ty,
});
}
*shared_lock = Shared::Return {
return_value: value,
memory: self.memory,
};
if let Some(waker) = in_interrupted_waker {
waker.wake();
}
}
_ => unreachable!(),
}
}
JitInner::Done(_) => return Err(RunErr::Poisoned),
JitInner::Poisoned => unreachable!(),
JitInner::NotStarted { .. } => {
if value.is_some() {
return Err(RunErr::BadValueTy {
expected: None,
obtained: value.as_ref().map(|v| v.ty()),
});
}
let (function_to_call, params, mut store) =
match mem::replace(&mut self.inner, JitInner::Poisoned) {
JitInner::NotStarted {
function_to_call,
params,
store,
} => (function_to_call, params, store),
_ => unreachable!(),
};
*self.shared.try_lock().unwrap() = Shared::OutsideFunctionCall {
memory: self.memory,
};
let has_return_value = Signature::try_from(&function_to_call.ty(&store))
.unwrap()
.return_type()
.is_some();
let function_call = Box::pin(async move {
let mut result = [wasmtime::Val::I32(0)];
let outcome = function_to_call
.call_async(
&mut store,
¶ms,
&mut result[..(if has_return_value { 1 } else { 0 })],
)
.await;
match outcome {
Ok(()) if has_return_value => {
(store, Ok(Some((&result[0]).try_into().unwrap())))
}
Ok(()) => (store, Ok(None)),
Err(err) => (store, Err(err)),
}
});
self.inner = JitInner::Executing(function_call);
}
};
let function_call = match &mut self.inner {
JitInner::Executing(f) => f,
_ => unreachable!(),
};
match Future::poll(
function_call.as_mut(),
&mut task::Context::from_waker(task::Waker::noop()),
) {
task::Poll::Ready((store, Ok(val))) => {
self.inner = JitInner::Done(store);
Ok(ExecOutcome::Finished {
return_value: Ok(val),
})
}
task::Poll::Ready((store, Err(err))) => {
self.inner = JitInner::Done(store);
Ok(ExecOutcome::Finished {
return_value: Err(Trap(err.to_string())),
})
}
task::Poll::Pending => {
let mut shared_lock = self.shared.try_lock().unwrap();
match mem::replace(&mut *shared_lock, Shared::Poisoned) {
Shared::EnteredFunctionCall {
function_index,
parameters,
memory,
expected_return_ty,
in_interrupted_waker,
} => {
*shared_lock = Shared::WithinFunctionCall {
memory,
expected_return_ty,
in_interrupted_waker,
};
Ok(ExecOutcome::Interrupted {
id: function_index,
params: parameters,
})
}
_ => unreachable!(),
}
}
}
}
pub fn memory_size(&self) -> HeapPages {
match &self.inner {
JitInner::NotStarted { store, .. } | JitInner::Done(store) => {
let heap_pages = self.memory.size(store);
HeapPages::new(u32::try_from(heap_pages).unwrap())
}
JitInner::Executing(_) => {
let size_bytes = match *self.shared.try_lock().unwrap() {
Shared::WithinFunctionCall { memory, .. } => memory.1,
_ => unreachable!(),
};
if size_bytes == 0 {
HeapPages::new(0)
} else {
HeapPages::new(1 + u32::try_from((size_bytes - 1) / (64 * 1024)).unwrap())
}
}
JitInner::Poisoned => unreachable!(),
}
}
pub fn read_memory(
&self,
offset: u32,
size: u32,
) -> Result<impl AsRef<[u8]>, OutOfBoundsError> {
let memory_slice = match &self.inner {
JitInner::NotStarted { store, .. } | JitInner::Done(store) => self.memory.data(store),
JitInner::Executing(_) => {
let memory = match *self.shared.try_lock().unwrap() {
Shared::WithinFunctionCall { memory, .. } => memory,
_ => unreachable!(),
};
unsafe { slice::from_raw_parts(memory.0, memory.1) }
}
JitInner::Poisoned => unreachable!(),
};
let start = usize::try_from(offset).map_err(|_| OutOfBoundsError)?;
let end = start
.checked_add(usize::try_from(size).map_err(|_| OutOfBoundsError)?)
.ok_or(OutOfBoundsError)?;
if end > memory_slice.len() {
return Err(OutOfBoundsError);
}
Ok(&memory_slice[start..end])
}
pub fn write_memory(&mut self, offset: u32, value: &[u8]) -> Result<(), OutOfBoundsError> {
let memory_slice = match &mut self.inner {
JitInner::NotStarted { store, .. } | JitInner::Done(store) => {
self.memory.data_mut(store)
}
JitInner::Executing(_) => {
let memory = match *self.shared.try_lock().unwrap() {
Shared::WithinFunctionCall { memory, .. } => memory,
_ => unreachable!(),
};
unsafe { slice::from_raw_parts_mut(memory.0, memory.1) }
}
JitInner::Poisoned => unreachable!(),
};
let start = usize::try_from(offset).map_err(|_| OutOfBoundsError)?;
let end = start.checked_add(value.len()).ok_or(OutOfBoundsError)?;
if end > memory_slice.len() {
return Err(OutOfBoundsError);
}
if !value.is_empty() {
memory_slice[start..end].copy_from_slice(value);
}
Ok(())
}
pub fn grow_memory(&mut self, additional: HeapPages) -> Result<(), OutOfBoundsError> {
let additional = u64::from(u32::from(additional));
match &mut self.inner {
JitInner::NotStarted { store, .. } | JitInner::Done(store) => {
self.memory
.grow(store, additional)
.map_err(|_| OutOfBoundsError)?;
}
JitInner::Poisoned => unreachable!(),
JitInner::Executing(function_call) => {
let mut shared_lock = self.shared.try_lock().unwrap();
match mem::replace(&mut *shared_lock, Shared::Poisoned) {
Shared::WithinFunctionCall {
memory,
expected_return_ty,
in_interrupted_waker,
} => {
let current_pages = if memory.1 == 0 {
0
} else {
1 + u64::try_from((memory.1 - 1) / (64 * 1024)).unwrap()
};
if self
.memory_type
.maximum()
.map_or(false, |max| current_pages + additional > max)
{
*shared_lock = Shared::WithinFunctionCall {
memory,
expected_return_ty,
in_interrupted_waker,
};
return Err(OutOfBoundsError);
}
if let Some(waker) = in_interrupted_waker {
waker.wake();
}
*shared_lock = Shared::MemoryGrowRequired {
memory: self.memory,
additional,
}
}
_ => unreachable!(),
}
drop(shared_lock);
match Future::poll(
function_call.as_mut(),
&mut task::Context::from_waker(task::Waker::noop()),
) {
task::Poll::Ready(_) => unreachable!(),
task::Poll::Pending => {
debug_assert!(matches!(
*self.shared.try_lock().unwrap(),
Shared::WithinFunctionCall { .. }
));
}
}
}
}
Ok(())
}
pub fn into_prototype(self) -> JitPrototype {
JitPrototype::from_base_components(self.base_components).unwrap()
}
}
impl fmt::Debug for Jit {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Jit").finish()
}
}
unsafe impl Sync for Jit {}