use std::time::{Duration, Instant};
use mlua::prelude::*;
use mlua::{Function, MultiValue, Value};
use tracing::{debug, warn};
use super::cancel::{effective_token, race_sleep, race_yield, CancelToken};
use super::driver::{spawn_into, Handle};
use super::scope::{abort_all, current_scope, drain_scope, Scope, ScopeGuard, ScopeHandle};
use super::{ms_to_duration, task_config, LOCAL_SCOPE, TASK_INFO, TASK_TOKEN};
pub(super) fn spawn(lua: &Lua, (func, opts): (Function, Option<LuaTable>)) -> LuaResult<Handle> {
let scope = current_scope(lua)?;
spawn_into(lua, &scope, func, opts)
}
pub(super) async fn sleep(_: Lua, ms: f64) -> LuaResult<()> {
let dur = ms_to_duration(ms, "std.task.sleep")?;
race_sleep(dur).await
}
pub(super) async fn yield_now(_: Lua, _: ()) -> LuaResult<()> {
race_yield().await
}
pub(super) async fn checkpoint(_lua: Lua, _: ()) -> LuaResult<()> {
if let Some(t) = effective_token() {
if t.is_cancelled() {
return Err(LuaError::external("task cancelled"));
}
}
race_yield().await
}
pub(super) fn cancel_token(_: &Lua, _: ()) -> LuaResult<CancelToken> {
Ok(CancelToken::new())
}
pub(super) fn current(lua: &Lua, _: ()) -> LuaResult<Value> {
let info = TASK_INFO.try_with(|i| i.clone()).ok();
match info {
None => Ok(Value::Nil),
Some(i) => {
let t = lua.create_table()?;
t.set("id", i.id)?;
t.set("name", i.name)?;
let cancelled = TASK_TOKEN.try_with(|t| t.is_cancelled()).unwrap_or(false);
t.set("cancelled", cancelled)?;
Ok(Value::Table(t))
}
}
}
pub(super) async fn task_scope(_lua: Lua, args: MultiValue) -> LuaResult<Value> {
let (name, func) = parse_scope_args(&args)?;
let scope = Scope::new(name);
let scope_for_body = scope.clone();
let handle = ScopeHandle(scope.clone());
LOCAL_SCOPE
.scope(scope, async move {
let guard = ScopeGuard::new(scope_for_body.clone());
let user_result: LuaResult<Value> = func.call_async::<Value>(handle).await;
if user_result.is_err() {
scope_for_body.borrow().token.cancel();
}
drain_scope(&scope_for_body).await;
guard.disarm();
user_result
})
.await
}
fn parse_scope_args(args: &MultiValue) -> LuaResult<(Option<String>, Function)> {
let mut iter = args.iter();
let first = iter
.next()
.ok_or_else(|| LuaError::external("task.scope requires at least a function"))?;
match first {
Value::Function(f) => Ok((None, f.clone())),
Value::String(s) => {
let n = s.to_str()?.to_string();
let second = iter
.next()
.ok_or_else(|| LuaError::external("task.scope(name, fn) requires a function"))?;
match second {
Value::Function(f) => Ok((Some(n), f.clone())),
_ => Err(LuaError::external(
"task.scope: second argument must be a function",
)),
}
}
_ => Err(LuaError::external(
"task.scope: first argument must be a function or a name string",
)),
}
}
fn parse_timeout_opts(lua: &Lua, opts: Option<&LuaTable>) -> LuaResult<Duration> {
let cfg_default = Duration::from_millis(task_config(lua).grace_ms);
let Some(t) = opts else {
return Ok(cfg_default);
};
let mut grace: Option<Duration> = None;
for pair in t.clone().pairs::<LuaValue, LuaValue>() {
let (k, v) = pair?;
let key = match &k {
LuaValue::String(s) => s.to_str()?.to_string(),
other => {
return Err(LuaError::external(format!(
"task.with_timeout: opts keys must be strings, got {}",
other.type_name()
)));
}
};
match key.as_str() {
"grace_ms" => {
let ms = match v {
LuaValue::Integer(i) => i as f64,
LuaValue::Number(n) => n,
other => {
return Err(LuaError::external(format!(
"task.with_timeout: opts.grace_ms must be a number, got {}",
other.type_name()
)));
}
};
grace = Some(ms_to_duration(ms, "task.with_timeout: opts.grace_ms")?);
}
other => {
return Err(LuaError::external(format!(
"task.with_timeout: unknown opts key '{other}' (expected 'grace_ms')"
)));
}
}
}
Ok(grace.unwrap_or(cfg_default))
}
pub(super) async fn with_timeout(
lua: Lua,
(ms, func, opts): (f64, Function, Option<LuaTable>),
) -> LuaResult<Value> {
let dur = ms_to_duration(ms, "task.with_timeout")?;
let grace = parse_timeout_opts(&lua, opts.as_ref())?;
let scope = Scope::new(None);
let scope_for_body = scope.clone();
let handle = ScopeHandle(scope.clone());
LOCAL_SCOPE
.scope(scope, async move {
let guard = ScopeGuard::new(scope_for_body.clone());
let user_fut = func.call_async::<Value>(handle);
let timed = tokio::time::timeout(dur, user_fut).await;
let user_result: LuaResult<Value> = match timed {
Ok(r) => r,
Err(_) => {
debug!(
target: "mlua_batteries::task",
timeout_ms = ms,
grace_ms = grace.as_millis() as u64,
"with_timeout: deadline exceeded, issuing cooperative cancel",
);
scope_for_body.borrow().token.cancel();
Err(LuaError::external(format!(
"task.with_timeout: exceeded {ms} ms"
)))
}
};
if user_result.is_err() {
scope_for_body.borrow().token.cancel();
let drain_start = Instant::now();
let drained = tokio::time::timeout(grace, drain_scope(&scope_for_body)).await;
if drained.is_err() {
let children_total = scope_for_body.borrow().children.len();
warn!(
target: "mlua_batteries::task",
grace_ms = grace.as_millis() as u64,
elapsed_ms = drain_start.elapsed().as_millis() as u64,
children_total,
"with_timeout: grace expired, hard-aborting non-cooperative children",
);
abort_all(&scope_for_body);
drain_scope(&scope_for_body).await;
} else {
debug!(
target: "mlua_batteries::task",
elapsed_ms = drain_start.elapsed().as_millis() as u64,
"with_timeout: cooperative drain completed within grace",
);
}
} else {
drain_scope(&scope_for_body).await;
}
guard.disarm();
user_result
})
.await
}