Skip to main content

folk_plugin_http/
hooks.rs

1//! Lua hook pipeline for folk-plugin-http.
2//!
3//! Hook scripts are pre-compiled to Lua bytecode at startup and executed
4//! per-request on a fresh [`mlua::Lua`] VM instance.  Sync hooks run in
5//! the request critical path; async hooks fire-and-forget via `tokio::spawn`.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use anyhow::anyhow;
12use axum::body::Body;
13use axum::http::Response;
14use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
15use tracing::warn;
16
17use crate::config::{HookConfig, HookErrorBehavior, HookMode};
18
19// ── Public context types ──────────────────────────────────────────────────────
20
21/// Context passed to `request.before` and `request.error` hooks.
22#[derive(Debug, Clone)]
23pub struct RequestContext {
24    pub method: String,
25    pub path: String,
26    pub query: String,
27    pub client_ip: String,
28    pub request_id: String,
29    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
30    pub headers: HashMap<String, String>,
31    /// Arbitrary key-value bag, mutable in sync hooks.
32    pub extra: HashMap<String, String>,
33    /// Error message. Set only for `request.error` hooks.
34    pub error: Option<String>,
35    /// True when a previous sync hook short-circuited the pipeline.
36    pub short_circuited: bool,
37}
38
39/// Context passed to `response.headers` and `response.after` hooks.
40#[derive(Debug, Clone)]
41pub struct ResponseContext {
42    pub status: u16,
43    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
44    pub resp_headers: HashMap<String, String>,
45    /// Present only for `response.after` hooks. Raw bytes — no lossy UTF-8 conversion.
46    pub body: Option<Vec<u8>>,
47    /// True when a previous sync hook short-circuited the pipeline.
48    pub short_circuited: bool,
49}
50
51// ── Hook result ───────────────────────────────────────────────────────────────
52
53/// Outcome of running a hook stage.
54pub enum HookResult {
55    /// Continue normal processing.
56    Continue,
57    /// A sync hook returned a short-circuit response.
58    ShortCircuit(Response<Body>),
59}
60
61// ── Internals ─────────────────────────────────────────────────────────────────
62
63struct CompiledHook {
64    config: HookConfig,
65    /// Pre-compiled Lua bytecode.
66    bytecode: Vec<u8>,
67}
68
69// ── HookEngine ────────────────────────────────────────────────────────────────
70
71/// Holds all pre-compiled hooks for the lifetime of the server.
72#[derive(Clone)]
73pub struct HookEngine {
74    hooks: Arc<Vec<CompiledHook>>,
75}
76
77impl HookEngine {
78    /// Compile all hook scripts.
79    ///
80    /// If a script fails to compile and `required = false` (default), a WARN is
81    /// emitted and the hook is skipped.  If `required = true`, compilation failure
82    /// returns an error and aborts server startup — use this for security-critical
83    /// hooks where a silently-missing hook would leave requests unprotected.
84    pub fn new(configs: &[HookConfig]) -> anyhow::Result<Self> {
85        let mut compiled = Vec::with_capacity(configs.len());
86        for cfg in configs {
87            match compile_script(cfg) {
88                Ok(bytecode) => compiled.push(CompiledHook {
89                    config: cfg.clone(),
90                    bytecode,
91                }),
92                Err(e) => {
93                    if cfg.required {
94                        return Err(anyhow!(
95                            "required lua hook script failed to compile \
96                             (event = {}, script = {}): {}",
97                            cfg.event,
98                            cfg.lua.display(),
99                            e
100                        ));
101                    }
102                    warn!(
103                        event = %cfg.event,
104                        script = %cfg.lua.display(),
105                        error = %e,
106                        "lua hook script failed to compile — skipping"
107                    );
108                }
109            }
110        }
111        Ok(Self {
112            hooks: Arc::new(compiled),
113        })
114    }
115
116    /// Run `request.before` hooks.
117    pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
118        self.run_request_stage("request.before", ctx)
119    }
120
121    /// Run `request.error` hooks.
122    ///
123    /// Sync hooks run inline (critical path) and respect `fail_closed` /
124    /// short-circuit.  Async hooks fire-and-forget.  Mirrors `run_request_before`.
125    pub fn run_request_error(&self, ctx: &mut RequestContext) -> HookResult {
126        self.run_request_stage("request.error", ctx)
127    }
128
129    /// Run `response.headers` hooks.
130    pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
131        self.run_response_stage("response.headers", ctx)
132    }
133
134    /// Run `response.after` hooks.
135    pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
136        self.run_response_stage("response.after", ctx)
137    }
138
139    /// Returns true when at least one compiled hook is registered for `event`.
140    pub fn has_event(&self, event: &str) -> bool {
141        self.hooks.iter().any(|h| h.config.event == event)
142    }
143
144    // ── Internal stage runners ────────────────────────────────────────────────
145
146    fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
147        let hooks: Vec<_> = self
148            .hooks
149            .iter()
150            .filter(|h| h.config.event == event)
151            .collect();
152
153        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
154            .into_iter()
155            .partition(|h| h.config.mode == HookMode::Sync);
156
157        let mut sc_response: Option<Response<Body>> = None;
158
159        for hook in &sync_hooks {
160            if sc_response.is_some() {
161                break;
162            }
163            let timeout = Duration::from_millis(hook.config.timeout_ms);
164            match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
165                Ok(Some(resp)) => {
166                    ctx.short_circuited = true;
167                    sc_response = Some(resp);
168                }
169                Ok(None) => {}
170                Err(e) => {
171                    warn!(
172                        event = event,
173                        script = %hook.config.lua.display(),
174                        error = %e,
175                        "lua sync hook error"
176                    );
177                    if hook.config.on_error == HookErrorBehavior::FailClosed {
178                        return HookResult::ShortCircuit(internal_error_response());
179                    }
180                }
181            }
182        }
183
184        // Spawn async hooks with a read-only snapshot (taken after sync phase).
185        let ctx_snap = ctx.clone();
186        for hook in async_hooks {
187            let bytecode = hook.bytecode.clone();
188            let snap = ctx_snap.clone();
189            let timeout = Duration::from_millis(hook.config.timeout_ms);
190            tokio::spawn(async move {
191                if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
192                    warn!(error = %e, "lua async request hook failed");
193                }
194            });
195        }
196
197        match sc_response {
198            Some(resp) => HookResult::ShortCircuit(resp),
199            None => HookResult::Continue,
200        }
201    }
202
203    fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
204        let hooks: Vec<_> = self
205            .hooks
206            .iter()
207            .filter(|h| h.config.event == event)
208            .collect();
209
210        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
211            .into_iter()
212            .partition(|h| h.config.mode == HookMode::Sync);
213
214        let mut sc_response: Option<Response<Body>> = None;
215
216        for hook in &sync_hooks {
217            if sc_response.is_some() {
218                break;
219            }
220            let timeout = Duration::from_millis(hook.config.timeout_ms);
221            match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
222                Ok(Some(resp)) => {
223                    ctx.short_circuited = true;
224                    sc_response = Some(resp);
225                }
226                Ok(None) => {}
227                Err(e) => {
228                    warn!(
229                        event = event,
230                        script = %hook.config.lua.display(),
231                        error = %e,
232                        "lua sync response hook error"
233                    );
234                    if hook.config.on_error == HookErrorBehavior::FailClosed {
235                        return HookResult::ShortCircuit(internal_error_response());
236                    }
237                }
238            }
239        }
240
241        let ctx_snap = ctx.clone();
242        for hook in async_hooks {
243            let bytecode = hook.bytecode.clone();
244            let snap = ctx_snap.clone();
245            let timeout = Duration::from_millis(hook.config.timeout_ms);
246            tokio::spawn(async move {
247                if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
248                    warn!(error = %e, "lua async response hook failed");
249                }
250            });
251        }
252
253        match sc_response {
254            Some(resp) => HookResult::ShortCircuit(resp),
255            None => HookResult::Continue,
256        }
257    }
258}
259
260// ── Script compilation ────────────────────────────────────────────────────────
261
262fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
263    let source =
264        std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
265    let lua = make_lua()?;
266    let func = lua
267        .load(&source)
268        .into_function()
269        .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
270    // dump() returns Vec<u8> directly in mlua 0.10
271    Ok(func.dump(false))
272}
273
274fn make_lua() -> Result<Lua, String> {
275    // Safe subset — no io, os, debug, package, coroutine.
276    Lua::new_with(
277        StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
278        LuaOptions::default(),
279    )
280    .map_err(|e| format!("lua init: {e}"))
281}
282
283// ── Per-request execution helpers ─────────────────────────────────────────────
284
285/// Execute bytecode with an immutable RequestContext.  Used for async hooks.
286fn exec_request_hook(
287    bytecode: &[u8],
288    ctx: &RequestContext,
289    timeout: Duration,
290) -> Result<Option<Response<Body>>, String> {
291    let lua = make_lua()?;
292    install_timeout(&lua, timeout)?;
293
294    let t = build_request_table(&lua, ctx)?;
295    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
296
297    let func = lua
298        .load(bytecode)
299        .into_function()
300        .map_err(|e| e.to_string())?;
301    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
302    response_from_lua(result)
303}
304
305/// Execute bytecode with a mutable RequestContext — mutations to headers/extra
306/// are written back into `ctx` so subsequent sync hooks see them.
307fn exec_request_hook_mut(
308    bytecode: &[u8],
309    ctx: &mut RequestContext,
310    timeout: Duration,
311) -> Result<Option<Response<Body>>, String> {
312    let lua = make_lua()?;
313    install_timeout(&lua, timeout)?;
314
315    let t = build_request_table(&lua, ctx)?;
316    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
317
318    let func = lua
319        .load(bytecode)
320        .into_function()
321        .map_err(|e| e.to_string())?;
322    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
323
324    // Write back mutations.
325    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
326    let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
327    ctx.headers = lua_table_to_map(&headers_table)?;
328    let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
329    ctx.extra = lua_table_to_map(&extra_table)?;
330
331    response_from_lua(result)
332}
333
334/// Execute bytecode with an immutable ResponseContext.  Used for async hooks.
335fn exec_response_hook(
336    bytecode: &[u8],
337    ctx: &ResponseContext,
338    timeout: Duration,
339) -> Result<Option<Response<Body>>, String> {
340    let lua = make_lua()?;
341    install_timeout(&lua, timeout)?;
342
343    let t = build_response_table(&lua, ctx)?;
344    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
345
346    let func = lua
347        .load(bytecode)
348        .into_function()
349        .map_err(|e| e.to_string())?;
350    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
351    response_from_lua(result)
352}
353
354/// Execute bytecode with a mutable ResponseContext — mutations to resp_headers
355/// and body are written back.
356fn exec_response_hook_mut(
357    bytecode: &[u8],
358    ctx: &mut ResponseContext,
359    timeout: Duration,
360) -> Result<Option<Response<Body>>, String> {
361    let lua = make_lua()?;
362    install_timeout(&lua, timeout)?;
363
364    let t = build_response_table(&lua, ctx)?;
365    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
366
367    let func = lua
368        .load(bytecode)
369        .into_function()
370        .map_err(|e| e.to_string())?;
371    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
372
373    // Write back mutations.
374    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
375    let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
376    ctx.resp_headers = lua_table_to_map(&rh_table)?;
377
378    if ctx.body.is_some() {
379        // Only write back body if it was present (response.after event).
380        // Use mlua::String to preserve raw bytes — avoids any UTF-8 conversion.
381        let new_body: Option<mlua::String> = ctx_global.get("body").ok();
382        ctx.body = new_body.map(|s| s.as_bytes().to_vec());
383    }
384
385    response_from_lua(result)
386}
387
388// ── Table builders ────────────────────────────────────────────────────────────
389
390fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
391    let t = lua.create_table().map_err(|e| e.to_string())?;
392    t.set("method", ctx.method.as_str())
393        .map_err(|e| e.to_string())?;
394    t.set("path", ctx.path.as_str())
395        .map_err(|e| e.to_string())?;
396    t.set("query", ctx.query.as_str())
397        .map_err(|e| e.to_string())?;
398    t.set("client_ip", ctx.client_ip.as_str())
399        .map_err(|e| e.to_string())?;
400    t.set("request_id", ctx.request_id.as_str())
401        .map_err(|e| e.to_string())?;
402    t.set("short_circuited", ctx.short_circuited)
403        .map_err(|e| e.to_string())?;
404    let headers = map_to_lua_table(lua, &ctx.headers)?;
405    t.set("headers", headers).map_err(|e| e.to_string())?;
406    let extra = map_to_lua_table(lua, &ctx.extra)?;
407    t.set("extra", extra).map_err(|e| e.to_string())?;
408    if let Some(ref err) = ctx.error {
409        t.set("error", err.as_str()).map_err(|e| e.to_string())?;
410    }
411    Ok(t)
412}
413
414fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
415    let t = lua.create_table().map_err(|e| e.to_string())?;
416    t.set("status", ctx.status).map_err(|e| e.to_string())?;
417    t.set("short_circuited", ctx.short_circuited)
418        .map_err(|e| e.to_string())?;
419    let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
420    t.set("resp_headers", rh).map_err(|e| e.to_string())?;
421    if let Some(ref body) = ctx.body {
422        // Lua strings are byte strings — binary bodies pass through without corruption.
423        let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
424        t.set("body", lua_str).map_err(|e| e.to_string())?;
425    }
426    Ok(t)
427}
428
429// ── Small utilities ───────────────────────────────────────────────────────────
430
431fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
432    let start = Instant::now();
433    lua.set_hook(
434        HookTriggers::new().every_nth_instruction(100),
435        move |_lua, _debug| {
436            if start.elapsed() > timeout {
437                Err(mlua::Error::runtime("lua hook timeout"))
438            } else {
439                Ok(VmState::Continue)
440            }
441        },
442    );
443    Ok(())
444}
445
446fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
447    let t = lua.create_table().map_err(|e| e.to_string())?;
448    for (k, v) in map {
449        t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
450    }
451    Ok(t)
452}
453
454fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
455    let mut map = HashMap::new();
456    for pair in table.clone().pairs::<String, String>() {
457        let (k, v) = pair.map_err(|e| e.to_string())?;
458        map.insert(k, v);
459    }
460    Ok(map)
461}
462
463fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
464    match value {
465        LuaValue::Nil => Ok(None),
466        LuaValue::Table(t) => {
467            let status: u16 = t.get("status").unwrap_or(200u16);
468            let body: String = t.get("body").unwrap_or_default();
469            let resp = Response::builder()
470                .status(status)
471                .body(Body::from(body))
472                .map_err(|e| e.to_string())?;
473            Ok(Some(resp))
474        }
475        _ => Ok(None),
476    }
477}
478
479fn internal_error_response() -> Response<Body> {
480    Response::builder()
481        .status(500)
482        .body(Body::from("internal server error (hook fail_closed)"))
483        .unwrap()
484}