use std::cell::RefCell;
use std::future::Future;
use std::panic;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::time::Instant;
use mlua::prelude::*;
use mlua::{
Function, MultiValue, ThreadStatus, UserData, UserDataMethods, UserDataRegistry, Value,
};
use tokio::sync::oneshot;
use tokio::task::AbortHandle;
use tracing::{info_span, Instrument};
use super::cancel::race_sleep;
use super::scope::Scope;
use super::{
duration_to_ms, lua_to_string, ms_to_duration, task_config, TaskInfo, LOCAL_SCOPE, TASK_INFO,
TASK_TOKEN,
};
#[derive(Clone, Copy, Debug)]
pub enum Driver {
AsyncFn,
Coroutine,
}
pub(super) fn parse_opts(opts: Option<&LuaTable>) -> LuaResult<(Option<String>, Option<Driver>)> {
let Some(t) = opts else {
return Ok((None, None));
};
let mut name: Option<String> = None;
let mut driver: Option<Driver> = None;
for pair in t.clone().pairs::<LuaValue, LuaValue>() {
let (k, v) = pair?;
let key = match &k {
LuaValue::String(s) => s.to_str()?.to_string(),
other => {
return Err(LuaError::external(format!(
"std.task: opts keys must be strings, got {}",
other.type_name()
)));
}
};
match key.as_str() {
"name" => {
name = Some(lua_to_string(&v, "std.task: opts.name")?);
}
"driver" => {
let s = lua_to_string(&v, "std.task: opts.driver")?;
driver = Some(match s.as_str() {
"coroutine" => Driver::Coroutine,
"async_fn" | "async" => Driver::AsyncFn,
other => {
return Err(LuaError::external(format!(
"std.task: unknown driver '{other}' (expected 'async_fn' or 'coroutine')"
)));
}
});
}
other => {
return Err(LuaError::external(format!(
"std.task: unknown opts key '{other}' (expected 'name' or 'driver')"
)));
}
}
}
Ok((name, driver))
}
pub(super) enum JoinState {
Pending(oneshot::Receiver<LuaResult<Value>>),
Taken,
}
pub(super) struct Handle {
pub(super) id: String,
pub(super) name: Option<String>,
pub(super) abort: AbortHandle,
pub(super) state: JoinState,
pub(super) started_at: Instant,
}
impl UserData for Handle {
fn register(reg: &mut UserDataRegistry<Self>) {
reg.add_field_method_get("id", |_, this| Ok(this.id.clone()));
reg.add_field_method_get("name", |_, this| Ok(this.name.clone()));
reg.add_method("is_finished", |_, this, ()| Ok(this.abort.is_finished()));
reg.add_method("elapsed", |_, this, ()| {
Ok(duration_to_ms(this.started_at.elapsed()))
});
reg.add_method("abort", |_, this, ()| {
this.abort.abort();
Ok(())
});
reg.add_async_method_mut("join", |_, mut this, ()| {
let state = std::mem::replace(&mut this.state, JoinState::Taken);
async move {
match state {
JoinState::Pending(rx) => match rx.await {
Ok(res) => res,
Err(_) => Err(LuaError::external("task aborted")),
},
JoinState::Taken => Err(LuaError::external("task already joined")),
}
}
});
}
}
pub(super) fn spawn_into(
lua: &Lua,
scope: &Rc<RefCell<Scope>>,
func: Function,
opts: Option<LuaTable>,
) -> LuaResult<Handle> {
let (name, driver_opt) = parse_opts(opts.as_ref())?;
let driver = driver_opt.unwrap_or_else(|| task_config(lua).default_driver);
let token = scope.borrow().token.clone();
let (tx, rx) = oneshot::channel::<LuaResult<Value>>();
let id = format!("t{}", TASK_SEQ.with(|s| s.next()));
let info = TaskInfo {
id: id.clone(),
name: name.clone(),
};
let lua_for_cr = lua.clone();
let user_fut: Pin<Box<dyn Future<Output = LuaResult<Value>>>> = match driver {
Driver::AsyncFn => {
Box::pin(async move { func.call_async::<Value>(MultiValue::new()).await })
}
Driver::Coroutine => Box::pin(async move { run_coroutine(&lua_for_cr, func).await }),
};
let span = info_span!(
"task",
id = %id,
name = name.as_deref().unwrap_or(""),
driver = ?driver,
);
let traced = user_fut.instrument(span);
let with_info = TASK_INFO.scope(info, traced);
let with_token = TASK_TOKEN.scope(token, with_info);
let with_scope = LOCAL_SCOPE.scope(scope.clone(), with_token);
let catching = Catch {
inner: Box::pin(with_scope),
tx: Some(tx),
};
let join_handle = tokio::task::spawn_local(catching);
let abort = join_handle.abort_handle();
scope.borrow_mut().attach(join_handle);
Ok(Handle {
id,
name,
abort,
state: JoinState::Pending(rx),
started_at: Instant::now(),
})
}
thread_local! {
static TASK_SEQ: SeqGen = SeqGen::default();
}
#[derive(Default)]
struct SeqGen(std::cell::Cell<u64>);
impl SeqGen {
fn next(&self) -> u64 {
let v = self.0.get().wrapping_add(1);
self.0.set(v);
v
}
}
async fn run_coroutine(lua: &Lua, func: Function) -> LuaResult<Value> {
let thread = lua.create_thread(func)?;
loop {
if TASK_TOKEN.try_with(|t| t.is_cancelled()).unwrap_or(false) {
return Err(LuaError::external("task cancelled"));
}
let yielded: MultiValue = thread.resume(MultiValue::new())?;
match thread.status() {
ThreadStatus::Finished => {
return Ok(yielded.into_iter().next().unwrap_or(Value::Nil));
}
ThreadStatus::Resumable => {
let ctrl = yielded.into_iter().next().unwrap_or(Value::Nil);
match ctrl {
Value::Nil => tokio::task::yield_now().await,
Value::Integer(ms) => {
let dur = ms_to_duration(ms as f64, "coroutine yield")?;
race_sleep(dur).await?;
}
Value::Number(ms) => {
let dur = ms_to_duration(ms, "coroutine yield")?;
race_sleep(dur).await?;
}
other => {
return Err(LuaError::external(format!(
"coroutine yield: unsupported value type '{}' (expected nil or number)",
other.type_name()
)));
}
}
}
ThreadStatus::Running => {
return Err(LuaError::external(
"coroutine in Running state after resume (impossible)",
));
}
ThreadStatus::Error => {
return Err(LuaError::external("coroutine entered Error state"));
}
}
}
}
struct Catch<F> {
inner: Pin<Box<F>>,
tx: Option<oneshot::Sender<LuaResult<Value>>>,
}
impl<F: Future<Output = LuaResult<Value>>> Future for Catch<F> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.as_mut().get_mut();
match panic::catch_unwind(panic::AssertUnwindSafe(|| this.inner.as_mut().poll(cx))) {
Ok(Poll::Ready(result)) => {
if let Some(tx) = this.tx.take() {
let _ = tx.send(result);
}
Poll::Ready(())
}
Ok(Poll::Pending) => Poll::Pending,
Err(_) => {
if let Some(tx) = this.tx.take() {
let _ = tx.send(Err(LuaError::external("task panicked")));
}
Poll::Ready(())
}
}
}
}