use std::panic::{self, AssertUnwindSafe};
use std::sync::{Arc, Mutex};
use mlua::{Debug, HookTriggers, Lua, VmState};
pub(crate) type PanicCause = Arc<Mutex<Option<String>>>;
pub(crate) trait LineHook: 'static {
fn on_line(&self, lua: &Lua, debug: &Debug) -> mlua::Result<VmState>;
}
impl<F> LineHook for F
where
F: Fn(&Lua, &Debug) -> mlua::Result<VmState> + 'static,
{
fn on_line(&self, lua: &Lua, debug: &Debug) -> mlua::Result<VmState> {
(self)(lua, debug)
}
}
#[derive(Clone)]
pub(crate) struct HookHandle {
#[cfg_attr(not(test), allow(dead_code))]
panic_cause: PanicCause,
}
impl HookHandle {
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn panic_cause(&self) -> &PanicCause {
&self.panic_cause
}
}
fn apply_jit_off(lua: &Lua) -> mlua::Result<()> {
lua.load("jit.off()").exec()
}
pub(crate) fn install<H>(lua: &Lua, handler: H) -> mlua::Result<HookHandle>
where
H: LineHook,
{
apply_jit_off(lua)?;
let panic_cause: PanicCause = Arc::new(Mutex::new(None));
let cb_cause = Arc::clone(&panic_cause);
lua.set_global_hook(HookTriggers::EVERY_LINE, move |lua, debug| {
let result = panic::catch_unwind(AssertUnwindSafe(|| handler.on_line(lua, debug)));
match result {
Ok(Ok(_vm_state)) => Ok(VmState::Continue),
Ok(Err(_handler_err)) => Ok(VmState::Continue),
Err(payload) => {
record_panic_cause(&cb_cause, payload);
Ok(VmState::Continue)
}
}
})?;
Ok(HookHandle { panic_cause })
}
fn record_panic_cause(cause: &PanicCause, payload: Box<dyn std::any::Any + Send>) {
let Ok(mut guard) = cause.lock() else {
return;
};
if guard.is_some() {
return;
}
let recovered = payload
.downcast_ref::<&'static str>()
.map(|s| (*s).to_string())
.or_else(|| payload.downcast_ref::<String>().cloned());
*guard = Some(recovered.unwrap_or_else(|| "hook handler panicked".to_string()));
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use mlua::{Debug, Lua, LuaOptions, StdLib, VmState};
use super::*;
fn build_all_safe_vm() -> Lua {
unsafe { Lua::unsafe_new_with(StdLib::ALL_SAFE, LuaOptions::default()) }
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Fired {
source: String,
line: u32,
}
fn fired_from_debug(debug: &Debug) -> Fired {
let src = debug.source();
let source = src
.source
.as_ref()
.map(|c| c.as_ref().to_string())
.or_else(|| src.short_src.as_ref().map(|c| c.as_ref().to_string()))
.unwrap_or_default();
let line = debug.current_line().unwrap_or(0) as u32;
Fired { source, line }
}
fn recording_handler(
sink: Arc<Mutex<Vec<Fired>>>,
) -> impl Fn(&Lua, &Debug) -> mlua::Result<VmState> {
move |_lua: &Lua, debug: &Debug| {
if let Ok(mut g) = sink.lock() {
g.push(fired_from_debug(debug));
}
Ok(VmState::Continue)
}
}
fn lines_for_source(events: &[Fired], source: &str) -> Vec<u32> {
events
.iter()
.filter(|e| e.source == source)
.map(|e| e.line)
.collect()
}
fn run_scene_like_scenario(lua: &Lua, coroutine_count: usize) -> mlua::Result<()> {
for i in 0..coroutine_count {
let chunk_name = format!("@scene_co_{i}");
let body = format!(
"\
local marker = {i}
local acc = marker * 2
coroutine.yield()
acc = acc + marker
return acc
"
);
let scene_fn: mlua::Function =
lua.load(&body).set_name(&chunk_name).into_function()?;
let driver: mlua::Function = lua
.load(
"\
local scene_fn = ...
local co = coroutine.create(scene_fn)
while coroutine.status(co) ~= 'dead' do
local ok, err = coroutine.resume(co)
if not ok then error(err) end
end
",
)
.set_name("@scene_driver")
.into_function()?;
driver.call::<()>(scene_fn)?;
}
Ok(())
}
#[test]
fn install_applies_engine_wide_jit_off() {
let lua = build_all_safe_vm();
let before: bool = lua
.load("return (jit.status())")
.eval()
.expect("jit.status() must be callable on an ALL_SAFE VM");
assert!(before, "JIT engine must be ON before install (premise)");
let _h = install(&lua, |_lua: &Lua, _debug: &Debug| Ok(VmState::Continue))
.expect("install must succeed");
let after: bool = lua
.load("return (jit.status())")
.eval()
.expect("jit.status() must be callable after install");
assert!(
!after,
"no-arg jit.off() must disable the global JIT engine after install (R5.2)"
);
}
#[test]
fn install_keeps_std_debug_sandboxed() {
let lua = build_all_safe_vm();
let _h = install(&lua, |_lua: &Lua, _debug: &Debug| Ok(VmState::Continue))
.expect("install must succeed");
let debug_is_nil: bool = lua
.load("return debug == nil")
.eval()
.expect("eval should succeed");
assert!(
debug_is_nil,
"install must NOT expose std_debug (sandbox maintained, R5.3)"
);
}
#[test]
fn hook_fires_across_dynamic_coroutines() {
const N: usize = 3;
let lua = build_all_safe_vm();
let sink: Arc<Mutex<Vec<Fired>>> = Arc::new(Mutex::new(Vec::new()));
let _h = install(&lua, recording_handler(Arc::clone(&sink)))
.expect("install must succeed");
run_scene_like_scenario(&lua, N).expect("scene-like scenario must run");
lua.remove_global_hook();
let events = sink.lock().unwrap();
for i in 0..N {
let name = format!("@scene_co_{i}");
let lines = lines_for_source(&events, &name);
assert!(
!lines.is_empty(),
"hook must fire inside coroutine {i} body ({name}). \
recorded sources: {:?}",
events.iter().map(|e| &e.source).collect::<Vec<_>>()
);
assert!(
lines.iter().any(|&l| l >= 4),
"hook must keep firing after coroutine.yield/resume in {name} \
(post-yield lines). got: {lines:?}"
);
}
}
#[test]
fn firing_is_attributable_to_install() {
const CHUNK: &str = "\
local a = 1
local b = a + 1
local c = b + 1
return c
";
const SRC: &str = "@attribution_chunk";
let no_install_sink: Arc<Mutex<Vec<Fired>>> = Arc::new(Mutex::new(Vec::new()));
{
let lua = build_all_safe_vm();
lua.load("jit.off()").exec().expect("jit.off must run");
lua.load(CHUNK)
.set_name(SRC)
.exec()
.expect("workload must run without a hook");
assert!(
no_install_sink.lock().unwrap().is_empty(),
"with NO install there must be no recorded firing (zero-cost premise, R5.2)"
);
}
let install_sink: Arc<Mutex<Vec<Fired>>> = Arc::new(Mutex::new(Vec::new()));
{
let lua = build_all_safe_vm();
let _h = install(&lua, recording_handler(Arc::clone(&install_sink)))
.expect("install must succeed");
lua.load(CHUNK)
.set_name(SRC)
.exec()
.expect("workload must run with the hook");
lua.remove_global_hook();
}
let events = install_sink.lock().unwrap();
let lines = lines_for_source(&events, SRC);
assert!(
!lines.is_empty(),
"WITH install the hook must fire on the workload (R5.4). got: {lines:?}"
);
for expected in [1u32, 2, 3] {
assert!(
lines.contains(&expected),
"expected line {expected} must fire with install. got: {lines:?}"
);
}
}
#[test]
fn hook_panic_is_captured_and_vm_survives() {
use std::sync::mpsc;
use std::time::Duration;
const PANIC_SOURCE: &str = "@panic_scenario";
const PANIC_LINE: u32 = 2;
const PANIC_MSG: &str = "injected hook panic for task 1.3";
const WATCHDOG: Duration = Duration::from_secs(10);
let handle = std::thread::spawn(move || -> Result<Option<String>, String> {
let lua = build_all_safe_vm();
let cause_seen: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let cause_for_handler = Arc::clone(&cause_seen);
let handle = install(&lua, move |_lua: &Lua, debug: &Debug| {
let line = debug.current_line().unwrap_or(0) as u32;
let src = debug
.source()
.source
.as_ref()
.map(|c| c.as_ref().to_string())
.unwrap_or_default();
if src == PANIC_SOURCE && line == PANIC_LINE {
if let Ok(mut g) = cause_for_handler.lock() {
*g = Some(format!("{PANIC_MSG} at {src}:{line}"));
}
panic!("{PANIC_MSG}");
}
Ok(VmState::Continue)
})
.map_err(|e| e.to_string())?;
let chunk = "\
local a = 1
local b = a + 1
return b
";
lua.load(chunk)
.set_name(PANIC_SOURCE)
.exec()
.map_err(|e| format!("VM must survive a captured hook panic: {e}"))?;
let recorded = handle.panic_cause().lock().unwrap().clone();
Ok(recorded)
});
let (done_tx, done_rx) = mpsc::channel();
std::thread::spawn(move || {
let _ = done_tx.send(handle.join());
});
let joined = done_rx
.recv_timeout(WATCHDOG)
.expect("VM host thread must finish (no hang) after a hook panic is captured");
let thread_outcome =
joined.expect("VM host thread must terminate gracefully, NOT abort on a hook panic");
let recorded_cause = thread_outcome
.expect("VM must survive and side channel must be readable")
.expect("a captured hook panic must record a cause into the side channel");
assert!(
recorded_cause.contains(PANIC_MSG),
"the recorded panic cause must carry the injected message. got: {recorded_cause:?}"
);
}
#[test]
fn handler_error_is_swallowed_and_vm_continues() {
const ERR_SOURCE: &str = "@handler_err_chunk";
let lua = build_all_safe_vm();
let handle = install(&lua, move |_lua: &Lua, debug: &Debug| {
let src = debug
.source()
.source
.as_ref()
.map(|c| c.as_ref().to_string())
.unwrap_or_default();
if src == ERR_SOURCE {
return Err(mlua::Error::runtime("injected handler error"));
}
Ok(VmState::Continue)
})
.expect("install must succeed");
let result: i64 = lua
.load(
"\
local a = 1
local b = a + 1
return b + 1
",
)
.set_name(ERR_SOURCE)
.eval()
.expect("a handler Err must be swallowed: the chunk keeps executing");
lua.remove_global_hook();
assert_eq!(
result, 3,
"the chunk must run to completion with the correct result despite handler Errs"
);
assert!(
handle.panic_cause().lock().unwrap().is_none(),
"a handler Err is NOT a panic: the panic side channel must stay None"
);
}
#[test]
fn record_panic_cause_prefers_prerecorded_cause_over_payload() {
let cause: PanicCause = Arc::new(Mutex::new(Some("pre-recorded specific cause".into())));
record_panic_cause(&cause, Box::new("payload that must NOT win"));
assert_eq!(
cause.lock().unwrap().as_deref(),
Some("pre-recorded specific cause"),
"a pre-recorded cause must be preserved (preference step 1), not overwritten"
);
}
#[test]
fn record_panic_cause_recovers_static_str_payload() {
let cause: PanicCause = Arc::new(Mutex::new(None));
record_panic_cause(&cause, Box::new("static str payload"));
assert_eq!(
cause.lock().unwrap().as_deref(),
Some("static str payload"),
"a &'static str payload must be recovered (preference step 2)"
);
}
#[test]
fn record_panic_cause_recovers_string_payload() {
let cause: PanicCause = Arc::new(Mutex::new(None));
record_panic_cause(&cause, Box::new(String::from("owned String payload")));
assert_eq!(
cause.lock().unwrap().as_deref(),
Some("owned String payload"),
"an owned String payload must be recovered (preference step 2)"
);
}
#[test]
fn record_panic_cause_falls_back_to_generic_marker() {
let cause: PanicCause = Arc::new(Mutex::new(None));
record_panic_cause(&cause, Box::new(42_i32));
assert_eq!(
cause.lock().unwrap().as_deref(),
Some("hook handler panicked"),
"a non-string payload must record the generic marker (preference step 3)"
);
}
#[test]
fn record_panic_cause_tolerates_poisoned_lock() {
let cause: PanicCause = Arc::new(Mutex::new(None));
let poisoner = Arc::clone(&cause);
let _ = panic::catch_unwind(AssertUnwindSafe(move || {
let _guard = poisoner.lock().unwrap();
panic!("poison the lock");
}));
assert!(cause.lock().is_err(), "the lock must be poisoned (premise)");
record_panic_cause(&cause, Box::new("payload after poison"));
}
#[test]
fn hook_panic_without_prerecord_still_records_a_cause() {
use std::sync::mpsc;
use std::time::Duration;
const PANIC_SOURCE: &str = "@panic_bare";
const PANIC_LINE: u32 = 2;
const WATCHDOG: Duration = Duration::from_secs(10);
let handle = std::thread::spawn(move || -> Result<Option<String>, String> {
let lua = build_all_safe_vm();
let handle = install(&lua, move |_lua: &Lua, debug: &Debug| {
let line = debug.current_line().unwrap_or(0) as u32;
let src = debug
.source()
.source
.as_ref()
.map(|c| c.as_ref().to_string())
.unwrap_or_default();
if src == PANIC_SOURCE && line == PANIC_LINE {
panic!("bare panic with no pre-record");
}
Ok(VmState::Continue)
})
.map_err(|e| e.to_string())?;
let chunk = "\
local a = 1
local b = a + 1
return b
";
lua.load(chunk)
.set_name(PANIC_SOURCE)
.exec()
.map_err(|e| format!("VM must survive: {e}"))?;
Ok(handle.panic_cause().lock().unwrap().clone())
});
let (done_tx, done_rx) = mpsc::channel();
std::thread::spawn(move || {
let _ = done_tx.send(handle.join());
});
let joined = done_rx
.recv_timeout(WATCHDOG)
.expect("VM host thread must finish (no hang)");
let recorded = joined
.expect("VM host thread must terminate gracefully")
.expect("VM must survive and side channel readable")
.expect("the wrapper must record SOME cause even without a pre-record");
assert!(
!recorded.is_empty(),
"the recorded cause must be non-empty (payload recovery or generic marker)"
);
}
}