1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18#[derive(Debug, Clone)]
22pub struct RequestContext {
23 pub method: String,
24 pub path: String,
25 pub query: String,
26 pub client_ip: String,
27 pub request_id: String,
28 pub headers: HashMap<String, String>,
30 pub extra: HashMap<String, String>,
32 pub error: Option<String>,
34 pub short_circuited: bool,
36}
37
38#[derive(Debug, Clone)]
40pub struct ResponseContext {
41 pub status: u16,
42 pub resp_headers: HashMap<String, String>,
44 pub body: Option<Vec<u8>>,
46 pub short_circuited: bool,
48}
49
50pub enum HookResult {
54 Continue,
56 ShortCircuit(Response<Body>),
58}
59
60struct CompiledHook {
63 config: HookConfig,
64 bytecode: Vec<u8>,
66}
67
68#[derive(Clone)]
72pub struct HookEngine {
73 hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77 pub fn new(configs: &[HookConfig]) -> Self {
80 let mut compiled = Vec::with_capacity(configs.len());
81 for cfg in configs {
82 match compile_script(cfg) {
83 Ok(bytecode) => compiled.push(CompiledHook {
84 config: cfg.clone(),
85 bytecode,
86 }),
87 Err(e) => {
88 warn!(
89 event = %cfg.event,
90 script = %cfg.lua.display(),
91 error = %e,
92 "lua hook script failed to compile — skipping"
93 );
94 }
95 }
96 }
97 Self {
98 hooks: Arc::new(compiled),
99 }
100 }
101
102 pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104 self.run_request_stage("request.before", ctx)
105 }
106
107 pub fn run_request_error(&self, ctx: &mut RequestContext) -> HookResult {
112 self.run_request_stage("request.error", ctx)
113 }
114
115 pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
117 self.run_response_stage("response.headers", ctx)
118 }
119
120 pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
122 self.run_response_stage("response.after", ctx)
123 }
124
125 pub fn has_event(&self, event: &str) -> bool {
127 self.hooks.iter().any(|h| h.config.event == event)
128 }
129
130 fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
133 let hooks: Vec<_> = self
134 .hooks
135 .iter()
136 .filter(|h| h.config.event == event)
137 .collect();
138
139 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
140 .into_iter()
141 .partition(|h| h.config.mode == HookMode::Sync);
142
143 let mut sc_response: Option<Response<Body>> = None;
144
145 for hook in &sync_hooks {
146 if sc_response.is_some() {
147 break;
148 }
149 let timeout = Duration::from_millis(hook.config.timeout_ms);
150 match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
151 Ok(Some(resp)) => {
152 ctx.short_circuited = true;
153 sc_response = Some(resp);
154 }
155 Ok(None) => {}
156 Err(e) => {
157 warn!(
158 event = event,
159 script = %hook.config.lua.display(),
160 error = %e,
161 "lua sync hook error"
162 );
163 if hook.config.on_error == HookErrorBehavior::FailClosed {
164 return HookResult::ShortCircuit(internal_error_response());
165 }
166 }
167 }
168 }
169
170 let ctx_snap = ctx.clone();
172 for hook in async_hooks {
173 let bytecode = hook.bytecode.clone();
174 let snap = ctx_snap.clone();
175 let timeout = Duration::from_millis(hook.config.timeout_ms);
176 tokio::spawn(async move {
177 if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
178 warn!(error = %e, "lua async request hook failed");
179 }
180 });
181 }
182
183 match sc_response {
184 Some(resp) => HookResult::ShortCircuit(resp),
185 None => HookResult::Continue,
186 }
187 }
188
189 fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
190 let hooks: Vec<_> = self
191 .hooks
192 .iter()
193 .filter(|h| h.config.event == event)
194 .collect();
195
196 let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
197 .into_iter()
198 .partition(|h| h.config.mode == HookMode::Sync);
199
200 let mut sc_response: Option<Response<Body>> = None;
201
202 for hook in &sync_hooks {
203 if sc_response.is_some() {
204 break;
205 }
206 let timeout = Duration::from_millis(hook.config.timeout_ms);
207 match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
208 Ok(Some(resp)) => {
209 ctx.short_circuited = true;
210 sc_response = Some(resp);
211 }
212 Ok(None) => {}
213 Err(e) => {
214 warn!(
215 event = event,
216 script = %hook.config.lua.display(),
217 error = %e,
218 "lua sync response hook error"
219 );
220 if hook.config.on_error == HookErrorBehavior::FailClosed {
221 return HookResult::ShortCircuit(internal_error_response());
222 }
223 }
224 }
225 }
226
227 let ctx_snap = ctx.clone();
228 for hook in async_hooks {
229 let bytecode = hook.bytecode.clone();
230 let snap = ctx_snap.clone();
231 let timeout = Duration::from_millis(hook.config.timeout_ms);
232 tokio::spawn(async move {
233 if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
234 warn!(error = %e, "lua async response hook failed");
235 }
236 });
237 }
238
239 match sc_response {
240 Some(resp) => HookResult::ShortCircuit(resp),
241 None => HookResult::Continue,
242 }
243 }
244}
245
246fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
249 let source =
250 std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
251 let lua = make_lua()?;
252 let func = lua
253 .load(&source)
254 .into_function()
255 .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
256 Ok(func.dump(false))
258}
259
260fn make_lua() -> Result<Lua, String> {
261 Lua::new_with(
263 StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
264 LuaOptions::default(),
265 )
266 .map_err(|e| format!("lua init: {e}"))
267}
268
269fn exec_request_hook(
273 bytecode: &[u8],
274 ctx: &RequestContext,
275 timeout: Duration,
276) -> Result<Option<Response<Body>>, String> {
277 let lua = make_lua()?;
278 install_timeout(&lua, timeout)?;
279
280 let t = build_request_table(&lua, ctx)?;
281 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
282
283 let func = lua
284 .load(bytecode)
285 .into_function()
286 .map_err(|e| e.to_string())?;
287 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
288 response_from_lua(result)
289}
290
291fn exec_request_hook_mut(
294 bytecode: &[u8],
295 ctx: &mut RequestContext,
296 timeout: Duration,
297) -> Result<Option<Response<Body>>, String> {
298 let lua = make_lua()?;
299 install_timeout(&lua, timeout)?;
300
301 let t = build_request_table(&lua, ctx)?;
302 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
303
304 let func = lua
305 .load(bytecode)
306 .into_function()
307 .map_err(|e| e.to_string())?;
308 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
309
310 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
312 let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
313 ctx.headers = lua_table_to_map(&headers_table)?;
314 let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
315 ctx.extra = lua_table_to_map(&extra_table)?;
316
317 response_from_lua(result)
318}
319
320fn exec_response_hook(
322 bytecode: &[u8],
323 ctx: &ResponseContext,
324 timeout: Duration,
325) -> Result<Option<Response<Body>>, String> {
326 let lua = make_lua()?;
327 install_timeout(&lua, timeout)?;
328
329 let t = build_response_table(&lua, ctx)?;
330 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
331
332 let func = lua
333 .load(bytecode)
334 .into_function()
335 .map_err(|e| e.to_string())?;
336 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
337 response_from_lua(result)
338}
339
340fn exec_response_hook_mut(
343 bytecode: &[u8],
344 ctx: &mut ResponseContext,
345 timeout: Duration,
346) -> Result<Option<Response<Body>>, String> {
347 let lua = make_lua()?;
348 install_timeout(&lua, timeout)?;
349
350 let t = build_response_table(&lua, ctx)?;
351 lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
352
353 let func = lua
354 .load(bytecode)
355 .into_function()
356 .map_err(|e| e.to_string())?;
357 let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
358
359 let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
361 let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
362 ctx.resp_headers = lua_table_to_map(&rh_table)?;
363
364 if ctx.body.is_some() {
365 let new_body: Option<mlua::String> = ctx_global.get("body").ok();
368 ctx.body = new_body.map(|s| s.as_bytes().to_vec());
369 }
370
371 response_from_lua(result)
372}
373
374fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
377 let t = lua.create_table().map_err(|e| e.to_string())?;
378 t.set("method", ctx.method.as_str())
379 .map_err(|e| e.to_string())?;
380 t.set("path", ctx.path.as_str())
381 .map_err(|e| e.to_string())?;
382 t.set("query", ctx.query.as_str())
383 .map_err(|e| e.to_string())?;
384 t.set("client_ip", ctx.client_ip.as_str())
385 .map_err(|e| e.to_string())?;
386 t.set("request_id", ctx.request_id.as_str())
387 .map_err(|e| e.to_string())?;
388 t.set("short_circuited", ctx.short_circuited)
389 .map_err(|e| e.to_string())?;
390 let headers = map_to_lua_table(lua, &ctx.headers)?;
391 t.set("headers", headers).map_err(|e| e.to_string())?;
392 let extra = map_to_lua_table(lua, &ctx.extra)?;
393 t.set("extra", extra).map_err(|e| e.to_string())?;
394 if let Some(ref err) = ctx.error {
395 t.set("error", err.as_str()).map_err(|e| e.to_string())?;
396 }
397 Ok(t)
398}
399
400fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
401 let t = lua.create_table().map_err(|e| e.to_string())?;
402 t.set("status", ctx.status).map_err(|e| e.to_string())?;
403 t.set("short_circuited", ctx.short_circuited)
404 .map_err(|e| e.to_string())?;
405 let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
406 t.set("resp_headers", rh).map_err(|e| e.to_string())?;
407 if let Some(ref body) = ctx.body {
408 let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
410 t.set("body", lua_str).map_err(|e| e.to_string())?;
411 }
412 Ok(t)
413}
414
415fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
418 let start = Instant::now();
419 lua.set_hook(
420 HookTriggers::new().every_nth_instruction(100),
421 move |_lua, _debug| {
422 if start.elapsed() > timeout {
423 Err(mlua::Error::runtime("lua hook timeout"))
424 } else {
425 Ok(VmState::Continue)
426 }
427 },
428 );
429 Ok(())
430}
431
432fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
433 let t = lua.create_table().map_err(|e| e.to_string())?;
434 for (k, v) in map {
435 t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
436 }
437 Ok(t)
438}
439
440fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
441 let mut map = HashMap::new();
442 for pair in table.clone().pairs::<String, String>() {
443 let (k, v) = pair.map_err(|e| e.to_string())?;
444 map.insert(k, v);
445 }
446 Ok(map)
447}
448
449fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
450 match value {
451 LuaValue::Nil => Ok(None),
452 LuaValue::Table(t) => {
453 let status: u16 = t.get("status").unwrap_or(200u16);
454 let body: String = t.get("body").unwrap_or_default();
455 let resp = Response::builder()
456 .status(status)
457 .body(Body::from(body))
458 .map_err(|e| e.to_string())?;
459 Ok(Some(resp))
460 }
461 _ => Ok(None),
462 }
463}
464
465fn internal_error_response() -> Response<Body> {
466 Response::builder()
467 .status(500)
468 .body(Body::from("internal server error (hook fail_closed)"))
469 .unwrap()
470}