1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18#[derive(Debug, Clone)]
22pub struct RequestContext {
23 pub method: String,
24 pub path: String,
25 pub query: String,
26 pub client_ip: String,
27 pub request_id: String,
28 pub headers: HashMap<String, String>,
30 pub extra: HashMap<String, String>,
32 pub error: Option<String>,
34 pub short_circuited: bool,
36}
37
38#[derive(Debug, Clone)]
40pub struct ResponseContext {
41 pub status: u16,
42 pub resp_headers: HashMap<String, String>,
44 pub body: Option<Vec<u8>>,
46 pub short_circuited: bool,
48}
49
50pub enum HookResult {
54 Continue,
56 ShortCircuit(Response<Body>),
58}
59
60struct CompiledHook {
63 config: HookConfig,
64 bytecode: Vec<u8>,
66}
67
68#[derive(Clone)]
72pub struct HookEngine {
73 hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77 pub fn new(configs: &[HookConfig]) -> Self {
80 let mut compiled = Vec::with_capacity(configs.len());
81 for cfg in configs {
82 match compile_script(cfg) {
83 Ok(bytecode) => compiled.push(CompiledHook {
84 config: cfg.clone(),
85 bytecode,
86 }),
87 Err(e) => {
88 warn!(
89 event = %cfg.event,
90 script = %cfg.lua.display(),
91 error = %e,
92 "lua hook script failed to compile — skipping"
93 );
94 }
95 }
96 }
97 Self {
98 hooks: Arc::new(compiled),
99 }
100 }
101
102 pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104 self.run_request_stage("request.before", ctx)
105 }
106
107 pub fn run_request_error(&self, ctx: &RequestContext) {
109 for hook in self
110 .hooks
111 .iter()
112 .filter(|h| h.config.event == "request.error")
113 {
114 let bytecode = hook.bytecode.clone();
115 let ctx_clone = ctx.clone();
116 let timeout = Duration::from_millis(hook.config.timeout_ms);
117 tokio::spawn(async move {
118 if let Err(e) = exec_request_hook(&bytecode, &ctx_clone, timeout) {
119 warn!(error = %e, "lua request.error hook failed");
120 }
121 });
122 }
123 }
124
125 pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
127 self.run_response_stage("response.headers", ctx)
128 }
129
130 pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
132 self.run_response_stage("response.after", ctx)
133 }
134
135 pub fn has_event(&self, event: &str) -> bool {
137 self.hooks.iter().any(|h| h.config.event == event)
138 }
139
140 fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
143 let hooks: Vec<_> = self
144 .hooks
145 .iter()
146 .filter(|h| h.config.event == event)
147 .collect();
148
149 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
150 .into_iter()
151 .partition(|h| h.config.mode == HookMode::Sync);
152
153 let mut sc_response: Option<Response<Body>> = None;
154
155 for hook in &sync_hooks {
156 if sc_response.is_some() {
157 break;
158 }
159 let timeout = Duration::from_millis(hook.config.timeout_ms);
160 match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
161 Ok(Some(resp)) => {
162 ctx.short_circuited = true;
163 sc_response = Some(resp);
164 }
165 Ok(None) => {}
166 Err(e) => {
167 warn!(
168 event = event,
169 script = %hook.config.lua.display(),
170 error = %e,
171 "lua sync hook error"
172 );
173 if hook.config.on_error == HookErrorBehavior::FailClosed {
174 return HookResult::ShortCircuit(internal_error_response());
175 }
176 }
177 }
178 }
179
180 let ctx_snap = ctx.clone();
182 for hook in async_hooks {
183 let bytecode = hook.bytecode.clone();
184 let snap = ctx_snap.clone();
185 let timeout = Duration::from_millis(hook.config.timeout_ms);
186 tokio::spawn(async move {
187 if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
188 warn!(error = %e, "lua async request hook failed");
189 }
190 });
191 }
192
193 match sc_response {
194 Some(resp) => HookResult::ShortCircuit(resp),
195 None => HookResult::Continue,
196 }
197 }
198
199 fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
200 let hooks: Vec<_> = self
201 .hooks
202 .iter()
203 .filter(|h| h.config.event == event)
204 .collect();
205
206 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
207 .into_iter()
208 .partition(|h| h.config.mode == HookMode::Sync);
209
210 let mut sc_response: Option<Response<Body>> = None;
211
212 for hook in &sync_hooks {
213 if sc_response.is_some() {
214 break;
215 }
216 let timeout = Duration::from_millis(hook.config.timeout_ms);
217 match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
218 Ok(Some(resp)) => {
219 ctx.short_circuited = true;
220 sc_response = Some(resp);
221 }
222 Ok(None) => {}
223 Err(e) => {
224 warn!(
225 event = event,
226 script = %hook.config.lua.display(),
227 error = %e,
228 "lua sync response hook error"
229 );
230 if hook.config.on_error == HookErrorBehavior::FailClosed {
231 return HookResult::ShortCircuit(internal_error_response());
232 }
233 }
234 }
235 }
236
237 let ctx_snap = ctx.clone();
238 for hook in async_hooks {
239 let bytecode = hook.bytecode.clone();
240 let snap = ctx_snap.clone();
241 let timeout = Duration::from_millis(hook.config.timeout_ms);
242 tokio::spawn(async move {
243 if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
244 warn!(error = %e, "lua async response hook failed");
245 }
246 });
247 }
248
249 match sc_response {
250 Some(resp) => HookResult::ShortCircuit(resp),
251 None => HookResult::Continue,
252 }
253 }
254}
255
256fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
259 let source =
260 std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
261 let lua = make_lua()?;
262 let func = lua
263 .load(&source)
264 .into_function()
265 .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
266 Ok(func.dump(false))
268}
269
270fn make_lua() -> Result<Lua, String> {
271 Lua::new_with(
273 StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
274 LuaOptions::default(),
275 )
276 .map_err(|e| format!("lua init: {e}"))
277}
278
279fn exec_request_hook(
283 bytecode: &[u8],
284 ctx: &RequestContext,
285 timeout: Duration,
286) -> Result<Option<Response<Body>>, String> {
287 let lua = make_lua()?;
288 install_timeout(&lua, timeout)?;
289
290 let t = build_request_table(&lua, ctx)?;
291 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
292
293 let func = lua
294 .load(bytecode)
295 .into_function()
296 .map_err(|e| e.to_string())?;
297 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
298 response_from_lua(result)
299}
300
301fn exec_request_hook_mut(
304 bytecode: &[u8],
305 ctx: &mut RequestContext,
306 timeout: Duration,
307) -> Result<Option<Response<Body>>, String> {
308 let lua = make_lua()?;
309 install_timeout(&lua, timeout)?;
310
311 let t = build_request_table(&lua, ctx)?;
312 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
313
314 let func = lua
315 .load(bytecode)
316 .into_function()
317 .map_err(|e| e.to_string())?;
318 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
319
320 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
322 let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
323 ctx.headers = lua_table_to_map(&headers_table)?;
324 let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
325 ctx.extra = lua_table_to_map(&extra_table)?;
326
327 response_from_lua(result)
328}
329
330fn exec_response_hook(
332 bytecode: &[u8],
333 ctx: &ResponseContext,
334 timeout: Duration,
335) -> Result<Option<Response<Body>>, String> {
336 let lua = make_lua()?;
337 install_timeout(&lua, timeout)?;
338
339 let t = build_response_table(&lua, ctx)?;
340 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
341
342 let func = lua
343 .load(bytecode)
344 .into_function()
345 .map_err(|e| e.to_string())?;
346 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
347 response_from_lua(result)
348}
349
350fn exec_response_hook_mut(
353 bytecode: &[u8],
354 ctx: &mut ResponseContext,
355 timeout: Duration,
356) -> Result<Option<Response<Body>>, String> {
357 let lua = make_lua()?;
358 install_timeout(&lua, timeout)?;
359
360 let t = build_response_table(&lua, ctx)?;
361 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
362
363 let func = lua
364 .load(bytecode)
365 .into_function()
366 .map_err(|e| e.to_string())?;
367 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
368
369 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
371 let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
372 ctx.resp_headers = lua_table_to_map(&rh_table)?;
373
374 if ctx.body.is_some() {
375 let new_body: Option<mlua::String> = ctx_global.get("body").ok();
378 ctx.body = new_body.map(|s| s.as_bytes().to_vec());
379 }
380
381 response_from_lua(result)
382}
383
384fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
387 let t = lua.create_table().map_err(|e| e.to_string())?;
388 t.set("method", ctx.method.as_str())
389 .map_err(|e| e.to_string())?;
390 t.set("path", ctx.path.as_str())
391 .map_err(|e| e.to_string())?;
392 t.set("query", ctx.query.as_str())
393 .map_err(|e| e.to_string())?;
394 t.set("client_ip", ctx.client_ip.as_str())
395 .map_err(|e| e.to_string())?;
396 t.set("request_id", ctx.request_id.as_str())
397 .map_err(|e| e.to_string())?;
398 t.set("short_circuited", ctx.short_circuited)
399 .map_err(|e| e.to_string())?;
400 let headers = map_to_lua_table(lua, &ctx.headers)?;
401 t.set("headers", headers).map_err(|e| e.to_string())?;
402 let extra = map_to_lua_table(lua, &ctx.extra)?;
403 t.set("extra", extra).map_err(|e| e.to_string())?;
404 if let Some(ref err) = ctx.error {
405 t.set("error", err.as_str()).map_err(|e| e.to_string())?;
406 }
407 Ok(t)
408}
409
410fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
411 let t = lua.create_table().map_err(|e| e.to_string())?;
412 t.set("status", ctx.status).map_err(|e| e.to_string())?;
413 t.set("short_circuited", ctx.short_circuited)
414 .map_err(|e| e.to_string())?;
415 let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
416 t.set("resp_headers", rh).map_err(|e| e.to_string())?;
417 if let Some(ref body) = ctx.body {
418 let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
420 t.set("body", lua_str).map_err(|e| e.to_string())?;
421 }
422 Ok(t)
423}
424
425fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
428 let start = Instant::now();
429 lua.set_hook(
430 HookTriggers::new().every_nth_instruction(100),
431 move |_lua, _debug| {
432 if start.elapsed() > timeout {
433 Err(mlua::Error::runtime("lua hook timeout"))
434 } else {
435 Ok(VmState::Continue)
436 }
437 },
438 );
439 Ok(())
440}
441
442fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
443 let t = lua.create_table().map_err(|e| e.to_string())?;
444 for (k, v) in map {
445 t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
446 }
447 Ok(t)
448}
449
450fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
451 let mut map = HashMap::new();
452 for pair in table.clone().pairs::<String, String>() {
453 let (k, v) = pair.map_err(|e| e.to_string())?;
454 map.insert(k, v);
455 }
456 Ok(map)
457}
458
459fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
460 match value {
461 LuaValue::Nil => Ok(None),
462 LuaValue::Table(t) => {
463 let status: u16 = t.get("status").unwrap_or(200u16);
464 let body: String = t.get("body").unwrap_or_default();
465 let resp = Response::builder()
466 .status(status)
467 .body(Body::from(body))
468 .map_err(|e| e.to_string())?;
469 Ok(Some(resp))
470 }
471 _ => Ok(None),
472 }
473}
474
475fn internal_error_response() -> Response<Body> {
476 Response::builder()
477 .status(500)
478 .body(Body::from("internal server error (hook fail_closed)"))
479 .unwrap()
480}