Skip to main content

folk_plugin_http/
hooks.rs

1//! Lua hook pipeline for folk-plugin-http.
2//!
3//! Hook scripts are pre-compiled to Lua bytecode at startup and executed
4//! per-request on a fresh [`mlua::Lua`] VM instance.  Sync hooks run in
5//! the request critical path; async hooks fire-and-forget via `tokio::spawn`.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::body::Body;
12use axum::http::Response;
13use mlua::{HookTriggers, Lua, LuaOptions, StdLib, Table, Value as LuaValue, VmState};
14use tracing::warn;
15
16use crate::config::{HookConfig, HookErrorBehavior, HookMode};
17
18// ── Public context types ──────────────────────────────────────────────────────
19
20/// Context passed to `request.before` and `request.error` hooks.
21#[derive(Debug, Clone)]
22pub struct RequestContext {
23    pub method: String,
24    pub path: String,
25    pub query: String,
26    pub client_ip: String,
27    pub request_id: String,
28    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
29    pub headers: HashMap<String, String>,
30    /// Arbitrary key-value bag, mutable in sync hooks.
31    pub extra: HashMap<String, String>,
32    /// Error message. Set only for `request.error` hooks.
33    pub error: Option<String>,
34    /// True when a previous sync hook short-circuited the pipeline.
35    pub short_circuited: bool,
36}
37
38/// Context passed to `response.headers` and `response.after` hooks.
39#[derive(Debug, Clone)]
40pub struct ResponseContext {
41    pub status: u16,
42    /// Mutable in sync hooks — mutations propagate to subsequent hooks.
43    pub resp_headers: HashMap<String, String>,
44    /// Present only for `response.after` hooks. Raw bytes — no lossy UTF-8 conversion.
45    pub body: Option<Vec<u8>>,
46    /// True when a previous sync hook short-circuited the pipeline.
47    pub short_circuited: bool,
48}
49
50// ── Hook result ───────────────────────────────────────────────────────────────
51
52/// Outcome of running a hook stage.
53pub enum HookResult {
54    /// Continue normal processing.
55    Continue,
56    /// A sync hook returned a short-circuit response.
57    ShortCircuit(Response<Body>),
58}
59
60// ── Internals ─────────────────────────────────────────────────────────────────
61
62struct CompiledHook {
63    config: HookConfig,
64    /// Pre-compiled Lua bytecode.
65    bytecode: Vec<u8>,
66}
67
68// ── HookEngine ────────────────────────────────────────────────────────────────
69
70/// Holds all pre-compiled hooks for the lifetime of the server.
71#[derive(Clone)]
72pub struct HookEngine {
73    hooks: Arc<Vec<CompiledHook>>,
74}
75
76impl HookEngine {
77    /// Compile all hook scripts.  Scripts that fail to compile are skipped with
78    /// a WARN log; the server starts normally regardless.
79    pub fn new(configs: &[HookConfig]) -> Self {
80        let mut compiled = Vec::with_capacity(configs.len());
81        for cfg in configs {
82            match compile_script(cfg) {
83                Ok(bytecode) => compiled.push(CompiledHook {
84                    config: cfg.clone(),
85                    bytecode,
86                }),
87                Err(e) => {
88                    warn!(
89                        event = %cfg.event,
90                        script = %cfg.lua.display(),
91                        error = %e,
92                        "lua hook script failed to compile — skipping"
93                    );
94                }
95            }
96        }
97        Self {
98            hooks: Arc::new(compiled),
99        }
100    }
101
102    /// Run `request.before` hooks.
103    pub fn run_request_before(&self, ctx: &mut RequestContext) -> HookResult {
104        self.run_request_stage("request.before", ctx)
105    }
106
107    /// Run `request.error` hooks — all spawned as async; no short-circuit.
108    pub fn run_request_error(&self, ctx: &RequestContext) {
109        for hook in self
110            .hooks
111            .iter()
112            .filter(|h| h.config.event == "request.error")
113        {
114            let bytecode = hook.bytecode.clone();
115            let ctx_clone = ctx.clone();
116            let timeout = Duration::from_millis(hook.config.timeout_ms);
117            tokio::spawn(async move {
118                if let Err(e) = exec_request_hook(&bytecode, &ctx_clone, timeout) {
119                    warn!(error = %e, "lua request.error hook failed");
120                }
121            });
122        }
123    }
124
125    /// Run `response.headers` hooks.
126    pub fn run_response_headers(&self, ctx: &mut ResponseContext) -> HookResult {
127        self.run_response_stage("response.headers", ctx)
128    }
129
130    /// Run `response.after` hooks.
131    pub fn run_response_after(&self, ctx: &mut ResponseContext) -> HookResult {
132        self.run_response_stage("response.after", ctx)
133    }
134
135    /// Returns true when at least one compiled hook is registered for `event`.
136    pub fn has_event(&self, event: &str) -> bool {
137        self.hooks.iter().any(|h| h.config.event == event)
138    }
139
140    // ── Internal stage runners ────────────────────────────────────────────────
141
142    fn run_request_stage(&self, event: &str, ctx: &mut RequestContext) -> HookResult {
143        let hooks: Vec<_> = self
144            .hooks
145            .iter()
146            .filter(|h| h.config.event == event)
147            .collect();
148
149        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
150            .into_iter()
151            .partition(|h| h.config.mode == HookMode::Sync);
152
153        let mut sc_response: Option<Response<Body>> = None;
154
155        for hook in &sync_hooks {
156            if sc_response.is_some() {
157                break;
158            }
159            let timeout = Duration::from_millis(hook.config.timeout_ms);
160            match exec_request_hook_mut(&hook.bytecode, ctx, timeout) {
161                Ok(Some(resp)) => {
162                    ctx.short_circuited = true;
163                    sc_response = Some(resp);
164                }
165                Ok(None) => {}
166                Err(e) => {
167                    warn!(
168                        event = event,
169                        script = %hook.config.lua.display(),
170                        error = %e,
171                        "lua sync hook error"
172                    );
173                    if hook.config.on_error == HookErrorBehavior::FailClosed {
174                        return HookResult::ShortCircuit(internal_error_response());
175                    }
176                }
177            }
178        }
179
180        // Spawn async hooks with a read-only snapshot (taken after sync phase).
181        let ctx_snap = ctx.clone();
182        for hook in async_hooks {
183            let bytecode = hook.bytecode.clone();
184            let snap = ctx_snap.clone();
185            let timeout = Duration::from_millis(hook.config.timeout_ms);
186            tokio::spawn(async move {
187                if let Err(e) = exec_request_hook(&bytecode, &snap, timeout) {
188                    warn!(error = %e, "lua async request hook failed");
189                }
190            });
191        }
192
193        match sc_response {
194            Some(resp) => HookResult::ShortCircuit(resp),
195            None => HookResult::Continue,
196        }
197    }
198
199    fn run_response_stage(&self, event: &str, ctx: &mut ResponseContext) -> HookResult {
200        let hooks: Vec<_> = self
201            .hooks
202            .iter()
203            .filter(|h| h.config.event == event)
204            .collect();
205
206        let (sync_hooks, async_hooks): (Vec<_>, Vec<_>) = hooks
207            .into_iter()
208            .partition(|h| h.config.mode == HookMode::Sync);
209
210        let mut sc_response: Option<Response<Body>> = None;
211
212        for hook in &sync_hooks {
213            if sc_response.is_some() {
214                break;
215            }
216            let timeout = Duration::from_millis(hook.config.timeout_ms);
217            match exec_response_hook_mut(&hook.bytecode, ctx, timeout) {
218                Ok(Some(resp)) => {
219                    ctx.short_circuited = true;
220                    sc_response = Some(resp);
221                }
222                Ok(None) => {}
223                Err(e) => {
224                    warn!(
225                        event = event,
226                        script = %hook.config.lua.display(),
227                        error = %e,
228                        "lua sync response hook error"
229                    );
230                    if hook.config.on_error == HookErrorBehavior::FailClosed {
231                        return HookResult::ShortCircuit(internal_error_response());
232                    }
233                }
234            }
235        }
236
237        let ctx_snap = ctx.clone();
238        for hook in async_hooks {
239            let bytecode = hook.bytecode.clone();
240            let snap = ctx_snap.clone();
241            let timeout = Duration::from_millis(hook.config.timeout_ms);
242            tokio::spawn(async move {
243                if let Err(e) = exec_response_hook(&bytecode, &snap, timeout) {
244                    warn!(error = %e, "lua async response hook failed");
245                }
246            });
247        }
248
249        match sc_response {
250            Some(resp) => HookResult::ShortCircuit(resp),
251            None => HookResult::Continue,
252        }
253    }
254}
255
256// ── Script compilation ────────────────────────────────────────────────────────
257
258fn compile_script(cfg: &HookConfig) -> Result<Vec<u8>, String> {
259    let source =
260        std::fs::read_to_string(&cfg.lua).map_err(|e| format!("read {:?}: {e}", cfg.lua))?;
261    let lua = make_lua()?;
262    let func = lua
263        .load(&source)
264        .into_function()
265        .map_err(|e| format!("compile {:?}: {e}", cfg.lua))?;
266    // dump() returns Vec<u8> directly in mlua 0.10
267    Ok(func.dump(false))
268}
269
270fn make_lua() -> Result<Lua, String> {
271    // Safe subset — no io, os, debug, package, coroutine.
272    Lua::new_with(
273        StdLib::TABLE | StdLib::STRING | StdLib::MATH | StdLib::UTF8,
274        LuaOptions::default(),
275    )
276    .map_err(|e| format!("lua init: {e}"))
277}
278
279// ── Per-request execution helpers ─────────────────────────────────────────────
280
281/// Execute bytecode with an immutable RequestContext.  Used for async hooks.
282fn exec_request_hook(
283    bytecode: &[u8],
284    ctx: &RequestContext,
285    timeout: Duration,
286) -> Result<Option<Response<Body>>, String> {
287    let lua = make_lua()?;
288    install_timeout(&lua, timeout)?;
289
290    let t = build_request_table(&lua, ctx)?;
291    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
292
293    let func = lua
294        .load(bytecode)
295        .into_function()
296        .map_err(|e| e.to_string())?;
297    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
298    response_from_lua(result)
299}
300
301/// Execute bytecode with a mutable RequestContext — mutations to headers/extra
302/// are written back into `ctx` so subsequent sync hooks see them.
303fn exec_request_hook_mut(
304    bytecode: &[u8],
305    ctx: &mut RequestContext,
306    timeout: Duration,
307) -> Result<Option<Response<Body>>, String> {
308    let lua = make_lua()?;
309    install_timeout(&lua, timeout)?;
310
311    let t = build_request_table(&lua, ctx)?;
312    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
313
314    let func = lua
315        .load(bytecode)
316        .into_function()
317        .map_err(|e| e.to_string())?;
318    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
319
320    // Write back mutations.
321    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
322    let headers_table: Table = ctx_global.get("headers").map_err(|e| e.to_string())?;
323    ctx.headers = lua_table_to_map(&headers_table)?;
324    let extra_table: Table = ctx_global.get("extra").map_err(|e| e.to_string())?;
325    ctx.extra = lua_table_to_map(&extra_table)?;
326
327    response_from_lua(result)
328}
329
330/// Execute bytecode with an immutable ResponseContext.  Used for async hooks.
331fn exec_response_hook(
332    bytecode: &[u8],
333    ctx: &ResponseContext,
334    timeout: Duration,
335) -> Result<Option<Response<Body>>, String> {
336    let lua = make_lua()?;
337    install_timeout(&lua, timeout)?;
338
339    let t = build_response_table(&lua, ctx)?;
340    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
341
342    let func = lua
343        .load(bytecode)
344        .into_function()
345        .map_err(|e| e.to_string())?;
346    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
347    response_from_lua(result)
348}
349
350/// Execute bytecode with a mutable ResponseContext — mutations to resp_headers
351/// and body are written back.
352fn exec_response_hook_mut(
353    bytecode: &[u8],
354    ctx: &mut ResponseContext,
355    timeout: Duration,
356) -> Result<Option<Response<Body>>, String> {
357    let lua = make_lua()?;
358    install_timeout(&lua, timeout)?;
359
360    let t = build_response_table(&lua, ctx)?;
361    lua.globals().set("ctx", t).map_err(|e| e.to_string())?;
362
363    let func = lua
364        .load(bytecode)
365        .into_function()
366        .map_err(|e| e.to_string())?;
367    let result: LuaValue = func.call(()).map_err(|e| e.to_string())?;
368
369    // Write back mutations.
370    let ctx_global: Table = lua.globals().get("ctx").map_err(|e| e.to_string())?;
371    let rh_table: Table = ctx_global.get("resp_headers").map_err(|e| e.to_string())?;
372    ctx.resp_headers = lua_table_to_map(&rh_table)?;
373
374    if ctx.body.is_some() {
375        // Only write back body if it was present (response.after event).
376        // Use mlua::String to preserve raw bytes — avoids any UTF-8 conversion.
377        let new_body: Option<mlua::String> = ctx_global.get("body").ok();
378        ctx.body = new_body.map(|s| s.as_bytes().to_vec());
379    }
380
381    response_from_lua(result)
382}
383
384// ── Table builders ────────────────────────────────────────────────────────────
385
386fn build_request_table(lua: &Lua, ctx: &RequestContext) -> Result<Table, String> {
387    let t = lua.create_table().map_err(|e| e.to_string())?;
388    t.set("method", ctx.method.as_str())
389        .map_err(|e| e.to_string())?;
390    t.set("path", ctx.path.as_str())
391        .map_err(|e| e.to_string())?;
392    t.set("query", ctx.query.as_str())
393        .map_err(|e| e.to_string())?;
394    t.set("client_ip", ctx.client_ip.as_str())
395        .map_err(|e| e.to_string())?;
396    t.set("request_id", ctx.request_id.as_str())
397        .map_err(|e| e.to_string())?;
398    t.set("short_circuited", ctx.short_circuited)
399        .map_err(|e| e.to_string())?;
400    let headers = map_to_lua_table(lua, &ctx.headers)?;
401    t.set("headers", headers).map_err(|e| e.to_string())?;
402    let extra = map_to_lua_table(lua, &ctx.extra)?;
403    t.set("extra", extra).map_err(|e| e.to_string())?;
404    if let Some(ref err) = ctx.error {
405        t.set("error", err.as_str()).map_err(|e| e.to_string())?;
406    }
407    Ok(t)
408}
409
410fn build_response_table(lua: &Lua, ctx: &ResponseContext) -> Result<Table, String> {
411    let t = lua.create_table().map_err(|e| e.to_string())?;
412    t.set("status", ctx.status).map_err(|e| e.to_string())?;
413    t.set("short_circuited", ctx.short_circuited)
414        .map_err(|e| e.to_string())?;
415    let rh = map_to_lua_table(lua, &ctx.resp_headers)?;
416    t.set("resp_headers", rh).map_err(|e| e.to_string())?;
417    if let Some(ref body) = ctx.body {
418        // Lua strings are byte strings — binary bodies pass through without corruption.
419        let lua_str = lua.create_string(body).map_err(|e| e.to_string())?;
420        t.set("body", lua_str).map_err(|e| e.to_string())?;
421    }
422    Ok(t)
423}
424
425// ── Small utilities ───────────────────────────────────────────────────────────
426
427fn install_timeout(lua: &Lua, timeout: Duration) -> Result<(), String> {
428    let start = Instant::now();
429    lua.set_hook(
430        HookTriggers::new().every_nth_instruction(100),
431        move |_lua, _debug| {
432            if start.elapsed() > timeout {
433                Err(mlua::Error::runtime("lua hook timeout"))
434            } else {
435                Ok(VmState::Continue)
436            }
437        },
438    );
439    Ok(())
440}
441
442fn map_to_lua_table(lua: &Lua, map: &HashMap<String, String>) -> Result<Table, String> {
443    let t = lua.create_table().map_err(|e| e.to_string())?;
444    for (k, v) in map {
445        t.set(k.as_str(), v.as_str()).map_err(|e| e.to_string())?;
446    }
447    Ok(t)
448}
449
450fn lua_table_to_map(table: &Table) -> Result<HashMap<String, String>, String> {
451    let mut map = HashMap::new();
452    for pair in table.clone().pairs::<String, String>() {
453        let (k, v) = pair.map_err(|e| e.to_string())?;
454        map.insert(k, v);
455    }
456    Ok(map)
457}
458
459fn response_from_lua(value: LuaValue) -> Result<Option<Response<Body>>, String> {
460    match value {
461        LuaValue::Nil => Ok(None),
462        LuaValue::Table(t) => {
463            let status: u16 = t.get("status").unwrap_or(200u16);
464            let body: String = t.get("body").unwrap_or_default();
465            let resp = Response::builder()
466                .status(status)
467                .body(Body::from(body))
468                .map_err(|e| e.to_string())?;
469            Ok(Some(resp))
470        }
471        _ => Ok(None),
472    }
473}
474
475fn internal_error_response() -> Response<Body> {
476    Response::builder()
477        .status(500)
478        .body(Body::from("internal server error (hook fail_closed)"))
479        .unwrap()
480}