use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::body::Body;
use axum::http::Response;
use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
use tracing::warn;
use crate::config::{HookConfig, HookErrorBehavior, HookMode};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub method: String,
pub path: String,
pub query: String,
pub client_ip: String,
pub request_id: String,
pub headers: HashMap<String, String>,
pub extra: HashMap<String, String>,
pub error: Option<String>,
pub short_circuited: bool,
}
#[derive(Debug, Clone)]
pub struct ResponseContext {
pub status: u16,
pub resp_headers: HashMap<String, String>,
pub body: Option<Vec<u8>>,
pub short_circuited: bool,
}
pub enum HookResult {
Continue,
ShortCircuit(Response<Body>),
}
struct CompiledHook {
config: HookConfig,
bytecode: Vec<u8>,
}
#[derive(Clone)]
pub struct HookEngine {
hooks: Arc<Vec<CompiledHook>>,
}
impl HookEngine {
pub fn new(configs: &[HookConfig]) -> Self {
let mut compiled = Vec::with_capacity(configs.len());
for cfg in configs {
match compile_script(cfg) {
Ok(bytecode) => compiled.push(CompiledHook {
config: cfg.clone(),
bytecode,
}),
Err(e) => {
warn!(
event = %cfg.event,
script = %cfg.lua.display(),
error = %e,
"lua hook script failed to compile — skipping"
);
}
}
}
Self {
hooks: Arc::new(compiled),
}
}
pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
self.run_request_stage("request.before", ctx)
}
pub fn run_request_error(&self, ctx: &RequestContext) {
for hook in self
.hooks
.iter()
.filter(|h| h.config.event == "request.error")
{
let bytecode = hook.bytecode.clone();
let ctx_clone = ctx.clone();
let timeout = Duration::from_millis(hook.config.timeout_ms);
tokio::spawn(async move {
if let Err(e) = exec_request_hook(&bytecode, &ctx_clone, timeout) {
warn!(error = %e, "lua request.error hook failed");
}
});
}
}
pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
self.run_response_stage("response.headers", ctx)
}
pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
self.run_response_stage("response.after", ctx)
}
pub fn has_event(&self, event: &str) -> bool {
self.hooks.iter().any(|h| h.config.event == event)
}
fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
let hooks: Vec<_> = self
.hooks
.iter()
.filter(|h| h.config.event == event)
.collect();
let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
.into_iter()
.partition(|h| h.config.mode == HookMode::Sync);
let mut sc_response: Option<Response<Body>> = None;
for hook in &sync_hooks {
if sc_response.is_some() {
break;
}
let timeout = Duration::from_millis(hook.config.timeout_ms);
match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
Ok(Some(resp)) => {
ctx.short_circuited = true;
sc_response = Some(resp);
}
Ok(None) => {}
Err(e) => {
warn!(
event = event,
script = %hook.config.lua.display(),
error = %e,
"lua sync hook error"
);
if hook.config.on_error == HookErrorBehavior::FailClosed {
return HookResult::ShortCircuit(internal_error_response());
}
}
}
}
let ctx_snap = ctx.clone();
for hook in async_hooks {
let bytecode = hook.bytecode.clone();
let snap = ctx_snap.clone();
let timeout = Duration::from_millis(hook.config.timeout_ms);
tokio::spawn(async move {
if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
warn!(error = %e, "lua async request hook failed");
}
});
}
match sc_response {
Some(resp) => HookResult::ShortCircuit(resp),
None => HookResult::Continue,
}
}
fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
let hooks: Vec<_> = self
.hooks
.iter()
.filter(|h| h.config.event == event)
.collect();
let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
.into_iter()
.partition(|h| h.config.mode == HookMode::Sync);
let mut sc_response: Option<Response<Body>> = None;
for hook in &sync_hooks {
if sc_response.is_some() {
break;
}
let timeout = Duration::from_millis(hook.config.timeout_ms);
match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
Ok(Some(resp)) => {
ctx.short_circuited = true;
sc_response = Some(resp);
}
Ok(None) => {}
Err(e) => {
warn!(
event = event,
script = %hook.config.lua.display(),
error = %e,
"lua sync response hook error"
);
if hook.config.on_error == HookErrorBehavior::FailClosed {
return HookResult::ShortCircuit(internal_error_response());
}
}
}
}
let ctx_snap = ctx.clone();
for hook in async_hooks {
let bytecode = hook.bytecode.clone();
let snap = ctx_snap.clone();
let timeout = Duration::from_millis(hook.config.timeout_ms);
tokio::spawn(async move {
if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
warn!(error = %e, "lua async response hook failed");
}
});
}
match sc_response {
Some(resp) => HookResult::ShortCircuit(resp),
None => HookResult::Continue,
}
}
}
fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
let source =
std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
let lua = make_lua()?;
let func = lua
.load(&source)
.into_function()
.map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
Ok(func.dump(false))
}
fn make_lua() -> Result<Lua, String> {
Lua::new_with(
StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
LuaOptions::default(),
)
.map_err(|e| format!("lua init: {e}"))
}
fn exec_request_hook(
bytecode: &[u8],
ctx: &RequestContext,
timeout: Duration,
) -> Result<Option<Response<Body>>, String> {
let lua = make_lua()?;
install_timeout(&lua, timeout)?;
let t = build_request_table(&lua, ctx)?;
lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
let func = lua
.load(bytecode)
.into_function()
.map_err(|e| e.to_string())?;
let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
response_from_lua(result)
}
fn exec_request_hook_mut(
bytecode: &[u8],
ctx: &mut RequestContext,
timeout: Duration,
) -> Result<Option<Response<Body>>, String> {
let lua = make_lua()?;
install_timeout(&lua, timeout)?;
let t = build_request_table(&lua, ctx)?;
lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
let func = lua
.load(bytecode)
.into_function()
.map_err(|e| e.to_string())?;
let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
ctx.headers = lua_table_to_map(&headers_table)?;
let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
ctx.extra = lua_table_to_map(&extra_table)?;
response_from_lua(result)
}
fn exec_response_hook(
bytecode: &[u8],
ctx: &ResponseContext,
timeout: Duration,
) -> Result<Option<Response<Body>>, String> {
let lua = make_lua()?;
install_timeout(&lua, timeout)?;
let t = build_response_table(&lua, ctx)?;
lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
let func = lua
.load(bytecode)
.into_function()
.map_err(|e| e.to_string())?;
let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
response_from_lua(result)
}
fn exec_response_hook_mut(
bytecode: &[u8],
ctx: &mut ResponseContext,
timeout: Duration,
) -> Result<Option<Response<Body>>, String> {
let lua = make_lua()?;
install_timeout(&lua, timeout)?;
let t = build_response_table(&lua, ctx)?;
lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
let func = lua
.load(bytecode)
.into_function()
.map_err(|e| e.to_string())?;
let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
ctx.resp_headers = lua_table_to_map(&rh_table)?;
if ctx.body.is_some() {
let new_body: Option<mlua::String> = ctx_global.get("body").ok();
ctx.body = new_body.map(|s| s.as_bytes().to_vec());
}
response_from_lua(result)
}
fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
let t = lua.create_table().map_err(|e| e.to_string())?;
t.set("method", ctx.method.as_str())
.map_err(|e| e.to_string())?;
t.set("path", ctx.path.as_str())
.map_err(|e| e.to_string())?;
t.set("query", ctx.query.as_str())
.map_err(|e| e.to_string())?;
t.set("client_ip", ctx.client_ip.as_str())
.map_err(|e| e.to_string())?;
t.set("request_id", ctx.request_id.as_str())
.map_err(|e| e.to_string())?;
t.set("short_circuited", ctx.short_circuited)
.map_err(|e| e.to_string())?;
let headers = map_to_lua_table(lua, &ctx.headers)?;
t.set("headers", headers).map_err(|e| e.to_string())?;
let extra = map_to_lua_table(lua, &ctx.extra)?;
t.set("extra", extra).map_err(|e| e.to_string())?;
if let Some(ref err) = ctx.error {
t.set("error", err.as_str()).map_err(|e| e.to_string())?;
}
Ok(t)
}
fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
let t = lua.create_table().map_err(|e| e.to_string())?;
t.set("status", ctx.status).map_err(|e| e.to_string())?;
t.set("short_circuited", ctx.short_circuited)
.map_err(|e| e.to_string())?;
let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
t.set("resp_headers", rh).map_err(|e| e.to_string())?;
if let Some(ref body) = ctx.body {
let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
t.set("body", lua_str).map_err(|e| e.to_string())?;
}
Ok(t)
}
fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
let start = Instant::now();
lua.set_hook(
HookTriggers::new().every_nth_instruction(100),
move |_lua, _debug| {
if start.elapsed() > timeout {
Err(mlua::Error::runtime("lua hook timeout"))
} else {
Ok(VmState::Continue)
}
},
);
Ok(())
}
fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
let t = lua.create_table().map_err(|e| e.to_string())?;
for (k, v) in map {
t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
}
Ok(t)
}
fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
let mut map = HashMap::new();
for pair in table.clone().pairs::<String, String>() {
let (k, v) = pair.map_err(|e| e.to_string())?;
map.insert(k, v);
}
Ok(map)
}
fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
match value {
LuaValue::Nil => Ok(None),
LuaValue::Table(t) => {
let status: u16 = t.get("status").unwrap_or(200u16);
let body: String = t.get("body").unwrap_or_default();
let resp = Response::builder()
.status(status)
.body(Body::from(body))
.map_err(|e| e.to_string())?;
Ok(Some(resp))
}
_ => Ok(None),
}
}
fn internal_error_response() -> Response<Body> {
Response::builder()
.status(500)
.body(Body::from("internal server error (hook fail_closed)"))
.unwrap()
}