#![cfg(feature = "async")]
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use crate::error::{Error, Result};
use crate::function::Function;
use crate::multi::MultiValue;
use crate::state::Lua;
use crate::sync::MaybeSend;
use crate::sys::*;
use crate::table::Table;
use crate::thread::{AsyncResume, Thread};
use crate::traits::{FromLuaMulti, IntoLuaMulti};
#[cfg(feature = "send")]
pub(crate) type LocalResultFuture = Pin<Box<dyn Future<Output = Result<MultiValue>> + Send>>;
#[cfg(not(feature = "send"))]
pub(crate) type LocalResultFuture = Pin<Box<dyn Future<Output = Result<MultiValue>>>>;
#[cfg(feature = "send")]
pub(crate) type AsyncCallback = Box<dyn Fn(Lua, MultiValue) -> LocalResultFuture + Send>;
#[cfg(not(feature = "send"))]
pub(crate) type AsyncCallback = Box<dyn Fn(Lua, MultiValue) -> LocalResultFuture>;
static PENDING_MARK: u8 = 0;
static YIELD_MARK: u8 = 0;
static TERMINATE_MARK: u8 = 0;
#[inline]
pub(crate) fn poll_pending() -> *mut c_void {
&PENDING_MARK as *const u8 as *mut c_void
}
#[inline]
pub(crate) fn poll_yield() -> *mut c_void {
&YIELD_MARK as *const u8 as *mut c_void
}
#[inline]
pub(crate) fn poll_terminate() -> *mut c_void {
&TERMINATE_MARK as *const u8 as *mut c_void
}
use std::collections::HashMap;
#[derive(Default)]
struct AsyncVmState {
waker: Option<Waker>,
ownership: HashMap<usize, usize>,
}
#[inline]
unsafe fn vm_key(state: *mut lua_State) -> usize {
unsafe { (*state).global as usize }
}
#[cfg(feature = "send")]
mod async_store {
use super::AsyncVmState;
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
static STORE: LazyLock<Mutex<HashMap<usize, AsyncVmState>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub(super) fn with<R>(f: impl FnOnce(&mut HashMap<usize, AsyncVmState>) -> R) -> R {
let mut guard = STORE.lock().unwrap_or_else(|e| e.into_inner());
f(&mut guard)
}
}
#[cfg(not(feature = "send"))]
mod async_store {
use super::AsyncVmState;
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static STORE: RefCell<HashMap<usize, AsyncVmState>> = RefCell::new(HashMap::new());
}
pub(super) fn with<R>(f: impl FnOnce(&mut HashMap<usize, AsyncVmState>) -> R) -> R {
STORE.with(|s| f(&mut s.borrow_mut()))
}
}
pub(crate) struct WakerGuard {
key: usize,
prev: Option<Waker>,
}
impl Drop for WakerGuard {
fn drop(&mut self) {
let key = self.key;
let prev = self.prev.take();
async_store::with(|m| {
if let Some(s) = m.get_mut(&key) {
s.waker = prev;
}
});
}
}
pub(crate) fn set_current_waker(state: *mut lua_State, waker: Waker) -> WakerGuard {
let key = unsafe { vm_key(state) };
let prev = async_store::with(|m| m.entry(key).or_default().waker.replace(waker));
WakerGuard { key, prev }
}
pub(crate) fn register_implicit_thread(co_state: *mut lua_State, owner: *mut lua_State) {
let key = unsafe { vm_key(co_state) };
let owner = owner as usize;
async_store::with(|m| {
let s = m.entry(key).or_default();
let root = s.ownership.get(&owner).copied().unwrap_or(owner);
s.ownership.insert(co_state as usize, root);
});
}
pub(crate) fn unregister_implicit_thread(co_state: *mut lua_State) {
let key = unsafe { vm_key(co_state) };
async_store::with(|m| {
if let Some(s) = m.get_mut(&key) {
s.ownership.remove(&(co_state as usize));
}
});
}
pub(crate) fn implicit_thread_owner(state: *mut lua_State) -> Option<*mut lua_State> {
let key = unsafe { vm_key(state) };
async_store::with(|m| {
m.get(&key)
.and_then(|s| s.ownership.get(&(state as usize)).copied())
.map(|p| p as *mut lua_State)
})
}
pub(crate) fn clear_async_state(state: *mut lua_State) {
let key = unsafe { vm_key(state) };
async_store::with(|m| {
m.remove(&key);
});
}
fn current_waker(state: *mut lua_State) -> Waker {
let key = unsafe { vm_key(state) };
async_store::with(|m| m.get(&key).and_then(|s| s.waker.clone())).unwrap_or_else(noop_waker)
}
fn noop_waker() -> Waker {
use std::task::{RawWaker, RawWakerVTable};
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(std::ptr::null(), &VTABLE), |_| {}, |_| {}, |_| {}, );
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
struct AsyncPollUpvalue {
data: Option<LocalResultFuture>,
}
unsafe extern "C" fn poll_upvalue_dtor(ptr: *mut c_void) {
if !ptr.is_null() {
unsafe { core::ptr::drop_in_place(ptr as *mut AsyncPollUpvalue) };
}
}
struct AsyncCallbackUpvalue {
callback: AsyncCallback,
}
unsafe extern "C" fn callback_upvalue_dtor(ptr: *mut c_void) {
if !ptr.is_null() {
unsafe { core::ptr::drop_in_place(ptr as *mut AsyncCallbackUpvalue) };
}
}
unsafe fn get_future_c(state: *mut lua_State) -> c_int {
unsafe {
let ud = lua_touserdata(state, lua_upvalueindex(1));
if ud.is_null() {
return raise(state, "luaur-rt: missing async callback upvalue");
}
let upvalue = &*(ud as *const AsyncCallbackUpvalue);
let lua = Lua::from_borrowed(state);
let nargs = lua_gettop(state);
let mut args = MultiValue::with_capacity(nargs.max(0) as usize);
for i in 1..=nargs {
match lua.value_from_stack(i) {
Ok(v) => args.push_back(v),
Err(e) => return raise(state, &e.to_string()),
}
}
let fut = (upvalue.callback)(lua.clone(), args);
let storage = lua_newuserdatadtor(
state,
core::mem::size_of::<AsyncPollUpvalue>(),
Some(poll_upvalue_dtor),
);
if storage.is_null() {
return raise(state, "luaur-rt: failed to allocate async future userdata");
}
core::ptr::write(
storage as *mut AsyncPollUpvalue,
AsyncPollUpvalue { data: Some(fut) },
);
1
}
}
unsafe fn poll_c(state: *mut lua_State) -> c_int {
unsafe {
let ud = lua_touserdata(state, 1);
if ud.is_null() {
return raise(state, "luaur-rt: missing async future argument");
}
let future = &mut *(ud as *mut AsyncPollUpvalue);
let nargs = lua_gettop(state);
if nargs == 2 && lua_tolightuserdata(state, -1) == poll_terminate() {
future.data.take(); lua_pushinteger(state, -1);
return 1;
}
let lua = Lua::from_borrowed(state);
let waker = current_waker(state);
let mut cx = std::task::Context::from_waker(&waker);
let poll = match future.data.as_mut() {
Some(f) => f.as_mut().poll(&mut cx),
None => return raise_destructed(state),
};
use std::task::Poll;
match poll {
Poll::Pending => {
let fut_nvals = lua_gettop(state) - 1; if fut_nvals >= 3 && lua_tolightuserdata(state, -3) == poll_yield() {
lua_pushnil(state);
lua_replace(state, -4);
return 3;
}
lua_pushnil(state);
lua_pushlightuserdatatagged(state, poll_pending(), 0);
2
}
Poll::Ready(result) => {
let results = match result {
Ok(r) => r,
Err(e) => {
return raise(state, &e.to_string());
}
};
let nres = results.len() as c_int;
if nres < 3 {
lua_pushinteger(state, nres);
for v in results.iter() {
if let Err(e) = lua.push_value(v) {
return raise(state, &e.to_string());
}
}
1 + nres
} else {
lua_pushinteger(state, nres);
let seq = match lua.create_sequence_from(results) {
Ok(t) => t,
Err(e) => return raise(state, &e.to_string()),
};
seq.push_to_stack();
2
}
}
}
}
}
unsafe fn unpack_c(state: *mut lua_State) -> c_int {
unsafe {
let mut isnum: c_int = 0;
let n = lua_tointegerx(state, 2, &mut isnum as *mut c_int);
if lua_checkstack(state, n.saturating_add(1)) == 0 {
return raise(state, "luaur-rt: stack overflow unpacking async results");
}
for i in 1..=n {
lua_rawgeti(state, 1, i);
}
n
}
}
unsafe fn raise(state: *mut lua_State, msg: &str) -> c_int {
unsafe {
lua_pushlstring(state, msg.as_ptr() as *const c_char, msg.len());
lua_error(state)
}
}
unsafe fn raise_destructed(state: *mut lua_State) -> c_int {
unsafe { crate::callback::raise_structured_error(state, Error::CallbackDestructed) }
}
const POLLER_SOURCE: &str = r#"
local poll, yield = poll, yield
local future = get_future(...)
local nres, res, res2 = poll(future)
while true do
if nres ~= nil then
if nres == 0 then
return
elseif nres == 1 then
return res
elseif nres == 2 then
return res, res2
elseif nres < 0 then
yield()
else
return unpack(res, nres)
end
end
if res2 == nil then
nres, res, res2 = poll(future, yield(res))
elseif res2 == 0 then
nres, res, res2 = poll(future, yield())
elseif res2 == 1 then
nres, res, res2 = poll(future, yield(res))
else
nres, res, res2 = poll(future, yield(unpack(res, res2)))
end
end
"#;
unsafe fn push_c_closure_with_upvalue(
state: *mut lua_State,
f: unsafe fn(*mut lua_State) -> c_int,
name: &core::ffi::CStr,
) {
unsafe {
lua_pushcclosurek(state, Some(f), name.as_ptr(), 1, None);
}
}
pub(crate) fn create_async_callback(lua: &Lua, callback: AsyncCallback) -> Result<Function> {
let state = lua.state();
let get_future = unsafe {
let storage = lua_newuserdatadtor(
state,
core::mem::size_of::<AsyncCallbackUpvalue>(),
Some(callback_upvalue_dtor),
);
if storage.is_null() {
return Err(Error::runtime(
"luaur-rt: failed to allocate async callback userdata",
));
}
core::ptr::write(
storage as *mut AsyncCallbackUpvalue,
AsyncCallbackUpvalue { callback },
);
push_c_closure_with_upvalue(state, get_future_c, c"luaur-rt-get-future");
Function::from_ref(lua.pop_ref())
};
let poll = unsafe {
lua_pushcclosurek(state, Some(poll_c), c"luaur-rt-poll".as_ptr(), 0, None);
Function::from_ref(lua.pop_ref())
};
let unpack = unsafe {
lua_pushcclosurek(state, Some(unpack_c), c"luaur-rt-unpack".as_ptr(), 0, None);
Function::from_ref(lua.pop_ref())
};
let coroutine: Table = lua.globals().get("coroutine")?;
let yield_fn: Function = coroutine.get("yield")?;
let env = lua.create_table();
env.set("get_future", get_future)?;
env.set("poll", poll)?;
env.set("yield", yield_fn)?;
env.set("unpack", unpack)?;
lua.load(POLLER_SOURCE)
.set_name("__luaur_async_poll")
.set_environment(env)
.into_function()
}
impl Lua {
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub fn yield_with<R: crate::traits::FromLuaMulti + 'static>(
&self,
args: impl IntoLuaMulti,
) -> impl std::future::Future<Output = Result<R>> + 'static {
let lua = self.clone();
let args = args.into_lua_multi(self);
async move {
let mut args = Some(args?);
std::future::poll_fn(move |_cx| {
use std::task::Poll;
match args.take() {
Some(values) => {
let state = lua.state();
unsafe {
lua_pushlightuserdatatagged(state, poll_yield(), 0);
let count = values.len() as c_int;
if count <= 1 {
match values.iter().next() {
Some(v) => {
if lua.push_value(v).is_err() {
return Poll::Ready(Err(Error::runtime(
"luaur-rt: failed to push yield value",
)));
}
}
None => lua_pushnil(state),
}
} else {
match lua.create_sequence_from(values) {
Ok(t) => t.push_to_stack(),
Err(e) => return Poll::Ready(Err(e)),
}
}
lua_pushinteger(state, count);
}
Poll::Pending
}
None => {
let state = lua.state();
let result = unsafe {
let top = lua_gettop(state);
let mut results = MultiValue::with_capacity((top.max(1) - 1) as usize);
let mut err = None;
for i in 2..=top {
match lua.value_from_stack(i) {
Ok(v) => results.push_back(v),
Err(e) => {
err = Some(e);
break;
}
}
}
if top > 1 {
lua_settop(state, 1);
}
match err {
Some(e) => Err(e),
None => R::from_lua_multi(results, &lua),
}
};
Poll::Ready(result)
}
}
})
.await
}
}
}
#[cfg(feature = "send")]
pub type BoxedAsyncFnFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
#[cfg(not(feature = "send"))]
pub type BoxedAsyncFnFuture<T> = Pin<Box<dyn Future<Output = T>>>;
pub trait LuaNativeAsyncFn<A: FromLuaMulti> {
type Output;
fn call(&self, args: A) -> BoxedAsyncFnFuture<Self::Output>;
}
macro_rules! impl_lua_native_async_fn {
($($A:ident),*) => {
impl<FN, $($A,)* Fut, R> LuaNativeAsyncFn<($($A,)*)> for FN
where
FN: Fn($($A,)*) -> Fut + MaybeSend + 'static,
($($A,)*): FromLuaMulti,
Fut: Future<Output = R> + MaybeSend + 'static,
{
type Output = R;
#[allow(non_snake_case)]
fn call(&self, args: ($($A,)*)) -> BoxedAsyncFnFuture<R> {
let ($($A,)*) = args;
Box::pin(self($($A,)*))
}
}
};
}
impl_lua_native_async_fn!();
impl_lua_native_async_fn!(A);
impl_lua_native_async_fn!(A, B);
impl_lua_native_async_fn!(A, B, C);
impl_lua_native_async_fn!(A, B, C, D);
impl_lua_native_async_fn!(A, B, C, D, E);
impl_lua_native_async_fn!(A, B, C, D, E, F);
impl_lua_native_async_fn!(A, B, C, D, E, F, G);
impl_lua_native_async_fn!(A, B, C, D, E, F, G, H);
pub struct WrappedAsync<F, A, FR, R> {
func: F,
_marker: PhantomData<fn(A) -> (FR, R)>,
}
impl<F, A, FR, R> WrappedAsync<F, A, FR, R>
where
F: Fn(Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + MaybeSend + 'static,
R: crate::traits::IntoLuaMulti,
{
pub(crate) fn new(func: F) -> Self {
WrappedAsync {
func,
_marker: PhantomData,
}
}
}
impl<F, A, FR, R> crate::traits::IntoLua for WrappedAsync<F, A, FR, R>
where
F: Fn(Lua, A) -> FR + MaybeSend + 'static,
A: FromLuaMulti,
FR: Future<Output = Result<R>> + MaybeSend + 'static,
R: crate::traits::IntoLuaMulti,
{
fn into_lua(self, lua: &Lua) -> Result<crate::value::Value> {
let func = self.func;
let f = lua.create_async_function(move |lua, a: A| func(lua, a))?;
Ok(crate::value::Value::Function(f))
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct AsyncThread<R> {
thread: Thread,
args: Option<MultiValue>,
done: bool,
implicit: bool,
_ret: PhantomData<fn() -> R>,
}
impl<R> AsyncThread<R> {
pub(crate) fn new(thread: Thread, args: MultiValue) -> AsyncThread<R> {
AsyncThread {
thread,
args: Some(args),
done: false,
implicit: false,
_ret: PhantomData,
}
}
pub(crate) fn set_implicit(&mut self, implicit: bool) {
self.implicit = implicit;
}
fn take_args(&mut self) -> MultiValue {
self.args.take().unwrap_or_default()
}
}
impl<R> Drop for AsyncThread<R> {
fn drop(&mut self) {
if !self.done {
self.thread.terminate_async();
}
if self.implicit {
unregister_implicit_thread(self.thread.state());
}
}
}
impl<R: FromLuaMulti> Future for AsyncThread<R> {
type Output = Result<R>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.done {
return Poll::Ready(Err(Error::CoroutineUnresumable));
}
let lua = this.thread.lua();
let _wg = set_current_waker(lua.state(), cx.waker().clone());
let args = this.take_args();
match this.thread.resume_for_async(args) {
Err(e) => {
this.done = true;
Poll::Ready(Err(e))
}
Ok(AsyncResume::Pending) => {
Poll::Pending
}
Ok(AsyncResume::Yielded(_vals)) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Ok(AsyncResume::Returned(vals)) => {
this.done = true;
Poll::Ready(R::from_lua_multi(vals, &lua))
}
}
}
}
impl<R: FromLuaMulti> futures_util::stream::Stream for AsyncThread<R> {
type Item = Result<R>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.done {
return Poll::Ready(None);
}
let lua = this.thread.lua();
let _wg = set_current_waker(lua.state(), cx.waker().clone());
let args = this.take_args();
match this.thread.resume_for_async(args) {
Err(e) => {
this.done = true;
Poll::Ready(Some(Err(e)))
}
Ok(AsyncResume::Pending) => Poll::Pending,
Ok(AsyncResume::Yielded(vals)) => {
Poll::Ready(Some(R::from_lua_multi(vals, &lua)))
}
Ok(AsyncResume::Returned(vals)) => {
this.done = true;
Poll::Ready(Some(R::from_lua_multi(vals, &lua)))
}
}
}
}