use std::time::{Duration, Instant};
struct HookGuard<'a>(&'a mlua::Lua);
impl Drop for HookGuard<'_> {
fn drop(&mut self) {
self.0.remove_hook();
}
}
pub fn with_timeout<F, R>(lua: &mlua::Lua, duration: Duration, f: F) -> mlua::Result<R>
where
F: FnOnce(&mlua::Lua) -> mlua::Result<R>,
{
let start = Instant::now();
let _ = lua.set_hook(
mlua::HookTriggers::new().every_nth_instruction(128),
move |_lua, _debug| {
if start.elapsed() >= duration {
Err(mlua::Error::RuntimeError("execution timed out".to_string()))
} else {
Ok(mlua::VmState::Continue)
}
},
);
let _guard = HookGuard(lua);
f(lua)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infinite_loop_times_out() {
let lua = mlua::Lua::new();
let result = with_timeout(&lua, Duration::from_millis(100), |lua| {
lua.load("while true do end").exec()
});
assert!(result.is_err());
}
#[test]
fn test_fast_code_completes() {
let lua = mlua::Lua::new();
let result = with_timeout(&lua, Duration::from_secs(5), |lua| {
lua.load("return 1 + 1").eval::<i32>()
});
assert_eq!(result.unwrap(), 2);
}
#[test]
fn test_non_timeout_errors_propagate() {
let lua = mlua::Lua::new();
let result = with_timeout(&lua, Duration::from_secs(5), |lua| {
lua.load(r#"error("custom error")"#).exec()
});
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("custom error"));
}
#[test]
fn test_hook_removed_after_execution() {
let lua = mlua::Lua::new();
let _ = with_timeout(&lua, Duration::from_millis(50), |lua| {
lua.load("return 1").eval::<i32>()
});
let result: Result<(), mlua::Error> = lua
.load("local x = 0; for i = 1, 1000000 do x = x + 1 end")
.exec();
assert!(result.is_ok());
}
#[test]
fn test_callback_error_chain_detected() {
let lua = mlua::Lua::new();
let inner_fn = lua
.create_function(|lua, ()| {
lua.load("while true do end").exec()?;
Ok(())
})
.unwrap();
lua.globals().set("inner_fn", inner_fn).unwrap();
let result = with_timeout(&lua, Duration::from_millis(100), |lua| {
lua.load("inner_fn()").exec()
});
assert!(result.is_err());
}
}