#![cfg(feature = "async")]
use std::sync::Arc;
use std::time::Duration;
use futures_util::stream::TryStreamExt;
use tokio::sync::Mutex;
use luaur_rt::{Error, Function, Lua, MultiValue, Result, Thread, Value};
async fn sleep_ms(ms: u64) {
tokio::time::sleep(Duration::from_millis(ms)).await;
}
#[tokio::test]
async fn test_async_function() -> Result<()> {
let lua = Lua::new();
let f = lua
.create_async_function(|_lua, (a, b, c): (i64, i64, i64)| async move { Ok((a + b) * c) })?;
lua.globals().set("f", f)?;
let res: i64 = lua.load("f(1, 2, 3)").eval_async().await?;
assert_eq!(res, 9);
Ok(())
}
#[tokio::test]
async fn test_async_function_wrap() -> Result<()> {
let lua = Lua::new();
let f = Function::wrap_async(|s: String| async move {
tokio::task::yield_now().await;
Ok::<_, Error>(s)
});
lua.globals().set("f", f)?;
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
assert_eq!(res, "hello");
let ferr =
Function::wrap_async(|| async move { Err::<(), _>(Error::runtime("some async error")) });
lua.globals().set("ferr", ferr)?;
lua.load(
r#"
local ok, err = pcall(ferr)
assert(not ok and tostring(err):find("some async error"))
"#,
)
.exec_async()
.await
.unwrap();
Ok(())
}
#[tokio::test]
async fn test_async_function_wrap_raw() -> Result<()> {
let lua = Lua::new();
let f = Function::wrap_raw_async(|s: String| async move {
tokio::task::yield_now().await;
s
});
lua.globals().set("f", f)?;
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
assert_eq!(res, "hello");
let ferr = Function::wrap_raw_async(|| async move {
tokio::task::yield_now().await;
Err::<(), _>("some error")
});
lua.globals().set("ferr", ferr)?;
let (_, err): (Value, String) = lua.load(r#"ferr()"#).eval_async().await?;
assert_eq!(err, "some error");
Ok(())
}
#[tokio::test]
async fn test_async_sleep() -> Result<()> {
let lua = Lua::new();
let sleep = lua.create_async_function(move |_lua, n: u64| async move {
sleep_ms(n).await;
Ok(format!("elapsed:{}ms", n))
})?;
lua.globals().set("sleep", sleep)?;
let res: String = lua.load(r"return sleep(...)").call_async(100).await?;
assert_eq!(res, "elapsed:100ms");
Ok(())
}
#[tokio::test]
async fn test_async_call() -> Result<()> {
let lua = Lua::new();
let hello = lua.create_async_function(|_lua, name: String| async move {
sleep_ms(10).await;
Ok(format!("hello, {}!", name))
})?;
match hello.call::<()>("alex") {
Err(Error::RuntimeError(_)) => {}
err => panic!("expected `RuntimeError`, got {err:?}"),
};
assert_eq!(hello.call_async::<String>("alex").await?, "hello, alex!");
let sum = lua.create_function(|_lua, (a, b): (i64, i64)| return Ok(a + b))?;
assert_eq!(sum.call_async::<i64>((5, 1)).await?, 6);
Ok(())
}
#[tokio::test]
async fn test_async_call_many_returns() -> Result<()> {
let lua = Lua::new();
let hello = lua.create_async_function(|_lua, ()| async move {
sleep_ms(10).await;
Ok(("a", "b", "c", 1))
})?;
let vals = hello.call_async::<MultiValue>(()).await?;
assert_eq!(vals.len(), 4);
assert_eq!(vals[0].to_string()?, "a");
assert_eq!(vals[1].to_string()?, "b");
assert_eq!(vals[2].to_string()?, "c");
assert_eq!(vals[3], Value::Integer(1));
Ok(())
}
#[tokio::test]
async fn test_async_bind_call() -> Result<()> {
let lua = Lua::new();
let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move {
tokio::task::yield_now().await;
Ok(a + b)
})?;
let plus_10 = sum.bind(10)?;
lua.globals().set("plus_10", plus_10)?;
assert_eq!(lua.load("plus_10(-1)").eval_async::<i64>().await?, 9);
assert_eq!(lua.load("plus_10(1)").eval_async::<i64>().await?, 11);
Ok(())
}
#[tokio::test]
async fn test_async_handle_yield() -> Result<()> {
let lua = Lua::new();
let sum = lua.create_async_function(|_lua, (a, b): (i64, i64)| async move {
sleep_ms(10).await;
Ok(a + b)
})?;
lua.globals().set("sleep_sum", sum)?;
let res: String = lua
.load(
r#"
sum = sleep_sum(6, 7)
assert(sum == 13)
coroutine.yield("in progress")
return "done"
"#,
)
.call_async(())
.await?;
assert_eq!(res, "done");
let min = lua
.load(
r#"
function (a, b)
coroutine.yield("ignore me")
if a < b then return a else return b end
end
"#,
)
.eval::<Function>()?;
assert_eq!(min.call_async::<i64>((-1, 1)).await?, -1);
Ok(())
}
#[tokio::test]
async fn test_async_multi_return_nil() -> Result<()> {
let lua = Lua::new();
lua.globals().set(
"func",
lua.create_async_function(|_, _: ()| async { Ok((Option::<String>::None, "error")) })?,
)?;
lua.load(
r#"
local ok, err = func()
assert(err == "error")
"#,
)
.exec_async()
.await
}
#[tokio::test]
async fn test_async_return_async_closure() -> Result<()> {
let lua = Lua::new();
let f = lua.create_async_function(|lua, a: i64| async move {
sleep_ms(10).await;
let g = lua.create_async_function(move |_, b: i64| async move {
sleep_ms(10).await;
return Ok(a + b);
})?;
Ok(g)
})?;
lua.globals().set("f", f)?;
let res: i64 = lua
.load("local g = f(1); return g(2) + g(3)")
.call_async(())
.await?;
assert_eq!(res, 7);
Ok(())
}
#[tokio::test]
async fn test_async_thread_stream() -> Result<()> {
let lua = Lua::new();
let thread = lua.create_thread(
lua.load(
r#"
function (sum)
for i = 1,10 do
sum = sum + i
coroutine.yield(sum)
end
return sum
end
"#,
)
.eval()?,
)?;
let mut stream = thread.into_async::<i64>(1)?;
let mut sum = 0;
while let Some(n) = stream.try_next().await? {
sum += n;
}
assert_eq!(sum, 286);
Ok(())
}
#[tokio::test]
async fn test_async_thread() -> Result<()> {
let lua = Lua::new();
let cnt = Arc::new(10); let cnt2 = cnt.clone();
let f = lua.create_async_function(move |_lua, ()| {
let cnt3 = cnt2.clone();
async move {
sleep_ms(*cnt3.as_ref()).await;
Ok("done")
}
})?;
let res: String = lua.create_thread(f)?.into_async(())?.await?;
assert_eq!(res, "done");
Ok(())
}
#[test]
fn test_async_thread_capture() -> Result<()> {
let lua = Lua::new();
let f = lua.create_async_function(move |_lua, v: Value| async move {
tokio::task::yield_now().await;
drop(v);
Ok(())
})?;
let thread = lua.create_thread(f)?;
thread.resume::<()>("abc").unwrap();
drop(thread);
Ok(())
}
#[tokio::test]
async fn test_async_thread_error() -> Result<()> {
let lua = Lua::new();
let result = lua
.load("function x(...) error(...) end x(...)")
.set_name("chunk")
.call_async::<()>("test error")
.await;
assert!(
matches!(result, Err(Error::RuntimeError(cause)) if cause.contains("test error")),
"improper error from dead thread"
);
Ok(())
}
#[tokio::test]
async fn test_async_terminate_drop_lua() -> Result<()> {
let mutex = Arc::new(Mutex::new(0u32));
{
let lua = Lua::new();
let mutex2 = mutex.clone();
let func = lua.create_async_function(move |lua, ()| {
let mutex = mutex2.clone();
async move {
let _guard = mutex.lock().await;
sleep_ms(100).await;
drop(lua); Ok(())
}
})?;
let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(())).await;
}
assert!(mutex.try_lock().is_ok());
Ok(())
}
#[tokio::test]
async fn test_async_task_abort() -> Result<()> {
let lua = Lua::new();
let sleep = lua.create_async_function(move |_lua, n: u64| async move {
sleep_ms(n).await;
Ok(())
})?;
lua.globals().set("sleep", sleep)?;
let local = tokio::task::LocalSet::new();
local
.run_until(async {
let lua2 = lua.clone();
let jh = tokio::task::spawn_local(async move {
lua2.load("sleep(200) result = 'done'")
.exec_async()
.await
.unwrap();
});
sleep_ms(100).await; jh.abort();
})
.await;
local.await;
assert_eq!(lua.globals().get::<Value>("result")?, Value::Nil);
Ok(())
}
#[test]
fn test_async_yield_with() -> Result<()> {
let lua = Lua::new();
let func = lua.create_async_function(|lua, (mut a, mut b): (i32, i32)| async move {
let zero = lua.yield_with::<MultiValue>(()).await?;
assert!(zero.is_empty());
let one = lua.yield_with::<MultiValue>(a + b).await?;
assert_eq!(one.len(), 1);
for _ in 0..3 {
(a, b) = lua.yield_with((a + b, a * b)).await?;
}
Ok((0, 0))
})?;
let thread = lua.create_thread(func)?;
let zero = thread.resume::<MultiValue>((2, 3))?; assert!(zero.is_empty());
let one = thread.resume::<i32>(())?; assert_eq!(one, 5);
assert_eq!(thread.resume::<(i32, i32)>(1)?, (5, 6)); assert_eq!(thread.resume::<(i32, i32)>((10, 11))?, (21, 110));
assert_eq!(thread.resume::<(i32, i32)>((11, 12))?, (23, 132));
assert_eq!(thread.resume::<(i32, i32)>((12, 13))?, (0, 0));
assert!(thread.is_finished());
Ok(())
}
#[tokio::test]
async fn test_async_current_thread() -> Result<()> {
let lua = Lua::new();
let get_inner_thread = lua.create_async_function(move |lua, ()| async move {
let f =
lua.create_async_function(move |lua, ()| async move { Ok(lua.current_thread()) })?;
f.call_async::<Thread>(()).await
})?;
let inner_thread = get_inner_thread.call_async::<Thread>(()).await?;
assert_eq!(inner_thread, lua.current_thread());
Ok(())
}