use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use mlua::prelude::*;
use mlua_isle::{AsyncIsle, CancelToken, IsleError};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use crate::bus::{AckResult, EventBus, Handler};
use crate::error::BlockError;
use crate::host::HostContext;
const BUS_DISPATCH_FN: &str = "__bus_dispatch";
const BUS_HANDLERS_TBL: &str = "__bus_handlers";
const BUS_ON_ANY_GLOBAL: &str = "__bus_on_any";
struct LuaHandler {
isle: Arc<AsyncIsle>,
}
#[async_trait]
impl Handler for LuaHandler {
async fn call(&self, kind: String, id: String, payload: Value, meta: Value) -> AckResult {
let payload_str = serde_json::to_string(&payload).map_err(|e| {
tracing::error!(%kind, %id, error = %e, "bus: payload JSON encode failed");
BlockError::Bus(format!("payload encode: {e}"))
})?;
let meta_str = serde_json::to_string(&meta).map_err(|e| {
tracing::error!(%kind, %id, error = %e, "bus: meta JSON encode failed");
BlockError::Bus(format!("meta encode: {e}"))
})?;
let args: [&str; 4] = [&kind, &id, &payload_str, &meta_str];
let task = self.isle.spawn_coroutine_call(BUS_DISPATCH_FN, &args);
struct CancelOnDrop(CancelToken);
impl Drop for CancelOnDrop {
fn drop(&mut self) {
self.0.cancel();
}
}
let guard = CancelOnDrop(task.cancel_token().clone());
let result_str = task.await.map_err(|e| {
tracing::error!(%kind, %id, error = %e, "bus: Lua dispatch failed");
match e {
IsleError::Cancelled => BlockError::Bus("handler cancelled".into()),
IsleError::Shutdown => BlockError::Bus("isle shut down".into()),
other => BlockError::Bus(format!("isle error: {other}")),
}
})?;
std::mem::forget(guard);
if result_str.is_empty() {
return Ok(Value::Null);
}
match serde_json::from_str::<Value>(&result_str) {
Ok(v) => Ok(v),
Err(e) => {
tracing::warn!(
%kind, %id, error = %e,
"bus: handler return value is not valid JSON; falling back to string"
);
Ok(Value::String(result_str))
}
}
}
}
pub fn register(lua: &Lua, ctx: &HostContext) -> LuaResult<()> {
let bus_tbl = lua.create_table()?;
let event_bus_for_on = Arc::clone(&ctx.event_bus);
let event_bus_for_on_any = Arc::clone(&ctx.event_bus);
let event_bus_for_serve = Arc::clone(&ctx.event_bus);
let handler_isle_for_on = Arc::clone(&ctx.handler_isle);
let handler_isle_for_on_any = Arc::clone(&ctx.handler_isle);
bus_tbl.set(
"on",
lua.create_async_function(move |_, (kind, func): (String, LuaFunction)| {
let handler_isle = Arc::clone(&handler_isle_for_on);
let event_bus = Arc::clone(&event_bus_for_on);
async move {
if func.info().what != "Lua" {
return Err(LuaError::external(
"bus.on: handler must be a pure Lua function (C functions and Rust-bound callbacks are not supported)",
));
}
let bytecode = func.dump(true);
if bytecode.is_empty() {
return Err(LuaError::external(
"bus.on: Function::dump returned empty bytecode (handler not serializable)",
));
}
let kind_for_exec = kind.clone();
let bytecode_name = format!("@bus_handler[{kind_for_exec}]");
handler_isle
.exec(move |lua| {
let loaded: LuaFunction = lua
.load(bytecode.as_slice())
.set_mode(mlua::ChunkMode::Binary)
.set_name(&bytecode_name)
.into_function()
.map_err(|e| IsleError::Lua(format!("bus.on load: {e}")))?;
let tbl: LuaTable = lua
.globals()
.get(BUS_HANDLERS_TBL)
.map_err(|e| IsleError::Lua(format!("bus.on handlers tbl: {e}")))?;
tbl.set(kind_for_exec.as_str(), loaded)
.map_err(|e| IsleError::Lua(format!("bus.on set: {e}")))?;
Ok(String::new())
})
.await
.map_err(|e| {
tracing::error!(%kind, error = %e, "bus.on: handler isle load failed");
LuaError::external(format!("bus.on: handler isle load failed: {e}"))
})?;
let handler: Arc<dyn Handler> = Arc::new(LuaHandler {
isle: Arc::clone(&handler_isle),
});
let mut guard = event_bus
.lock()
.map_err(|_| LuaError::external("bus mutex poisoned"))?;
match guard.as_mut() {
Some(bus) => bus
.on(kind.clone(), handler)
.map_err(|e| LuaError::external(format!("bus.on: {e}")))?,
None => {
return Err(LuaError::external(
"bus.on: bus.serve() has already taken ownership; register handlers before calling bus.serve()",
));
}
}
drop(guard);
Ok(())
}
})?,
)?;
bus_tbl.set(
"on_any",
lua.create_async_function(move |_, func: LuaFunction| {
let handler_isle = Arc::clone(&handler_isle_for_on_any);
let event_bus = Arc::clone(&event_bus_for_on_any);
async move {
if func.info().what != "Lua" {
return Err(LuaError::external(
"bus.on_any: handler must be a pure Lua function (C functions and Rust-bound callbacks are not supported)",
));
}
let bytecode = func.dump(true);
if bytecode.is_empty() {
return Err(LuaError::external(
"bus.on_any: Function::dump returned empty bytecode (handler not serializable)",
));
}
let bytecode_name = "@bus_handler[__on_any]".to_string();
handler_isle
.exec(move |lua| {
let loaded: LuaFunction = lua
.load(bytecode.as_slice())
.set_mode(mlua::ChunkMode::Binary)
.set_name(&bytecode_name)
.into_function()
.map_err(|e| IsleError::Lua(format!("bus.on_any load: {e}")))?;
lua.globals()
.set(BUS_ON_ANY_GLOBAL, loaded)
.map_err(|e| IsleError::Lua(format!("bus.on_any set: {e}")))?;
Ok(String::new())
})
.await
.map_err(|e| {
tracing::error!(error = %e, "bus.on_any: handler isle load failed");
LuaError::external(format!("bus.on_any: handler isle load failed: {e}"))
})?;
let handler: Arc<dyn Handler> = Arc::new(LuaHandler {
isle: Arc::clone(&handler_isle),
});
let mut guard = event_bus
.lock()
.map_err(|_| LuaError::external("bus mutex poisoned"))?;
match guard.as_mut() {
Some(bus) => bus
.on_any(handler)
.map_err(|e| LuaError::external(format!("bus.on_any: {e}")))?,
None => {
return Err(LuaError::external(
"bus.on_any: bus.serve() has already taken ownership; register handlers before calling bus.serve()",
));
}
}
drop(guard);
Ok(())
}
})?,
)?;
let serving = Arc::new(AtomicBool::new(false));
bus_tbl.set(
"serve",
lua.create_async_function(move |_, ()| {
let event_bus = Arc::clone(&event_bus_for_serve);
let serving = Arc::clone(&serving);
async move {
if serving.swap(true, Ordering::SeqCst) {
return Err(LuaError::external("bus.serve: already running"));
}
let bus = {
let mut guard = event_bus
.lock()
.map_err(|_| LuaError::external("bus mutex poisoned"))?;
match guard.take() {
Some(b) => b,
None => {
serving.store(false, Ordering::SeqCst);
return Err(LuaError::external("bus.serve: bus already consumed"));
}
}
};
let shutdown = CancellationToken::new();
let signal_task = spawn_signal_task(shutdown.clone());
let grace_ms = crate::bridge::config::task_grace_ms();
let run_result =
run_with_grace(bus, shutdown.clone(), Duration::from_millis(grace_ms)).await;
signal_task.abort();
if let Err(e) = run_result {
tracing::error!(error = %e, "bus.serve: dispatcher loop returned error");
return Err(LuaError::external(format!("bus.serve: {e}")));
}
tracing::info!("bus.serve: dispatcher loop exited cleanly");
Ok(())
}
})?,
)?;
lua.globals().set("bus", bus_tbl)?;
Ok(())
}
pub(crate) fn install_bus_dispatcher_on_handler_isle(lua: &Lua) -> LuaResult<()> {
lua.globals().set(BUS_HANDLERS_TBL, lua.create_table()?)?;
lua.globals().set(BUS_ON_ANY_GLOBAL, LuaValue::Nil)?;
let src = r#"
local BUS_HANDLERS_TBL = "__bus_handlers"
local BUS_ON_ANY_GLOBAL = "__bus_on_any"
return function(kind, id, payload_json, meta_json)
local handlers = _G[BUS_HANDLERS_TBL]
local h = handlers and handlers[kind]
if type(h) ~= "function" then
h = _G[BUS_ON_ANY_GLOBAL]
end
if type(h) ~= "function" then
error("no Lua handler for kind `" .. tostring(kind) .. "`")
end
local ok_payload, payload = pcall(std.json.decode, payload_json)
if not ok_payload then
error("payload decode: " .. tostring(payload))
end
local ok_meta, meta = pcall(std.json.decode, meta_json)
if not ok_meta then
error("meta decode: " .. tostring(meta))
end
local ev = {
kind = kind,
id = id,
payload = payload,
meta = meta,
}
local ret = h(ev)
if ret == nil then
return ""
end
return std.json.encode(ret)
end
"#;
let dispatch: LuaFunction = lua
.load(src)
.set_name("@agent_block:__bus_dispatch")
.eval()?;
lua.globals().set(BUS_DISPATCH_FN, dispatch)?;
Ok(())
}
async fn run_with_grace(
mut bus: EventBus,
shutdown: CancellationToken,
grace: Duration,
) -> Result<(), BlockError> {
let run_fut = bus.run(shutdown.clone());
tokio::pin!(run_fut);
tokio::select! {
res = &mut run_fut => res,
_ = shutdown.cancelled() => {
match tokio::time::timeout(grace, &mut run_fut).await {
Ok(res) => res,
Err(_) => {
tracing::warn!(
grace_ms = grace.as_millis() as u64,
"bus.serve: grace window exceeded; forcing exit"
);
Ok(())
}
}
}
}
}
fn spawn_signal_task(shutdown: CancellationToken) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let term = match signal(SignalKind::terminate()) {
Ok(s) => Some(s),
Err(e) => {
tracing::error!(error = %e, "bus.serve: SIGTERM install failed; Ctrl+C only");
None
}
};
match term {
Some(mut term) => {
tokio::select! {
_ = term.recv() => tracing::info!("bus.serve: SIGTERM received"),
sig = tokio::signal::ctrl_c() => match sig {
Ok(()) => tracing::info!("bus.serve: Ctrl+C received"),
Err(e) => tracing::error!(error = %e, "bus.serve: ctrl_c error"),
},
}
}
None => {
if let Err(e) = tokio::signal::ctrl_c().await {
tracing::error!(error = %e, "bus.serve: ctrl_c error");
} else {
tracing::info!("bus.serve: Ctrl+C received");
}
}
}
}
#[cfg(not(unix))]
{
if let Err(e) = tokio::signal::ctrl_c().await {
tracing::error!(error = %e, "bus.serve: ctrl_c error");
} else {
tracing::info!("bus.serve: Ctrl+C received");
}
}
shutdown.cancel();
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn dispatcher_resolves_kind_and_encodes_return() {
let lua = Lua::new();
mlua_batteries::register_all(&lua, "std").unwrap();
install_bus_dispatcher_on_handler_isle(&lua).unwrap();
lua.load(
r#"
__bus_handlers["mesh"] = function(ev)
return { echoed = ev.payload.value, id = ev.id }
end
"#,
)
.exec()
.unwrap();
let dispatch: LuaFunction = lua.globals().get(BUS_DISPATCH_FN).unwrap();
let payload = serde_json::to_string(&json!({"value": 42})).unwrap();
let meta = serde_json::to_string(&json!({"from": "peer"})).unwrap();
let out: String = dispatch
.call(("mesh", "evt-1", payload.as_str(), meta.as_str()))
.unwrap();
let got: Value = serde_json::from_str(&out).unwrap();
assert_eq!(got, json!({"echoed": 42, "id": "evt-1"}));
}
#[test]
fn dispatcher_falls_back_to_on_any() {
let lua = Lua::new();
mlua_batteries::register_all(&lua, "std").unwrap();
install_bus_dispatcher_on_handler_isle(&lua).unwrap();
lua.load(
r#"
__bus_on_any = function(ev)
return { from_any = ev.kind }
end
"#,
)
.exec()
.unwrap();
let dispatch: LuaFunction = lua.globals().get(BUS_DISPATCH_FN).unwrap();
let out: String = dispatch.call(("custom", "e1", "{}", "{}")).unwrap();
let got: Value = serde_json::from_str(&out).unwrap();
assert_eq!(got, json!({"from_any": "custom"}));
}
#[test]
fn dispatcher_errors_when_no_handler_registered() {
let lua = Lua::new();
mlua_batteries::register_all(&lua, "std").unwrap();
install_bus_dispatcher_on_handler_isle(&lua).unwrap();
let dispatch: LuaFunction = lua.globals().get(BUS_DISPATCH_FN).unwrap();
let err = dispatch
.call::<String>(("nope", "e1", "{}", "{}"))
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("no Lua handler for kind `nope`"),
"unexpected error: {msg}"
);
}
#[test]
fn dispatcher_reports_invalid_payload_json() {
let lua = Lua::new();
mlua_batteries::register_all(&lua, "std").unwrap();
install_bus_dispatcher_on_handler_isle(&lua).unwrap();
lua.load(r#"__bus_handlers["x"] = function() return nil end"#)
.exec()
.unwrap();
let dispatch: LuaFunction = lua.globals().get(BUS_DISPATCH_FN).unwrap();
let err = dispatch
.call::<String>(("x", "e1", "not-json", "{}"))
.unwrap_err();
assert!(
err.to_string().contains("payload decode"),
"unexpected error: {err}"
);
}
#[test]
fn bytecode_round_trip_to_second_lua_vm_dispatches() {
let src = Lua::new();
let func: LuaFunction = src
.load(
r#"
return function(ev)
return { got = ev.payload.value, kind = ev.kind }
end
"#,
)
.eval()
.unwrap();
assert_eq!(func.info().what, "Lua");
let bytecode = func.dump(true);
assert!(!bytecode.is_empty(), "Lua function dump must be non-empty");
let dst = Lua::new();
mlua_batteries::register_all(&dst, "std").unwrap();
install_bus_dispatcher_on_handler_isle(&dst).unwrap();
let loaded: LuaFunction = dst
.load(bytecode.as_slice())
.set_mode(mlua::ChunkMode::Binary)
.set_name("@bus_handler[mesh]")
.into_function()
.unwrap();
let handlers: LuaTable = dst.globals().get(BUS_HANDLERS_TBL).unwrap();
handlers.set("mesh", loaded).unwrap();
let dispatch: LuaFunction = dst.globals().get(BUS_DISPATCH_FN).unwrap();
let payload = serde_json::to_string(&json!({"value": 7})).unwrap();
let out: String = dispatch
.call(("mesh", "evt-rt", payload.as_str(), "{}"))
.unwrap();
let got: Value = serde_json::from_str(&out).unwrap();
assert_eq!(got, json!({"got": 7, "kind": "mesh"}));
}
#[test]
fn c_function_is_detected_via_info_what() {
let lua = Lua::new();
let rust_fn: LuaFunction = lua.create_function(|_, ()| Ok(())).unwrap();
assert_ne!(
rust_fn.info().what,
"Lua",
"Rust-bound callbacks should not report info().what == \"Lua\""
);
let lua_fn: LuaFunction = lua.load("return function() end").eval().unwrap();
assert_eq!(lua_fn.info().what, "Lua");
}
}