Skip to main content

folk_plugin_http/
hooks.rs

1//! Lua hook pipeline for folk-plugin-http.
2//!
3//! Hook scripts are pre-compiled to Lua bytecode at startup and executed
4//! per-request on a fresh [`mlua::Lua`] VM instance.  Sync hooks run in
5//! the request critical path; async hooks fire-and-forget via `tokio::spawn`.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18// ── Public context types ──────────────────────────────────────────────────────
19
20/// Context passed to `request.before` and `request.error` hooks.
21#[derive(Debug, Clone)]
22pub struct RequestContext {
23    pub method: String,
24    pub path: String,
25    pub query: String,
26    pub client_ip: String,
27    pub request_id: String,
28    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
29    pub headers: HashMap<String, String>,
30    /// Arbitrary key-value bag, mutable in sync hooks.
31    pub extra: HashMap<String, String>,
32    /// Error message. Set only for `request.error` hooks.
33    pub error: Option<String>,
34    /// True when a previous sync hook short-circuited the pipeline.
35    pub short_circuited: bool,
36}
37
38/// Context passed to `response.headers` and `response.after` hooks.
39#[derive(Debug, Clone)]
40pub struct ResponseContext {
41    pub status: u16,
42    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
43    pub resp_headers: HashMap<String, String>,
44    /// Present only for `response.after` hooks. Raw bytes — no lossy UTF-8 conversion.
45    pub body: Option<Vec<u8>>,
46    /// True when a previous sync hook short-circuited the pipeline.
47    pub short_circuited: bool,
48}
49
50// ── Hook result ───────────────────────────────────────────────────────────────
51
52/// Outcome of running a hook stage.
53pub enum HookResult {
54    /// Continue normal processing.
55    Continue,
56    /// A sync hook returned a short-circuit response.
57    ShortCircuit(Response<Body>),
58}
59
60// ── Internals ─────────────────────────────────────────────────────────────────
61
62struct CompiledHook {
63    config: HookConfig,
64    /// Pre-compiled Lua bytecode.
65    bytecode: Vec<u8>,
66}
67
68// ── HookEngine ────────────────────────────────────────────────────────────────
69
70/// Holds all pre-compiled hooks for the lifetime of the server.
71#[derive(Clone)]
72pub struct HookEngine {
73    hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77    /// Compile all hook scripts.  Scripts that fail to compile are skipped with
78    /// a WARN log; the server starts normally regardless.
79    pub fn new(configs: &[HookConfig]) -> Self {
80        let mut compiled = Vec::with_capacity(configs.len());
81        for cfg in configs {
82            match compile_script(cfg) {
83                Ok(bytecode) => compiled.push(CompiledHook {
84                    config: cfg.clone(),
85                    bytecode,
86                }),
87                Err(e) => {
88                    warn!(
89                        event = %cfg.event,
90                        script = %cfg.lua.display(),
91                        error = %e,
92                        "lua hook script failed to compile — skipping"
93                    );
94                }
95            }
96        }
97        Self {
98            hooks: Arc::new(compiled),
99        }
100    }
101
102    /// Run `request.before` hooks.
103    pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104        self.run_request_stage("request.before", ctx)
105    }
106
107    /// Run `request.error` hooks.
108    ///
109    /// Sync hooks run inline (critical path) and respect `fail_closed` /
110    /// short-circuit.  Async hooks fire-and-forget.  Mirrors `run_request_before`.
111    pub fn run_request_error(&self, ctx: &mut RequestContext) -> HookResult {
112        self.run_request_stage("request.error", ctx)
113    }
114
115    /// Run `response.headers` hooks.
116    pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
117        self.run_response_stage("response.headers", ctx)
118    }
119
120    /// Run `response.after` hooks.
121    pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
122        self.run_response_stage("response.after", ctx)
123    }
124
125    /// Returns true when at least one compiled hook is registered for `event`.
126    pub fn has_event(&self, event: &str) -> bool {
127        self.hooks.iter().any(|h| h.config.event == event)
128    }
129
130    // ── Internal stage runners ────────────────────────────────────────────────
131
132    fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
133        let hooks: Vec<_> = self
134            .hooks
135            .iter()
136            .filter(|h| h.config.event == event)
137            .collect();
138
139        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
140            .into_iter()
141            .partition(|h| h.config.mode == HookMode::Sync);
142
143        let mut sc_response: Option<Response<Body>> = None;
144
145        for hook in &sync_hooks {
146            if sc_response.is_some() {
147                break;
148            }
149            let timeout = Duration::from_millis(hook.config.timeout_ms);
150            match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
151                Ok(Some(resp)) => {
152                    ctx.short_circuited = true;
153                    sc_response = Some(resp);
154                }
155                Ok(None) => {}
156                Err(e) => {
157                    warn!(
158                        event = event,
159                        script = %hook.config.lua.display(),
160                        error = %e,
161                        "lua sync hook error"
162                    );
163                    if hook.config.on_error == HookErrorBehavior::FailClosed {
164                        return HookResult::ShortCircuit(internal_error_response());
165                    }
166                }
167            }
168        }
169
170        // Spawn async hooks with a read-only snapshot (taken after sync phase).
171        let ctx_snap = ctx.clone();
172        for hook in async_hooks {
173            let bytecode = hook.bytecode.clone();
174            let snap = ctx_snap.clone();
175            let timeout = Duration::from_millis(hook.config.timeout_ms);
176            tokio::spawn(async move {
177                if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
178                    warn!(error = %e, "lua async request hook failed");
179                }
180            });
181        }
182
183        match sc_response {
184            Some(resp) => HookResult::ShortCircuit(resp),
185            None => HookResult::Continue,
186        }
187    }
188
189    fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
190        let hooks: Vec<_> = self
191            .hooks
192            .iter()
193            .filter(|h| h.config.event == event)
194            .collect();
195
196        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
197            .into_iter()
198            .partition(|h| h.config.mode == HookMode::Sync);
199
200        let mut sc_response: Option<Response<Body>> = None;
201
202        for hook in &sync_hooks {
203            if sc_response.is_some() {
204                break;
205            }
206            let timeout = Duration::from_millis(hook.config.timeout_ms);
207            match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
208                Ok(Some(resp)) => {
209                    ctx.short_circuited = true;
210                    sc_response = Some(resp);
211                }
212                Ok(None) => {}
213                Err(e) => {
214                    warn!(
215                        event = event,
216                        script = %hook.config.lua.display(),
217                        error = %e,
218                        "lua sync response hook error"
219                    );
220                    if hook.config.on_error == HookErrorBehavior::FailClosed {
221                        return HookResult::ShortCircuit(internal_error_response());
222                    }
223                }
224            }
225        }
226
227        let ctx_snap = ctx.clone();
228        for hook in async_hooks {
229            let bytecode = hook.bytecode.clone();
230            let snap = ctx_snap.clone();
231            let timeout = Duration::from_millis(hook.config.timeout_ms);
232            tokio::spawn(async move {
233                if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
234                    warn!(error = %e, "lua async response hook failed");
235                }
236            });
237        }
238
239        match sc_response {
240            Some(resp) => HookResult::ShortCircuit(resp),
241            None => HookResult::Continue,
242        }
243    }
244}
245
246// ── Script compilation ────────────────────────────────────────────────────────
247
248fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
249    let source =
250        std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
251    let lua = make_lua()?;
252    let func = lua
253        .load(&source)
254        .into_function()
255        .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
256    // dump() returns Vec<u8> directly in mlua 0.10
257    Ok(func.dump(false))
258}
259
260fn make_lua() -> Result<Lua, String> {
261    // Safe subset — no io, os, debug, package, coroutine.
262    Lua::new_with(
263        StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
264        LuaOptions::default(),
265    )
266    .map_err(|e| format!("lua init: {e}"))
267}
268
269// ── Per-request execution helpers ─────────────────────────────────────────────
270
271/// Execute bytecode with an immutable RequestContext.  Used for async hooks.
272fn exec_request_hook(
273    bytecode: &[u8],
274    ctx: &RequestContext,
275    timeout: Duration,
276) -> Result<Option<Response<Body>>, String> {
277    let lua = make_lua()?;
278    install_timeout(&lua, timeout)?;
279
280    let t = build_request_table(&lua, ctx)?;
281    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
282
283    let func = lua
284        .load(bytecode)
285        .into_function()
286        .map_err(|e| e.to_string())?;
287    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
288    response_from_lua(result)
289}
290
291/// Execute bytecode with a mutable RequestContext — mutations to headers/extra
292/// are written back into `ctx` so subsequent sync hooks see them.
293fn exec_request_hook_mut(
294    bytecode: &[u8],
295    ctx: &mut RequestContext,
296    timeout: Duration,
297) -> Result<Option<Response<Body>>, String> {
298    let lua = make_lua()?;
299    install_timeout(&lua, timeout)?;
300
301    let t = build_request_table(&lua, ctx)?;
302    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
303
304    let func = lua
305        .load(bytecode)
306        .into_function()
307        .map_err(|e| e.to_string())?;
308    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
309
310    // Write back mutations.
311    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
312    let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
313    ctx.headers = lua_table_to_map(&headers_table)?;
314    let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
315    ctx.extra = lua_table_to_map(&extra_table)?;
316
317    response_from_lua(result)
318}
319
320/// Execute bytecode with an immutable ResponseContext.  Used for async hooks.
321fn exec_response_hook(
322    bytecode: &[u8],
323    ctx: &ResponseContext,
324    timeout: Duration,
325) -> Result<Option<Response<Body>>, String> {
326    let lua = make_lua()?;
327    install_timeout(&lua, timeout)?;
328
329    let t = build_response_table(&lua, ctx)?;
330    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
331
332    let func = lua
333        .load(bytecode)
334        .into_function()
335        .map_err(|e| e.to_string())?;
336    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
337    response_from_lua(result)
338}
339
340/// Execute bytecode with a mutable ResponseContext — mutations to resp_headers
341/// and body are written back.
342fn exec_response_hook_mut(
343    bytecode: &[u8],
344    ctx: &mut ResponseContext,
345    timeout: Duration,
346) -> Result<Option<Response<Body>>, String> {
347    let lua = make_lua()?;
348    install_timeout(&lua, timeout)?;
349
350    let t = build_response_table(&lua, ctx)?;
351    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
352
353    let func = lua
354        .load(bytecode)
355        .into_function()
356        .map_err(|e| e.to_string())?;
357    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
358
359    // Write back mutations.
360    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
361    let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
362    ctx.resp_headers = lua_table_to_map(&rh_table)?;
363
364    if ctx.body.is_some() {
365        // Only write back body if it was present (response.after event).
366        // Use mlua::String to preserve raw bytes — avoids any UTF-8 conversion.
367        let new_body: Option<mlua::String> = ctx_global.get("body").ok();
368        ctx.body = new_body.map(|s| s.as_bytes().to_vec());
369    }
370
371    response_from_lua(result)
372}
373
374// ── Table builders ────────────────────────────────────────────────────────────
375
376fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
377    let t = lua.create_table().map_err(|e| e.to_string())?;
378    t.set("method", ctx.method.as_str())
379        .map_err(|e| e.to_string())?;
380    t.set("path", ctx.path.as_str())
381        .map_err(|e| e.to_string())?;
382    t.set("query", ctx.query.as_str())
383        .map_err(|e| e.to_string())?;
384    t.set("client_ip", ctx.client_ip.as_str())
385        .map_err(|e| e.to_string())?;
386    t.set("request_id", ctx.request_id.as_str())
387        .map_err(|e| e.to_string())?;
388    t.set("short_circuited", ctx.short_circuited)
389        .map_err(|e| e.to_string())?;
390    let headers = map_to_lua_table(lua, &ctx.headers)?;
391    t.set("headers", headers).map_err(|e| e.to_string())?;
392    let extra = map_to_lua_table(lua, &ctx.extra)?;
393    t.set("extra", extra).map_err(|e| e.to_string())?;
394    if let Some(ref err) = ctx.error {
395        t.set("error", err.as_str()).map_err(|e| e.to_string())?;
396    }
397    Ok(t)
398}
399
400fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
401    let t = lua.create_table().map_err(|e| e.to_string())?;
402    t.set("status", ctx.status).map_err(|e| e.to_string())?;
403    t.set("short_circuited", ctx.short_circuited)
404        .map_err(|e| e.to_string())?;
405    let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
406    t.set("resp_headers", rh).map_err(|e| e.to_string())?;
407    if let Some(ref body) = ctx.body {
408        // Lua strings are byte strings — binary bodies pass through without corruption.
409        let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
410        t.set("body", lua_str).map_err(|e| e.to_string())?;
411    }
412    Ok(t)
413}
414
415// ── Small utilities ───────────────────────────────────────────────────────────
416
417fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
418    let start = Instant::now();
419    lua.set_hook(
420        HookTriggers::new().every_nth_instruction(100),
421        move |_lua, _debug| {
422            if start.elapsed() > timeout {
423                Err(mlua::Error::runtime("lua hook timeout"))
424            } else {
425                Ok(VmState::Continue)
426            }
427        },
428    );
429    Ok(())
430}
431
432fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
433    let t = lua.create_table().map_err(|e| e.to_string())?;
434    for (k, v) in map {
435        t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
436    }
437    Ok(t)
438}
439
440fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
441    let mut map = HashMap::new();
442    for pair in table.clone().pairs::<String, String>() {
443        let (k, v) = pair.map_err(|e| e.to_string())?;
444        map.insert(k, v);
445    }
446    Ok(map)
447}
448
449fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
450    match value {
451        LuaValue::Nil => Ok(None),
452        LuaValue::Table(t) => {
453            let status: u16 = t.get("status").unwrap_or(200u16);
454            let body: String = t.get("body").unwrap_or_default();
455            let resp = Response::builder()
456                .status(status)
457                .body(Body::from(body))
458                .map_err(|e| e.to_string())?;
459            Ok(Some(resp))
460        }
461        _ => Ok(None),
462    }
463}
464
465fn internal_error_response() -> Response<Body> {
466    Response::builder()
467        .status(500)
468        .body(Body::from("internal server error (hook fail_closed)"))
469        .unwrap()
470}