use std::panic::catch_unwind;
use luaur_rt::{Error, Function, IntoLua, Lua, Result, Thread, Value};
#[test]
fn test_thread() -> Result<()> {
let lua = Lua::new();
let thread = lua.create_thread(
lua.load(
r#"
function (s)
local sum = s
for i = 1,4 do
sum = sum + coroutine.yield(sum)
end
return sum
end
"#,
)
.eval()?,
)?;
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(0)?, 0);
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(1)?, 1);
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(2)?, 3);
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(3)?, 6);
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(4)?, 10);
assert!(thread.is_finished());
let accumulate = lua.create_thread(
lua.load(
r#"
function (sum)
while true do
sum = sum + coroutine.yield(sum)
end
end
"#,
)
.eval::<Function>()?,
)?;
for i in 0..4 {
accumulate.resume::<()>(i)?;
}
assert_eq!(accumulate.resume::<i64>(4)?, 10);
assert!(accumulate.is_resumable());
assert!(accumulate.resume::<()>("error").is_err());
assert!(accumulate.is_error());
let thread = lua
.load(
r#"
coroutine.create(function ()
while true do
coroutine.yield(42)
end
end)
"#,
)
.eval::<Thread>()?;
assert!(thread.is_resumable());
assert_eq!(thread.resume::<i64>(())?, 42);
let thread: Thread = lua
.load(
r#"
coroutine.create(function(arg)
assert(arg == 42)
local yieldarg = coroutine.yield(123)
assert(yieldarg == 43)
return 987
end)
"#,
)
.eval()?;
assert_eq!(thread.resume::<u32>(42)?, 123);
assert_eq!(thread.resume::<u32>(43)?, 987);
match thread.resume::<u32>(()) {
Err(Error::CoroutineUnresumable) => {}
Err(_) => panic!("resuming dead coroutine error is not CoroutineUnresumable kind"),
_ => panic!("resuming dead coroutine did not return error"),
}
Ok(())
}
#[test]
fn test_thread_running() -> Result<()> {
let lua = Lua::new();
let thread = lua.create_thread(lua.create_function(|lua, ()| {
assert!(lua.current_thread().is_running());
let result = lua.current_thread().resume::<()>(());
assert!(
matches!(result, Err(Error::CoroutineUnresumable)),
"unexpected result: {result:?}",
);
Ok(())
})?)?;
let result = thread.resume::<()>(());
assert!(result.is_ok(), "unexpected result: {result:?}");
Ok(())
}
#[test]
fn test_thread_normal() -> Result<()> {
let lua = Lua::new();
let check_outer = lua.create_function(|lua, ()| {
let outer: Thread = lua.globals().get("outer")?;
assert!(outer.is_normal());
assert!(
matches!(outer.resume::<()>(()), Err(Error::CoroutineUnresumable)),
"resuming a `normal` thread must be unresumable",
);
Ok(())
})?;
lua.globals().set("check_outer", check_outer)?;
let outer = lua.create_thread(
lua.load(
r#"
function()
local inner = coroutine.create(function() check_outer() end)
assert(coroutine.resume(inner))
end
"#,
)
.eval()?,
)?;
lua.globals().set("outer", &outer)?;
outer.resume::<()>(())?;
assert!(outer.is_finished());
Ok(())
}
#[test]
fn test_thread_reset() -> Result<()> {
let lua = Lua::new();
let func: Function = lua
.load(r#"function(x) coroutine.yield(x + 1) end"#)
.eval()?;
let thread = lua.create_thread(lua.load("return 0").into_function()?)?;
assert!(thread.reset(func.clone()).is_ok());
for _ in 0..2 {
assert!(thread.is_resumable());
let yielded = thread.resume::<i64>(10)?;
assert_eq!(yielded, 11);
assert!(thread.is_resumable());
thread.resume::<()>(())?;
assert!(thread.is_finished());
thread.reset(func.clone())?;
}
let func: Function = lua.load(r#"function() error("test error") end"#).eval()?;
let thread = lua.create_thread(func.clone())?;
let _ = thread.resume::<()>(());
assert!(thread.is_error());
assert!(thread.reset(func.clone()).is_ok());
assert!(thread.is_resumable());
Ok(())
}
#[test]
fn test_coroutine_from_closure() -> Result<()> {
let lua = Lua::new();
let thrd_main = lua.create_function(|_, ()| Ok(()))?;
lua.globals().set("main", thrd_main)?;
let thrd: Thread = lua.load("coroutine.create(main)").eval()?;
thrd.resume::<()>(())?;
Ok(())
}
#[test]
#[cfg(not(panic = "abort"))]
fn test_coroutine_panic() {
match catch_unwind(|| -> Result<()> {
let lua = Lua::new();
let thrd_main = lua.create_function(|_, ()| -> Result<()> {
panic!("test_panic");
})?;
lua.globals().set("main", &thrd_main)?;
let thrd: Thread = lua.create_thread(thrd_main)?;
thrd.resume::<()>(())
}) {
Ok(Err(Error::RuntimeError(msg))) => assert!(msg.contains("test_panic")),
Ok(other) => panic!("unexpected coroutine result: {other:?}"),
Err(p) => assert!(*p.downcast::<&str>().unwrap() == "test_panic"),
}
}
#[test]
fn test_thread_pointer() -> Result<()> {
let lua = Lua::new();
let func = lua.load("return 123").into_function()?;
let thread = lua.create_thread(func.clone())?;
assert_eq!(thread.to_pointer(), thread.clone().to_pointer());
assert_ne!(thread.to_pointer(), lua.current_thread().to_pointer());
Ok(())
}
#[test]
fn test_thread_resume_error() -> Result<()> {
let lua = Lua::new();
let thread = lua
.load(
r#"
coroutine.create(function()
local ok, err = pcall(coroutine.yield, 123)
assert(not ok, "yield should fail")
assert(err == "myerror", "unexpected error: " .. tostring(err))
return "success"
end)
"#,
)
.eval::<Thread>()?;
assert_eq!(thread.resume::<i64>(())?, 123);
let status = thread.resume_error::<String>("myerror").unwrap();
assert_eq!(status, "success");
Ok(())
}
#[test]
fn test_thread_resume_bad_arg() -> Result<()> {
let lua = Lua::new();
struct BadArg;
impl IntoLua for BadArg {
fn into_lua(self, _lua: &Lua) -> Result<Value> {
Err(Error::runtime("bad arg"))
}
}
let f = lua.create_thread(lua.create_function(|_, ()| Ok("okay"))?)?;
let res = f.resume::<()>((123, BadArg));
assert!(matches!(res, Err(Error::RuntimeError(msg)) if msg == "bad arg"));
let res = f.resume::<String>(()).unwrap();
assert_eq!(res, "okay");
Ok(())
}