use crate::runtime::Value;
use crate::vm::error::LuaError;
use crate::vm::exec::Vm;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
pub type AsyncNativeFn = fn(
*mut Vm,
func_slot: u32,
nargs: u32,
) -> Pin<Box<dyn Future<Output = Result<u32, LuaError>>>>;
pub(crate) enum DispatchOutcome {
Complete(Vec<Value>),
Error(LuaError),
BudgetExhausted,
AsyncNativeAwaiting(Pin<Box<dyn Future<Output = Result<u32, LuaError>>>>),
}
impl Vm {
pub fn create_async_native(&mut self, f: AsyncNativeFn) -> Value {
let raw_fn: crate::runtime::value::NativeFn =
unsafe { std::mem::transmute(f) };
Value::Native(self.heap.new_async_native(raw_fn, Box::new([])))
}
pub fn set_async_native(
&mut self,
name: &str,
f: AsyncNativeFn,
) -> Result<(), LuaError> {
let v = self.create_async_native(f);
self.set_global(name, v)
}
pub fn eval_async<'vm>(&'vm mut self, src: &str) -> EvalFuture<'vm> {
self.eval_async_chunk(src, "=eval")
}
pub fn eval_async_chunk<'vm>(&'vm mut self, src: &str, name: &str) -> EvalFuture<'vm> {
EvalFuture {
vm: self,
state: EvalState::Initial {
src: src.to_string(),
name: name.to_string(),
},
saved_jit_enabled: None,
saved_async_slice: None,
}
}
pub fn set_async_slice(&mut self, n: i64) {
self.async_slice_size = n.max(1);
}
pub fn async_slice(&self) -> i64 {
self.async_slice_size
}
pub(crate) fn drive_one(
&mut self,
bootstrap: Option<Value>,
entry_depth: usize,
) -> DispatchOutcome {
self.async_mode = true;
self.instr_budget = Some(self.async_slice_size);
let raw = match bootstrap {
Some(closure_val) => {
self.call_value(closure_val, &[])
}
None => {
self.exec_with_async(entry_depth)
}
};
match raw {
Ok(values) => DispatchOutcome::Complete(values),
Err(e) => {
if self.pending_async_native_fut.is_some() {
let fut = self
.pending_async_native_fut
.take()
.expect("checked above");
DispatchOutcome::AsyncNativeAwaiting(fut)
} else if self.host_yield_pending {
self.host_yield_pending = false;
DispatchOutcome::BudgetExhausted
} else {
DispatchOutcome::Error(e)
}
}
}
}
pub(crate) fn commit_async_native_result(&mut self, nret: u32) {
let ctx = self
.pending_async_native_ctx
.take()
.expect("commit_async_native_result without a pending ctx");
self.finish_results(ctx.func_slot, nret, ctx.nresults);
self.maybe_collect_garbage(self.top);
}
}
pub struct EvalFuture<'vm> {
vm: &'vm mut Vm,
state: EvalState,
saved_jit_enabled: Option<bool>,
#[allow(dead_code)]
saved_async_slice: Option<i64>,
}
enum EvalState {
Initial {
src: String,
name: String,
},
Running {
entry_depth: usize,
first_slice: bool,
closure: Option<Value>,
},
AwaitingNative {
entry_depth: usize,
fut: Pin<Box<dyn Future<Output = Result<u32, LuaError>>>>,
},
Done,
}
impl<'vm> Future for EvalFuture<'vm> {
type Output = Result<Vec<Value>, LuaError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
loop {
if let EvalState::Initial { src, name } = &this.state {
if this.saved_jit_enabled.is_none() {
this.saved_jit_enabled = Some(this.vm.jit_enabled());
this.vm.set_jit_enabled(false);
}
let cl = match this.vm.load(src.as_bytes(), name.as_bytes()) {
Ok(c) => c,
Err(syntax) => {
this.vm.set_error_kind(crate::vm::error::LuaErrorKind::Syntax);
this.vm.set_error_source(name.clone(), syntax.line);
let msg = format!("{}", syntax);
let s = this.vm.intern_str(&msg);
if let Some(prev) = this.saved_jit_enabled.take() {
this.vm.set_jit_enabled(prev);
}
this.vm.async_mode = false;
this.vm.async_waker = None;
this.state = EvalState::Done;
return Poll::Ready(Err(LuaError(Value::Str(s))));
}
};
let entry_depth = this.vm.frame_count().saturating_add(1);
this.state = EvalState::Running {
entry_depth,
first_slice: true,
closure: Some(Value::Closure(cl)),
};
}
match &mut this.state {
EvalState::Running {
entry_depth,
first_slice,
closure,
} => {
this.vm.async_waker = Some(cx.waker().clone());
let (bootstrap_arg, ed) = if *first_slice {
(closure.take(), *entry_depth)
} else {
(None, *entry_depth)
};
let ed_for_resume = *entry_depth;
let outcome = this.vm.drive_one(bootstrap_arg, ed);
*first_slice = false;
match outcome {
DispatchOutcome::Complete(values) => {
if let Some(prev) = this.saved_jit_enabled.take() {
this.vm.set_jit_enabled(prev);
}
this.vm.async_mode = false;
this.vm.async_waker = None;
this.state = EvalState::Done;
return Poll::Ready(Ok(values));
}
DispatchOutcome::Error(e) => {
if let Some(prev) = this.saved_jit_enabled.take() {
this.vm.set_jit_enabled(prev);
}
this.vm.async_mode = false;
this.vm.async_waker = None;
this.state = EvalState::Done;
return Poll::Ready(Err(e));
}
DispatchOutcome::BudgetExhausted => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
DispatchOutcome::AsyncNativeAwaiting(fut) => {
this.state = EvalState::AwaitingNative {
entry_depth: ed_for_resume,
fut,
};
continue;
}
}
}
EvalState::AwaitingNative { entry_depth, fut } => {
match fut.as_mut().poll(cx) {
Poll::Ready(Ok(nret)) => {
let ed = *entry_depth;
this.vm.commit_async_native_result(nret);
this.state = EvalState::Running {
entry_depth: ed,
first_slice: false,
closure: None,
};
continue;
}
Poll::Ready(Err(e)) => {
this.vm.pending_async_native_ctx = None;
if let Some(prev) = this.saved_jit_enabled.take() {
this.vm.set_jit_enabled(prev);
}
this.vm.async_mode = false;
this.vm.async_waker = None;
this.state = EvalState::Done;
return Poll::Ready(Err(e));
}
Poll::Pending => return Poll::Pending,
}
}
EvalState::Initial { .. } => unreachable!("transitioned above"),
EvalState::Done => panic!("EvalFuture polled after Poll::Ready"),
}
}
}
}
impl<'vm> Drop for EvalFuture<'vm> {
fn drop(&mut self) {
if let Some(prev) = self.saved_jit_enabled.take() {
self.vm.set_jit_enabled(prev);
}
self.vm.async_mode = false;
self.vm.async_waker = None;
self.vm.host_yield_pending = false;
self.vm.pending_async_native_fut = None;
self.vm.pending_async_native_ctx = None;
}
}