1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use anyhow::anyhow;
12use axum::body::Body;
13use axum::http::Response;
14use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
15use tracing::warn;
16
17use crate::config::{HookConfig, HookErrorBehavior, HookMode};
18
19#[derive(Debug, Clone)]
23pub struct RequestContext {
24 pub method: String,
25 pub path: String,
26 pub query: String,
27 pub client_ip: String,
28 pub request_id: String,
29 pub headers: HashMap<String, String>,
31 pub extra: HashMap<String, String>,
33 pub error: Option<String>,
35 pub short_circuited: bool,
37}
38
39#[derive(Debug, Clone)]
41pub struct ResponseContext {
42 pub status: u16,
43 pub resp_headers: HashMap<String, String>,
45 pub body: Option<Vec<u8>>,
47 pub short_circuited: bool,
49}
50
51pub enum HookResult {
55 Continue,
57 ShortCircuit(Response<Body>),
59}
60
61struct CompiledHook {
64 config: HookConfig,
65 bytecode: Vec<u8>,
67}
68
69#[derive(Clone)]
73pub struct HookEngine {
74 hooks: Arc<Vec<CompiledHook>>,
75}
76
77impl HookEngine {
78 pub fn new(configs: &[HookConfig]) -> anyhow::Result<Self> {
85 let mut compiled = Vec::with_capacity(configs.len());
86 for cfg in configs {
87 match compile_script(cfg) {
88 Ok(bytecode) => compiled.push(CompiledHook {
89 config: cfg.clone(),
90 bytecode,
91 }),
92 Err(e) => {
93 if cfg.required {
94 return Err(anyhow!(
95 "required lua hook script failed to compile \
96 (event = {}, script = {}): {}",
97 cfg.event,
98 cfg.lua.display(),
99 e
100 ));
101 }
102 warn!(
103 event = %cfg.event,
104 script = %cfg.lua.display(),
105 error = %e,
106 "lua hook script failed to compile — skipping"
107 );
108 }
109 }
110 }
111 Ok(Self {
112 hooks: Arc::new(compiled),
113 })
114 }
115
116 pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
118 self.run_request_stage("request.before", ctx)
119 }
120
121 pub fn run_request_error(&self, ctx: &mut RequestContext) -> HookResult {
126 self.run_request_stage("request.error", ctx)
127 }
128
129 pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
131 self.run_response_stage("response.headers", ctx)
132 }
133
134 pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
136 self.run_response_stage("response.after", ctx)
137 }
138
139 pub fn has_event(&self, event: &str) -> bool {
141 self.hooks.iter().any(|h| h.config.event == event)
142 }
143
144 fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
147 let hooks: Vec<_> = self
148 .hooks
149 .iter()
150 .filter(|h| h.config.event == event)
151 .collect();
152
153 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
154 .into_iter()
155 .partition(|h| h.config.mode == HookMode::Sync);
156
157 let mut sc_response: Option<Response<Body>> = None;
158
159 for hook in &sync_hooks {
160 if sc_response.is_some() {
161 break;
162 }
163 let timeout = Duration::from_millis(hook.config.timeout_ms);
164 match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
165 Ok(Some(resp)) => {
166 ctx.short_circuited = true;
167 sc_response = Some(resp);
168 }
169 Ok(None) => {}
170 Err(e) => {
171 warn!(
172 event = event,
173 script = %hook.config.lua.display(),
174 error = %e,
175 "lua sync hook error"
176 );
177 if hook.config.on_error == HookErrorBehavior::FailClosed {
178 return HookResult::ShortCircuit(internal_error_response());
179 }
180 }
181 }
182 }
183
184 let ctx_snap = ctx.clone();
186 for hook in async_hooks {
187 let bytecode = hook.bytecode.clone();
188 let snap = ctx_snap.clone();
189 let timeout = Duration::from_millis(hook.config.timeout_ms);
190 tokio::spawn(async move {
191 if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
192 warn!(error = %e, "lua async request hook failed");
193 }
194 });
195 }
196
197 match sc_response {
198 Some(resp) => HookResult::ShortCircuit(resp),
199 None => HookResult::Continue,
200 }
201 }
202
203 fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
204 let hooks: Vec<_> = self
205 .hooks
206 .iter()
207 .filter(|h| h.config.event == event)
208 .collect();
209
210 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
211 .into_iter()
212 .partition(|h| h.config.mode == HookMode::Sync);
213
214 let mut sc_response: Option<Response<Body>> = None;
215
216 for hook in &sync_hooks {
217 if sc_response.is_some() {
218 break;
219 }
220 let timeout = Duration::from_millis(hook.config.timeout_ms);
221 match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
222 Ok(Some(resp)) => {
223 ctx.short_circuited = true;
224 sc_response = Some(resp);
225 }
226 Ok(None) => {}
227 Err(e) => {
228 warn!(
229 event = event,
230 script = %hook.config.lua.display(),
231 error = %e,
232 "lua sync response hook error"
233 );
234 if hook.config.on_error == HookErrorBehavior::FailClosed {
235 return HookResult::ShortCircuit(internal_error_response());
236 }
237 }
238 }
239 }
240
241 let ctx_snap = ctx.clone();
242 for hook in async_hooks {
243 let bytecode = hook.bytecode.clone();
244 let snap = ctx_snap.clone();
245 let timeout = Duration::from_millis(hook.config.timeout_ms);
246 tokio::spawn(async move {
247 if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
248 warn!(error = %e, "lua async response hook failed");
249 }
250 });
251 }
252
253 match sc_response {
254 Some(resp) => HookResult::ShortCircuit(resp),
255 None => HookResult::Continue,
256 }
257 }
258}
259
260fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
263 let source =
264 std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
265 let lua = make_lua()?;
266 let func = lua
267 .load(&source)
268 .into_function()
269 .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
270 Ok(func.dump(false))
272}
273
274fn make_lua() -> Result<Lua, String> {
275 Lua::new_with(
277 StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
278 LuaOptions::default(),
279 )
280 .map_err(|e| format!("lua init: {e}"))
281}
282
283fn exec_request_hook(
287 bytecode: &[u8],
288 ctx: &RequestContext,
289 timeout: Duration,
290) -> Result<Option<Response<Body>>, String> {
291 let lua = make_lua()?;
292 install_timeout(&lua, timeout)?;
293
294 let t = build_request_table(&lua, ctx)?;
295 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
296
297 let func = lua
298 .load(bytecode)
299 .into_function()
300 .map_err(|e| e.to_string())?;
301 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
302 response_from_lua(result)
303}
304
305fn exec_request_hook_mut(
308 bytecode: &[u8],
309 ctx: &mut RequestContext,
310 timeout: Duration,
311) -> Result<Option<Response<Body>>, String> {
312 let lua = make_lua()?;
313 install_timeout(&lua, timeout)?;
314
315 let t = build_request_table(&lua, ctx)?;
316 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
317
318 let func = lua
319 .load(bytecode)
320 .into_function()
321 .map_err(|e| e.to_string())?;
322 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
323
324 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
326 let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
327 ctx.headers = lua_table_to_map(&headers_table)?;
328 let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
329 ctx.extra = lua_table_to_map(&extra_table)?;
330
331 response_from_lua(result)
332}
333
334fn exec_response_hook(
336 bytecode: &[u8],
337 ctx: &ResponseContext,
338 timeout: Duration,
339) -> Result<Option<Response<Body>>, String> {
340 let lua = make_lua()?;
341 install_timeout(&lua, timeout)?;
342
343 let t = build_response_table(&lua, ctx)?;
344 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
345
346 let func = lua
347 .load(bytecode)
348 .into_function()
349 .map_err(|e| e.to_string())?;
350 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
351 response_from_lua(result)
352}
353
354fn exec_response_hook_mut(
357 bytecode: &[u8],
358 ctx: &mut ResponseContext,
359 timeout: Duration,
360) -> Result<Option<Response<Body>>, String> {
361 let lua = make_lua()?;
362 install_timeout(&lua, timeout)?;
363
364 let t = build_response_table(&lua, ctx)?;
365 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
366
367 let func = lua
368 .load(bytecode)
369 .into_function()
370 .map_err(|e| e.to_string())?;
371 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
372
373 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
375 let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
376 ctx.resp_headers = lua_table_to_map(&rh_table)?;
377
378 if ctx.body.is_some() {
379 let new_body: Option<mlua::String> = ctx_global.get("body").ok();
382 ctx.body = new_body.map(|s| s.as_bytes().to_vec());
383 }
384
385 response_from_lua(result)
386}
387
388fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
391 let t = lua.create_table().map_err(|e| e.to_string())?;
392 t.set("method", ctx.method.as_str())
393 .map_err(|e| e.to_string())?;
394 t.set("path", ctx.path.as_str())
395 .map_err(|e| e.to_string())?;
396 t.set("query", ctx.query.as_str())
397 .map_err(|e| e.to_string())?;
398 t.set("client_ip", ctx.client_ip.as_str())
399 .map_err(|e| e.to_string())?;
400 t.set("request_id", ctx.request_id.as_str())
401 .map_err(|e| e.to_string())?;
402 t.set("short_circuited", ctx.short_circuited)
403 .map_err(|e| e.to_string())?;
404 let headers = map_to_lua_table(lua, &ctx.headers)?;
405 t.set("headers", headers).map_err(|e| e.to_string())?;
406 let extra = map_to_lua_table(lua, &ctx.extra)?;
407 t.set("extra", extra).map_err(|e| e.to_string())?;
408 if let Some(ref err) = ctx.error {
409 t.set("error", err.as_str()).map_err(|e| e.to_string())?;
410 }
411 Ok(t)
412}
413
414fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
415 let t = lua.create_table().map_err(|e| e.to_string())?;
416 t.set("status", ctx.status).map_err(|e| e.to_string())?;
417 t.set("short_circuited", ctx.short_circuited)
418 .map_err(|e| e.to_string())?;
419 let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
420 t.set("resp_headers", rh).map_err(|e| e.to_string())?;
421 if let Some(ref body) = ctx.body {
422 let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
424 t.set("body", lua_str).map_err(|e| e.to_string())?;
425 }
426 Ok(t)
427}
428
429fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
432 let start = Instant::now();
433 lua.set_hook(
434 HookTriggers::new().every_nth_instruction(100),
435 move |_lua, _debug| {
436 if start.elapsed() > timeout {
437 Err(mlua::Error::runtime("lua hook timeout"))
438 } else {
439 Ok(VmState::Continue)
440 }
441 },
442 );
443 Ok(())
444}
445
446fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
447 let t = lua.create_table().map_err(|e| e.to_string())?;
448 for (k, v) in map {
449 t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
450 }
451 Ok(t)
452}
453
454fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
455 let mut map = HashMap::new();
456 for pair in table.clone().pairs::<String, String>() {
457 let (k, v) = pair.map_err(|e| e.to_string())?;
458 map.insert(k, v);
459 }
460 Ok(map)
461}
462
463fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
464 match value {
465 LuaValue::Nil => Ok(None),
466 LuaValue::Table(t) => {
467 let status: u16 = t.get("status").unwrap_or(200u16);
468 let body: String = t.get("body").unwrap_or_default();
469 let resp = Response::builder()
470 .status(status)
471 .body(Body::from(body))
472 .map_err(|e| e.to_string())?;
473 Ok(Some(resp))
474 }
475 _ => Ok(None),
476 }
477}
478
479fn internal_error_response() -> Response<Body> {
480 Response::builder()
481 .status(500)
482 .body(Body::from("internal server error (hook fail_closed)"))
483 .unwrap()
484}