Skip to main content

courier/transforms/script/
lua.rs

1use anyhow::{Context, Result, anyhow, bail};
2use async_trait::async_trait;
3use mlua::{Function, Lua, LuaSerdeExt, MultiValue, Value};
4
5use crate::config::redact_secret;
6use crate::envelope::Envelope;
7
8use super::{ScriptEngine, ScriptTransformConfig};
9
10pub struct LuaEngine {
11    lua: Lua,
12    entrypoint: String,
13}
14
15#[async_trait]
16impl ScriptEngine for LuaEngine {
17    async fn run(&self, env: Envelope) -> Result<Option<Envelope>> {
18        self.run_inner(env)
19    }
20}
21
22impl LuaEngine {
23    pub(super) fn new(config: &ScriptTransformConfig) -> Result<Self> {
24        let lua = Lua::new();
25        lua.load(&config.script)
26            .exec()
27            .context("failed to compile Lua script")?;
28
29        let globals = lua.globals();
30        let _: Function = globals.get(config.entrypoint.as_str()).with_context(|| {
31            format!(
32                "missing Lua entrypoint '{}'",
33                redact_secret(&config.entrypoint)
34            )
35        })?;
36
37        Ok(Self {
38            lua,
39            entrypoint: config.entrypoint.clone(),
40        })
41    }
42
43    #[cfg(test)]
44    fn run(&self, env: Envelope) -> Result<Option<Envelope>> {
45        self.run_inner(env)
46    }
47
48    fn run_inner(&self, env: Envelope) -> Result<Option<Envelope>> {
49        let globals = self.lua.globals();
50        let entrypoint: Function = globals.get(self.entrypoint.as_str()).with_context(|| {
51            format!(
52                "missing Lua entrypoint '{}'",
53                redact_secret(&self.entrypoint)
54            )
55        })?;
56        let arg = self
57            .lua
58            .to_value(&env)
59            .context("failed to convert envelope into Lua value")?;
60        let out: MultiValue = entrypoint.call((arg,)).with_context(|| {
61            format!(
62                "Lua entrypoint '{}' failed",
63                redact_secret(&self.entrypoint)
64            )
65        })?;
66
67        let mut values = out.into_vec();
68        let value = match values.len() {
69            0 => return Ok(None),
70            1 => values.pop().expect("single return value expected"),
71            _ => bail!(
72                "Lua entrypoint '{}' returned multiple values",
73                redact_secret(&self.entrypoint)
74            ),
75        };
76
77        match value {
78            Value::Nil => Ok(None),
79            other => self.lua.from_value(other).map(Some).map_err(|err| {
80                anyhow!(err).context("failed to convert Lua return value into envelope")
81            }),
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use serde_json::json;
89
90    use super::LuaEngine;
91    use crate::envelope::Envelope;
92    use crate::transforms::script::{ScriptRuntime, ScriptTransformConfig};
93
94    fn config(script: &str) -> ScriptTransformConfig {
95        ScriptTransformConfig {
96            runtime: ScriptRuntime::Lua,
97            script: script.into(),
98            entrypoint: "transform".into(),
99            python: None,
100            rhai: None,
101        }
102    }
103
104    #[test]
105    fn mutates_payload() {
106        let engine = LuaEngine::new(&config(
107            r#"
108                function transform(env)
109                    env.payload.processed = true
110                    return env
111                end
112            "#,
113        ))
114        .unwrap();
115
116        let out = engine
117            .run(Envelope::new("src", json!({ "value": 1 })))
118            .unwrap()
119            .unwrap();
120        assert_eq!(out.payload, json!({ "value": 1, "processed": true }));
121    }
122
123    #[test]
124    fn mutates_metadata() {
125        let engine = LuaEngine::new(&config(
126            r#"
127                function transform(env)
128                    env.meta.headers.script_runtime = "lua"
129                    return env
130                end
131            "#,
132        ))
133        .unwrap();
134
135        let out = engine
136            .run(Envelope::new("src", json!({})))
137            .unwrap()
138            .unwrap();
139        assert_eq!(
140            out.meta.headers.get("script_runtime").map(String::as_str),
141            Some("lua")
142        );
143    }
144
145    #[test]
146    fn nil_return_filters_envelope() {
147        let engine = LuaEngine::new(&config("function transform(env) return nil end")).unwrap();
148
149        let out = engine
150            .run(Envelope::new("src", json!({ "skip": true })))
151            .unwrap();
152        assert!(out.is_none());
153    }
154
155    #[test]
156    fn compile_error_fails_build() {
157        let err = LuaEngine::new(&config("function transform(env) local = end"))
158            .err()
159            .expect("expected compile error");
160        let msg = format!("{err:#}");
161        assert!(msg.contains("failed to compile Lua script"), "{msg}");
162    }
163
164    #[test]
165    fn missing_entrypoint_fails_build() {
166        let err = LuaEngine::new(&config("function other(env) return env end"))
167            .err()
168            .expect("expected missing entrypoint error");
169        let msg = format!("{err:#}");
170        assert!(msg.contains("missing Lua entrypoint 'transform'"), "{msg}");
171    }
172
173    #[test]
174    fn invalid_return_shape_fails_run() {
175        let engine = LuaEngine::new(&config("function transform(env) return 42 end")).unwrap();
176
177        let err = engine.run(Envelope::new("src", json!({}))).unwrap_err();
178        let msg = format!("{err:#}");
179        assert!(
180            msg.contains("failed to convert Lua return value into envelope"),
181            "{msg}"
182        );
183    }
184
185    #[test]
186    fn runtime_exception_fails_run() {
187        let engine = LuaEngine::new(&config("function transform(env) error('boom') end")).unwrap();
188
189        let err = engine.run(Envelope::new("src", json!({}))).unwrap_err();
190        let msg = format!("{err:#}");
191        assert!(msg.contains("Lua entrypoint 'transform' failed"), "{msg}");
192    }
193}