1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18#[derive(Debug, Clone)]
22pub struct RequestContext {
23 pub method: String,
24 pub path: String,
25 pub query: String,
26 pub client_ip: String,
27 pub request_id: String,
28 pub headers: HashMap<String, String>,
30 pub extra: HashMap<String, String>,
32 pub error: Option<String>,
34 pub short_circuited: bool,
36}
37
38#[derive(Debug, Clone)]
40pub struct ResponseContext {
41 pub status: u16,
42 pub resp_headers: HashMap<String, String>,
44 pub body: Option<String>,
46 pub short_circuited: bool,
48}
49
50pub enum HookResult {
54 Continue,
56 ShortCircuit(Response<Body>),
58}
59
60struct CompiledHook {
63 config: HookConfig,
64 bytecode: Vec<u8>,
66}
67
68#[derive(Clone)]
72pub struct HookEngine {
73 hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77 pub fn new(configs: &[HookConfig]) -> Self {
80 let mut compiled = Vec::with_capacity(configs.len());
81 for cfg in configs {
82 match compile_script(cfg) {
83 Ok(bytecode) => compiled.push(CompiledHook {
84 config: cfg.clone(),
85 bytecode,
86 }),
87 Err(e) => {
88 warn!(
89 event = %cfg.event,
90 script = %cfg.lua.display(),
91 error = %e,
92 "lua hook script failed to compile — skipping"
93 );
94 }
95 }
96 }
97 Self {
98 hooks: Arc::new(compiled),
99 }
100 }
101
102 pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104 self.run_request_stage("request.before", ctx)
105 }
106
107 pub fn run_request_error(&self, ctx: &RequestContext) {
109 for hook in self
110 .hooks
111 .iter()
112 .filter(|h| h.config.event == "request.error")
113 {
114 let bytecode = hook.bytecode.clone();
115 let ctx_clone = ctx.clone();
116 let timeout = Duration::from_millis(hook.config.timeout_ms);
117 tokio::spawn(async move {
118 if let Err(e) = exec_request_hook(&bytecode, &ctx_clone, timeout) {
119 warn!(error = %e, "lua request.error hook failed");
120 }
121 });
122 }
123 }
124
125 pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
127 self.run_response_stage("response.headers", ctx)
128 }
129
130 pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
132 self.run_response_stage("response.after", ctx)
133 }
134
135 fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
138 let hooks: Vec<_> = self
139 .hooks
140 .iter()
141 .filter(|h| h.config.event == event)
142 .collect();
143
144 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
145 .into_iter()
146 .partition(|h| h.config.mode == HookMode::Sync);
147
148 let mut sc_response: Option<Response<Body>> = None;
149
150 for hook in &sync_hooks {
151 if sc_response.is_some() {
152 break;
153 }
154 let timeout = Duration::from_millis(hook.config.timeout_ms);
155 match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
156 Ok(Some(resp)) => {
157 ctx.short_circuited = true;
158 sc_response = Some(resp);
159 }
160 Ok(None) => {}
161 Err(e) => {
162 warn!(
163 event = event,
164 script = %hook.config.lua.display(),
165 error = %e,
166 "lua sync hook error"
167 );
168 if hook.config.on_error == HookErrorBehavior::FailClosed {
169 return HookResult::ShortCircuit(internal_error_response());
170 }
171 }
172 }
173 }
174
175 let ctx_snap = ctx.clone();
177 for hook in async_hooks {
178 let bytecode = hook.bytecode.clone();
179 let snap = ctx_snap.clone();
180 let timeout = Duration::from_millis(hook.config.timeout_ms);
181 tokio::spawn(async move {
182 if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
183 warn!(error = %e, "lua async request hook failed");
184 }
185 });
186 }
187
188 match sc_response {
189 Some(resp) => HookResult::ShortCircuit(resp),
190 None => HookResult::Continue,
191 }
192 }
193
194 fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
195 let hooks: Vec<_> = self
196 .hooks
197 .iter()
198 .filter(|h| h.config.event == event)
199 .collect();
200
201 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
202 .into_iter()
203 .partition(|h| h.config.mode == HookMode::Sync);
204
205 let mut sc_response: Option<Response<Body>> = None;
206
207 for hook in &sync_hooks {
208 if sc_response.is_some() {
209 break;
210 }
211 let timeout = Duration::from_millis(hook.config.timeout_ms);
212 match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
213 Ok(Some(resp)) => {
214 ctx.short_circuited = true;
215 sc_response = Some(resp);
216 }
217 Ok(None) => {}
218 Err(e) => {
219 warn!(
220 event = event,
221 script = %hook.config.lua.display(),
222 error = %e,
223 "lua sync response hook error"
224 );
225 if hook.config.on_error == HookErrorBehavior::FailClosed {
226 return HookResult::ShortCircuit(internal_error_response());
227 }
228 }
229 }
230 }
231
232 let ctx_snap = ctx.clone();
233 for hook in async_hooks {
234 let bytecode = hook.bytecode.clone();
235 let snap = ctx_snap.clone();
236 let timeout = Duration::from_millis(hook.config.timeout_ms);
237 tokio::spawn(async move {
238 if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
239 warn!(error = %e, "lua async response hook failed");
240 }
241 });
242 }
243
244 match sc_response {
245 Some(resp) => HookResult::ShortCircuit(resp),
246 None => HookResult::Continue,
247 }
248 }
249}
250
251fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
254 let source =
255 std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
256 let lua = make_lua()?;
257 let func = lua
258 .load(&source)
259 .into_function()
260 .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
261 Ok(func.dump(false))
263}
264
265fn make_lua() -> Result<Lua, String> {
266 Lua::new_with(
268 StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
269 LuaOptions::default(),
270 )
271 .map_err(|e| format!("lua init: {e}"))
272}
273
274fn exec_request_hook(
278 bytecode: &[u8],
279 ctx: &RequestContext,
280 timeout: Duration,
281) -> Result<Option<Response<Body>>, String> {
282 let lua = make_lua()?;
283 install_timeout(&lua, timeout)?;
284
285 let t = build_request_table(&lua, ctx)?;
286 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
287
288 let func = lua
289 .load(bytecode)
290 .into_function()
291 .map_err(|e| e.to_string())?;
292 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
293 response_from_lua(result)
294}
295
296fn exec_request_hook_mut(
299 bytecode: &[u8],
300 ctx: &mut RequestContext,
301 timeout: Duration,
302) -> Result<Option<Response<Body>>, String> {
303 let lua = make_lua()?;
304 install_timeout(&lua, timeout)?;
305
306 let t = build_request_table(&lua, ctx)?;
307 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
308
309 let func = lua
310 .load(bytecode)
311 .into_function()
312 .map_err(|e| e.to_string())?;
313 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
314
315 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
317 let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
318 ctx.headers = lua_table_to_map(&headers_table)?;
319 let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
320 ctx.extra = lua_table_to_map(&extra_table)?;
321
322 response_from_lua(result)
323}
324
325fn exec_response_hook(
327 bytecode: &[u8],
328 ctx: &ResponseContext,
329 timeout: Duration,
330) -> Result<Option<Response<Body>>, String> {
331 let lua = make_lua()?;
332 install_timeout(&lua, timeout)?;
333
334 let t = build_response_table(&lua, ctx)?;
335 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
336
337 let func = lua
338 .load(bytecode)
339 .into_function()
340 .map_err(|e| e.to_string())?;
341 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
342 response_from_lua(result)
343}
344
345fn exec_response_hook_mut(
348 bytecode: &[u8],
349 ctx: &mut ResponseContext,
350 timeout: Duration,
351) -> Result<Option<Response<Body>>, String> {
352 let lua = make_lua()?;
353 install_timeout(&lua, timeout)?;
354
355 let t = build_response_table(&lua, ctx)?;
356 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
357
358 let func = lua
359 .load(bytecode)
360 .into_function()
361 .map_err(|e| e.to_string())?;
362 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
363
364 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
366 let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
367 ctx.resp_headers = lua_table_to_map(&rh_table)?;
368
369 if ctx.body.is_some() {
370 let new_body: Option<String> = ctx_global.get("body").ok();
372 ctx.body = new_body;
373 }
374
375 response_from_lua(result)
376}
377
378fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
381 let t = lua.create_table().map_err(|e| e.to_string())?;
382 t.set("method", ctx.method.as_str())
383 .map_err(|e| e.to_string())?;
384 t.set("path", ctx.path.as_str())
385 .map_err(|e| e.to_string())?;
386 t.set("query", ctx.query.as_str())
387 .map_err(|e| e.to_string())?;
388 t.set("client_ip", ctx.client_ip.as_str())
389 .map_err(|e| e.to_string())?;
390 t.set("request_id", ctx.request_id.as_str())
391 .map_err(|e| e.to_string())?;
392 t.set("short_circuited", ctx.short_circuited)
393 .map_err(|e| e.to_string())?;
394 let headers = map_to_lua_table(lua, &ctx.headers)?;
395 t.set("headers", headers).map_err(|e| e.to_string())?;
396 let extra = map_to_lua_table(lua, &ctx.extra)?;
397 t.set("extra", extra).map_err(|e| e.to_string())?;
398 if let Some(ref err) = ctx.error {
399 t.set("error", err.as_str()).map_err(|e| e.to_string())?;
400 }
401 Ok(t)
402}
403
404fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
405 let t = lua.create_table().map_err(|e| e.to_string())?;
406 t.set("status", ctx.status).map_err(|e| e.to_string())?;
407 t.set("short_circuited", ctx.short_circuited)
408 .map_err(|e| e.to_string())?;
409 let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
410 t.set("resp_headers", rh).map_err(|e| e.to_string())?;
411 if let Some(ref body) = ctx.body {
412 t.set("body", body.as_str()).map_err(|e| e.to_string())?;
413 }
414 Ok(t)
415}
416
417fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
420 let start = Instant::now();
421 lua.set_hook(
422 HookTriggers::new().every_nth_instruction(100),
423 move |_lua, _debug| {
424 if start.elapsed() > timeout {
425 Err(mlua::Error::runtime("lua hook timeout"))
426 } else {
427 Ok(VmState::Continue)
428 }
429 },
430 );
431 Ok(())
432}
433
434fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
435 let t = lua.create_table().map_err(|e| e.to_string())?;
436 for (k, v) in map {
437 t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
438 }
439 Ok(t)
440}
441
442fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
443 let mut map = HashMap::new();
444 for pair in table.clone().pairs::<String, String>() {
445 let (k, v) = pair.map_err(|e| e.to_string())?;
446 map.insert(k, v);
447 }
448 Ok(map)
449}
450
451fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
452 match value {
453 LuaValue::Nil => Ok(None),
454 LuaValue::Table(t) => {
455 let status: u16 = t.get("status").unwrap_or(200u16);
456 let body: String = t.get("body").unwrap_or_default();
457 let resp = Response::builder()
458 .status(status)
459 .body(Body::from(body))
460 .map_err(|e| e.to_string())?;
461 Ok(Some(resp))
462 }
463 _ => Ok(None),
464 }
465}
466
467fn internal_error_response() -> Response<Body> {
468 Response::builder()
469 .status(500)
470 .body(Body::from("internal server error (hook fail_closed)"))
471 .unwrap()
472}