Skip to main content

folk_plugin_http/
hooks.rs

1//! Lua hook pipeline for folk-plugin-http.
2//!
3//! Hook scripts are pre-compiled to Lua bytecode at startup and executed
4//! per-request on a fresh [`mlua::Lua`] VM instance.  Sync hooks run in
5//! the request critical path; async hooks fire-and-forget via `tokio::spawn`.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18// ── Public context types ──────────────────────────────────────────────────────
19
20/// Context passed to `request.before` and `request.error` hooks.
21#[derive(Debug, Clone)]
22pub struct RequestContext {
23    pub method: String,
24    pub path: String,
25    pub query: String,
26    pub client_ip: String,
27    pub request_id: String,
28    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
29    pub headers: HashMap<String, String>,
30    /// Arbitrary key-value bag, mutable in sync hooks.
31    pub extra: HashMap<String, String>,
32    /// Error message. Set only for `request.error` hooks.
33    pub error: Option<String>,
34    /// True when a previous sync hook short-circuited the pipeline.
35    pub short_circuited: bool,
36}
37
38/// Context passed to `response.headers` and `response.after` hooks.
39#[derive(Debug, Clone)]
40pub struct ResponseContext {
41    pub status: u16,
42    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
43    pub resp_headers: HashMap<String, String>,
44    /// Present only for `response.after` hooks.
45    pub body: Option<String>,
46    /// True when a previous sync hook short-circuited the pipeline.
47    pub short_circuited: bool,
48}
49
50// ── Hook result ───────────────────────────────────────────────────────────────
51
52/// Outcome of running a hook stage.
53pub enum HookResult {
54    /// Continue normal processing.
55    Continue,
56    /// A sync hook returned a short-circuit response.
57    ShortCircuit(Response<Body>),
58}
59
60// ── Internals ─────────────────────────────────────────────────────────────────
61
62struct CompiledHook {
63    config: HookConfig,
64    /// Pre-compiled Lua bytecode.
65    bytecode: Vec<u8>,
66}
67
68// ── HookEngine ────────────────────────────────────────────────────────────────
69
70/// Holds all pre-compiled hooks for the lifetime of the server.
71#[derive(Clone)]
72pub struct HookEngine {
73    hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77    /// Compile all hook scripts.  Scripts that fail to compile are skipped with
78    /// a WARN log; the server starts normally regardless.
79    pub fn new(configs: &[HookConfig]) -> Self {
80        let mut compiled = Vec::with_capacity(configs.len());
81        for cfg in configs {
82            match compile_script(cfg) {
83                Ok(bytecode) => compiled.push(CompiledHook {
84                    config: cfg.clone(),
85                    bytecode,
86                }),
87                Err(e) => {
88                    warn!(
89                        event = %cfg.event,
90                        script = %cfg.lua.display(),
91                        error = %e,
92                        "lua hook script failed to compile — skipping"
93                    );
94                }
95            }
96        }
97        Self {
98            hooks: Arc::new(compiled),
99        }
100    }
101
102    /// Run `request.before` hooks.
103    pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104        self.run_request_stage("request.before", ctx)
105    }
106
107    /// Run `request.error` hooks — all spawned as async; no short-circuit.
108    pub fn run_request_error(&self, ctx: &RequestContext) {
109        for hook in self
110            .hooks
111            .iter()
112            .filter(|h| h.config.event == "request.error")
113        {
114            let bytecode = hook.bytecode.clone();
115            let ctx_clone = ctx.clone();
116            let timeout = Duration::from_millis(hook.config.timeout_ms);
117            tokio::spawn(async move {
118                if let Err(e) = exec_request_hook(&bytecode, &ctx_clone, timeout) {
119                    warn!(error = %e, "lua request.error hook failed");
120                }
121            });
122        }
123    }
124
125    /// Run `response.headers` hooks.
126    pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
127        self.run_response_stage("response.headers", ctx)
128    }
129
130    /// Run `response.after` hooks.
131    pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
132        self.run_response_stage("response.after", ctx)
133    }
134
135    // ── Internal stage runners ────────────────────────────────────────────────
136
137    fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
138        let hooks: Vec<_> = self
139            .hooks
140            .iter()
141            .filter(|h| h.config.event == event)
142            .collect();
143
144        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
145            .into_iter()
146            .partition(|h| h.config.mode == HookMode::Sync);
147
148        let mut sc_response: Option<Response<Body>> = None;
149
150        for hook in &sync_hooks {
151            if sc_response.is_some() {
152                break;
153            }
154            let timeout = Duration::from_millis(hook.config.timeout_ms);
155            match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
156                Ok(Some(resp)) => {
157                    ctx.short_circuited = true;
158                    sc_response = Some(resp);
159                }
160                Ok(None) => {}
161                Err(e) => {
162                    warn!(
163                        event = event,
164                        script = %hook.config.lua.display(),
165                        error = %e,
166                        "lua sync hook error"
167                    );
168                    if hook.config.on_error == HookErrorBehavior::FailClosed {
169                        return HookResult::ShortCircuit(internal_error_response());
170                    }
171                }
172            }
173        }
174
175        // Spawn async hooks with a read-only snapshot (taken after sync phase).
176        let ctx_snap = ctx.clone();
177        for hook in async_hooks {
178            let bytecode = hook.bytecode.clone();
179            let snap = ctx_snap.clone();
180            let timeout = Duration::from_millis(hook.config.timeout_ms);
181            tokio::spawn(async move {
182                if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
183                    warn!(error = %e, "lua async request hook failed");
184                }
185            });
186        }
187
188        match sc_response {
189            Some(resp) => HookResult::ShortCircuit(resp),
190            None => HookResult::Continue,
191        }
192    }
193
194    fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
195        let hooks: Vec<_> = self
196            .hooks
197            .iter()
198            .filter(|h| h.config.event == event)
199            .collect();
200
201        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
202            .into_iter()
203            .partition(|h| h.config.mode == HookMode::Sync);
204
205        let mut sc_response: Option<Response<Body>> = None;
206
207        for hook in &sync_hooks {
208            if sc_response.is_some() {
209                break;
210            }
211            let timeout = Duration::from_millis(hook.config.timeout_ms);
212            match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
213                Ok(Some(resp)) => {
214                    ctx.short_circuited = true;
215                    sc_response = Some(resp);
216                }
217                Ok(None) => {}
218                Err(e) => {
219                    warn!(
220                        event = event,
221                        script = %hook.config.lua.display(),
222                        error = %e,
223                        "lua sync response hook error"
224                    );
225                    if hook.config.on_error == HookErrorBehavior::FailClosed {
226                        return HookResult::ShortCircuit(internal_error_response());
227                    }
228                }
229            }
230        }
231
232        let ctx_snap = ctx.clone();
233        for hook in async_hooks {
234            let bytecode = hook.bytecode.clone();
235            let snap = ctx_snap.clone();
236            let timeout = Duration::from_millis(hook.config.timeout_ms);
237            tokio::spawn(async move {
238                if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
239                    warn!(error = %e, "lua async response hook failed");
240                }
241            });
242        }
243
244        match sc_response {
245            Some(resp) => HookResult::ShortCircuit(resp),
246            None => HookResult::Continue,
247        }
248    }
249}
250
251// ── Script compilation ────────────────────────────────────────────────────────
252
253fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
254    let source =
255        std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
256    let lua = make_lua()?;
257    let func = lua
258        .load(&source)
259        .into_function()
260        .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
261    // dump() returns Vec<u8> directly in mlua 0.10
262    Ok(func.dump(false))
263}
264
265fn make_lua() -> Result<Lua, String> {
266    // Safe subset — no io, os, debug, package, coroutine.
267    Lua::new_with(
268        StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
269        LuaOptions::default(),
270    )
271    .map_err(|e| format!("lua init: {e}"))
272}
273
274// ── Per-request execution helpers ─────────────────────────────────────────────
275
276/// Execute bytecode with an immutable RequestContext.  Used for async hooks.
277fn exec_request_hook(
278    bytecode: &[u8],
279    ctx: &RequestContext,
280    timeout: Duration,
281) -> Result<Option<Response<Body>>, String> {
282    let lua = make_lua()?;
283    install_timeout(&lua, timeout)?;
284
285    let t = build_request_table(&lua, ctx)?;
286    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
287
288    let func = lua
289        .load(bytecode)
290        .into_function()
291        .map_err(|e| e.to_string())?;
292    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
293    response_from_lua(result)
294}
295
296/// Execute bytecode with a mutable RequestContext — mutations to headers/extra
297/// are written back into `ctx` so subsequent sync hooks see them.
298fn exec_request_hook_mut(
299    bytecode: &[u8],
300    ctx: &mut RequestContext,
301    timeout: Duration,
302) -> Result<Option<Response<Body>>, String> {
303    let lua = make_lua()?;
304    install_timeout(&lua, timeout)?;
305
306    let t = build_request_table(&lua, ctx)?;
307    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
308
309    let func = lua
310        .load(bytecode)
311        .into_function()
312        .map_err(|e| e.to_string())?;
313    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
314
315    // Write back mutations.
316    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
317    let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
318    ctx.headers = lua_table_to_map(&headers_table)?;
319    let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
320    ctx.extra = lua_table_to_map(&extra_table)?;
321
322    response_from_lua(result)
323}
324
325/// Execute bytecode with an immutable ResponseContext.  Used for async hooks.
326fn exec_response_hook(
327    bytecode: &[u8],
328    ctx: &ResponseContext,
329    timeout: Duration,
330) -> Result<Option<Response<Body>>, String> {
331    let lua = make_lua()?;
332    install_timeout(&lua, timeout)?;
333
334    let t = build_response_table(&lua, ctx)?;
335    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
336
337    let func = lua
338        .load(bytecode)
339        .into_function()
340        .map_err(|e| e.to_string())?;
341    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
342    response_from_lua(result)
343}
344
345/// Execute bytecode with a mutable ResponseContext — mutations to resp_headers
346/// and body are written back.
347fn exec_response_hook_mut(
348    bytecode: &[u8],
349    ctx: &mut ResponseContext,
350    timeout: Duration,
351) -> Result<Option<Response<Body>>, String> {
352    let lua = make_lua()?;
353    install_timeout(&lua, timeout)?;
354
355    let t = build_response_table(&lua, ctx)?;
356    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
357
358    let func = lua
359        .load(bytecode)
360        .into_function()
361        .map_err(|e| e.to_string())?;
362    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
363
364    // Write back mutations.
365    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
366    let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
367    ctx.resp_headers = lua_table_to_map(&rh_table)?;
368
369    if ctx.body.is_some() {
370        // Only write back body if it was present (response.after event).
371        let new_body: Option<String> = ctx_global.get("body").ok();
372        ctx.body = new_body;
373    }
374
375    response_from_lua(result)
376}
377
378// ── Table builders ────────────────────────────────────────────────────────────
379
380fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
381    let t = lua.create_table().map_err(|e| e.to_string())?;
382    t.set("method", ctx.method.as_str())
383        .map_err(|e| e.to_string())?;
384    t.set("path", ctx.path.as_str())
385        .map_err(|e| e.to_string())?;
386    t.set("query", ctx.query.as_str())
387        .map_err(|e| e.to_string())?;
388    t.set("client_ip", ctx.client_ip.as_str())
389        .map_err(|e| e.to_string())?;
390    t.set("request_id", ctx.request_id.as_str())
391        .map_err(|e| e.to_string())?;
392    t.set("short_circuited", ctx.short_circuited)
393        .map_err(|e| e.to_string())?;
394    let headers = map_to_lua_table(lua, &ctx.headers)?;
395    t.set("headers", headers).map_err(|e| e.to_string())?;
396    let extra = map_to_lua_table(lua, &ctx.extra)?;
397    t.set("extra", extra).map_err(|e| e.to_string())?;
398    if let Some(ref err) = ctx.error {
399        t.set("error", err.as_str()).map_err(|e| e.to_string())?;
400    }
401    Ok(t)
402}
403
404fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
405    let t = lua.create_table().map_err(|e| e.to_string())?;
406    t.set("status", ctx.status).map_err(|e| e.to_string())?;
407    t.set("short_circuited", ctx.short_circuited)
408        .map_err(|e| e.to_string())?;
409    let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
410    t.set("resp_headers", rh).map_err(|e| e.to_string())?;
411    if let Some(ref body) = ctx.body {
412        t.set("body", body.as_str()).map_err(|e| e.to_string())?;
413    }
414    Ok(t)
415}
416
417// ── Small utilities ───────────────────────────────────────────────────────────
418
419fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
420    let start = Instant::now();
421    lua.set_hook(
422        HookTriggers::new().every_nth_instruction(100),
423        move |_lua, _debug| {
424            if start.elapsed() > timeout {
425                Err(mlua::Error::runtime("lua hook timeout"))
426            } else {
427                Ok(VmState::Continue)
428            }
429        },
430    );
431    Ok(())
432}
433
434fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
435    let t = lua.create_table().map_err(|e| e.to_string())?;
436    for (k, v) in map {
437        t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
438    }
439    Ok(t)
440}
441
442fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
443    let mut map = HashMap::new();
444    for pair in table.clone().pairs::<String, String>() {
445        let (k, v) = pair.map_err(|e| e.to_string())?;
446        map.insert(k, v);
447    }
448    Ok(map)
449}
450
451fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
452    match value {
453        LuaValue::Nil => Ok(None),
454        LuaValue::Table(t) => {
455            let status: u16 = t.get("status").unwrap_or(200u16);
456            let body: String = t.get("body").unwrap_or_default();
457            let resp = Response::builder()
458                .status(status)
459                .body(Body::from(body))
460                .map_err(|e| e.to_string())?;
461            Ok(Some(resp))
462        }
463        _ => Ok(None),
464    }
465}
466
467fn internal_error_response() -> Response<Body> {
468    Response::builder()
469        .status(500)
470        .body(Body::from("internal server error (hook fail_closed)"))
471        .unwrap()
472}