Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{borrow::Cow, fmt, ops::Deref};
4
5use minijinja::{
6    Environment,
7    Error as JinjaError,
8    ErrorKind as JinjaErrorKind,
9    State,
10    args,
11    value::{Rest as JinjaRest, Value as JinjaValue},
12};
13use mlua::LuaSerdeExt;
14
15use crate::{
16    convert::{
17        LuaAutoEscape,
18        LuaFunctionObject,
19        LuaObject,
20        LuaSyntaxConfig,
21        LuaUndefinedBehavior,
22        lua_to_minijinja,
23        minijinja_to_lua,
24    },
25    state::bind_lua,
26};
27
28/// A wrapper around a [`minijinja::Environment`]. This wrapper can be serialized into
29/// an [`mlua::UserData`] object for use within mlua::Lua.
30#[derive(mlua::UserData, Debug)]
31pub struct LuaEnvironment(Environment<'static>);
32
33impl From<Environment<'static>> for LuaEnvironment {
34    fn from(value: Environment<'static>) -> Self {
35        LuaEnvironment(value)
36    }
37}
38
39impl From<LuaEnvironment> for Environment<'static> {
40    fn from(value: LuaEnvironment) -> Self {
41        value.0
42    }
43}
44
45impl Deref for LuaEnvironment {
46    type Target = Environment<'static>;
47
48    fn deref(&self) -> &Self::Target {
49        &self.0
50    }
51}
52
53impl fmt::Display for LuaEnvironment {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "Environment")
56    }
57}
58
59#[mlua::userdata_impl]
60impl LuaEnvironment {
61    /// Get a new environment
62    #[lua(name = "new", infallible)]
63    pub(crate) fn lua_new() -> Self {
64        let mut env = Environment::new();
65
66        #[cfg(feature = "minijinja-contrib")]
67        minijinja_contrib::add_to_environment(&mut env);
68
69        #[cfg(feature = "json")]
70        crate::contrib::json::add_to_environment(&mut env);
71
72        #[cfg(feature = "datetime")]
73        crate::contrib::datetime::add_to_environment(&mut env);
74
75        env.into()
76    }
77
78    /// Get a new empty environment
79    #[lua(name = "empty", infallible)]
80    pub(crate) fn lua_empty() -> Self {
81        Environment::empty().into()
82    }
83
84    #[lua(name = "keep_trailing_newline", getter, infallible)]
85    pub(crate) fn lua_keep_trailing_newline(&self) -> bool {
86        self.0.keep_trailing_newline()
87    }
88
89    #[lua(name = "keep_trailing_newline", setter, infallible)]
90    pub(crate) fn lua_set_keep_trailing_newline(&mut self, val: bool) {
91        self.0.set_keep_trailing_newline(val)
92    }
93
94    #[lua(name = "trim_blocks", getter, infallible)]
95    pub(crate) fn lua_trim_blocks(&self) -> bool {
96        self.0.trim_blocks()
97    }
98
99    #[lua(name = "trim_blocks", setter, infallible)]
100    pub(crate) fn lua_set_trim_blocks(&mut self, val: bool) {
101        self.0.set_trim_blocks(val)
102    }
103
104    #[lua(name = "lstrip_blocks", getter, infallible)]
105    pub(crate) fn lua_lstrip_blocks(&self) -> bool {
106        self.0.lstrip_blocks()
107    }
108
109    #[lua(name = "lstrip_blocks", setter, infallible)]
110    pub(crate) fn lua_set_lstrip_blocks(&mut self, val: bool) {
111        self.0.set_lstrip_blocks(val)
112    }
113
114    #[lua(name = "debug", getter, infallible)]
115    pub(crate) fn lua_debug(&self) -> bool {
116        self.0.debug()
117    }
118
119    #[lua(name = "debug", setter, infallible)]
120    pub(crate) fn lua_set_debug(&mut self, val: bool) {
121        self.0.set_debug(val)
122    }
123
124    #[lua(name = "fuel", getter, infallible)]
125    pub(crate) fn lua_fuel(&self) -> Option<u64> {
126        self.0.fuel()
127    }
128
129    #[lua(name = "fuel", setter, infallible)]
130    pub(crate) fn lua_set_fuel(&mut self, val: Option<u64>) {
131        self.0.set_fuel(val)
132    }
133
134    #[lua(name = "recursion_limit", getter, infallible)]
135    pub(crate) fn lua_recursion_limit(&self) -> usize {
136        self.0.recursion_limit()
137    }
138
139    #[lua(name = "recursion_limit", setter, infallible)]
140    pub(crate) fn lua_set_recursion_limit(&mut self, val: usize) {
141        self.0.set_recursion_limit(val)
142    }
143
144    #[lua(name = "undefined_behavior", getter, infallible)]
145    pub(crate) fn lua_undefined_behavior(&self) -> LuaUndefinedBehavior {
146        self.0.undefined_behavior().into()
147    }
148
149    #[lua(name = "undefined_behavior", setter)]
150    pub(crate) fn lua_set_undefined_behavior(
151        &mut self,
152        val: LuaUndefinedBehavior,
153    ) -> mlua::Result<()> {
154        self.0.set_undefined_behavior(val.into());
155
156        Ok(())
157    }
158
159    #[lua(name = "add_template", infallible)]
160    pub(crate) fn lua_add_template(
161        &mut self,
162        lua: &mlua::Lua,
163        name: String,
164        source: String,
165    ) -> mlua::Result<()> {
166        bind_lua(lua, || {
167            self.0
168                .add_template_owned(name, source)
169                .map_err(mlua::Error::external)
170        })
171    }
172
173    #[lua(name = "remove_template", infallible)]
174    pub(crate) fn lua_remove_template(&mut self, lua: &mlua::Lua, name: String) {
175        bind_lua(lua, || self.0.remove_template(&name))
176    }
177
178    #[lua(name = "clear_templates", infallible)]
179    pub(crate) fn lua_clear_templates(&mut self, lua: &mlua::Lua) {
180        bind_lua(lua, || self.0.clear_templates())
181    }
182
183    #[lua(name = "undeclared_variables")]
184    pub(crate) fn lua_undeclared_variables(
185        &mut self,
186        lua: &mlua::Lua,
187        name: String,
188        nested: Option<bool>,
189    ) -> mlua::Result<mlua::Value> {
190        bind_lua(lua, || {
191            let nested = nested.unwrap_or(false);
192
193            let vars = self
194                .0
195                .get_template(&name)
196                .map_err(mlua::Error::external)?
197                .undeclared_variables(nested);
198
199            lua.to_value(&vars)
200        })
201    }
202
203    #[lua(name = "set_loader")]
204    pub(crate) fn lua_set_loader(
205        &mut self,
206        lua: &mlua::Lua,
207        callback: mlua::Function,
208    ) -> mlua::Result<()> {
209        let key = lua.create_registry_value(callback)?;
210        let func = LuaFunctionObject::new(key);
211
212        self.0.set_loader(move |name| {
213            let source = func.with_func(args!(name), None)?;
214            Ok(source.and_then(|v| {
215                // If the lua function returns nil, i.e., no path found
216                // it is mapped as `minijinja::value::ValueKind::Undefined`, however
217                // we need to return a `None` to indicate no path was found.
218                if v.is_undefined() {
219                    None
220                } else {
221                    Some(v.to_string())
222                }
223            }))
224        });
225
226        Ok(())
227    }
228
229    #[lua(name = "set_path_join_callback")]
230    pub(crate) fn lua_set_path_join_callback(
231        &mut self,
232        lua: &mlua::Lua,
233        callback: mlua::Function,
234    ) -> mlua::Result<()> {
235        let key = lua.create_registry_value(callback)?;
236        let func = LuaFunctionObject::new(key);
237
238        self.0.set_path_join_callback(move |name, parent| {
239            func.with_func(args!(name, parent), None)
240                .ok()
241                .flatten()
242                .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
243                .unwrap_or(Cow::Borrowed(name))
244        });
245
246        Ok(())
247    }
248
249    #[lua(name = "set_unknown_method_callback")]
250    pub(crate) fn lua_set_unknown_method_callback(
251        &mut self,
252        lua: &mlua::Lua,
253        callback: mlua::Function,
254    ) -> mlua::Result<()> {
255        let key = lua.create_registry_value(callback)?;
256        let mut func = LuaFunctionObject::new(key);
257        func.set_pass_state(true);
258
259        self.0
260            .set_unknown_method_callback(move |state, value, method, args| {
261                func.with_func(args!(value, method, ..args), Some(state))
262                    .map(|v| v.unwrap_or_default())
263            });
264
265        Ok(())
266    }
267
268    #[cfg(feature = "minijinja-contrib")]
269    #[lua(name = "set_pycompat", infallible)]
270    pub(crate) fn lua_set_pycompat(&mut self, enable: Option<bool>) {
271        match enable {
272            Some(true) | None => self
273                .0
274                .set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback),
275            Some(false) => self.0.set_unknown_method_callback(|_, _, _, _| {
276                Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
277            }),
278        }
279    }
280
281    #[lua(name = "set_auto_escape_callback")]
282    pub(crate) fn lua_set_auto_escape_callback(
283        &mut self,
284        lua: &mlua::Lua,
285        callback: mlua::Function,
286    ) -> mlua::Result<()> {
287        let key = lua.create_registry_value(callback)?;
288        let func = LuaFunctionObject::new(key);
289
290        self.0
291            .set_auto_escape_callback(move |name| -> minijinja::AutoEscape {
292                func.with_func(args!(name), None)
293                    .ok()
294                    .flatten()
295                    .and_then(|v| LuaAutoEscape::try_from(v.to_string().as_str()).ok())
296                    .unwrap_or_default()
297                    .into()
298            });
299
300        Ok(())
301    }
302
303    #[lua(name = "set_formatter")]
304    pub(crate) fn lua_set_formatter(
305        &mut self,
306        lua: &mlua::Lua,
307        callback: mlua::Function,
308    ) -> mlua::Result<()> {
309        let key = lua.create_registry_value(callback)?;
310        let mut func = LuaFunctionObject::new(key);
311        func.set_pass_state(true);
312
313        self.0.set_formatter(move |out, state, value| {
314            let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
315                return Ok(());
316            };
317
318            let Some(s) = val.as_str() else {
319                return Err(JinjaError::new(
320                    JinjaErrorKind::WriteFailure,
321                    "formatter must return a string",
322                ));
323            };
324
325            out.write_str(s)
326                .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
327        });
328
329        Ok(())
330    }
331
332    #[lua(name = "set_syntax")]
333    pub(crate) fn lua_set_syntax(&mut self, syntax: LuaSyntaxConfig) -> mlua::Result<()> {
334        self.0.set_syntax(syntax.into());
335
336        Ok(())
337    }
338
339    #[lua(name = "render_template")]
340    pub(crate) fn lua_render_template(
341        &mut self,
342        lua: &mlua::Lua,
343        name: String,
344        ctx: Option<mlua::Table>,
345    ) -> mlua::Result<String> {
346        let ctx = ctx.unwrap_or(lua.create_table()?);
347
348        let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
349
350        bind_lua(lua, || {
351            self.0
352                .get_template(&name)
353                .map_err(mlua::Error::external)?
354                .render(context)
355                .map_err(mlua::Error::external)
356        })
357    }
358
359    #[lua(name = "render_str")]
360    pub(crate) fn lua_render_str(
361        &self,
362        lua: &mlua::Lua,
363        source: String,
364        ctx: Option<mlua::Table>,
365        name: Option<String>,
366    ) -> mlua::Result<String> {
367        let ctx = ctx.unwrap_or(lua.create_table()?);
368
369        let name = name.unwrap_or("<string>".to_string());
370        let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
371
372        bind_lua(lua, || {
373            self.0
374                .render_named_str(&name, &source, context)
375                .map_err(mlua::Error::external)
376        })
377    }
378
379    #[lua(name = "render_captured")]
380    pub(crate) fn lua_render_captured(
381        &mut self,
382        lua: &mlua::Lua,
383        name: String,
384        ctx: Option<mlua::Table>,
385        callback: mlua::Function,
386    ) -> mlua::Result<mlua::MultiValue> {
387        let key = lua.create_registry_value(callback)?;
388        let mut func = LuaFunctionObject::new(key);
389        func.set_pass_state(true);
390
391        let ctx = ctx.unwrap_or(lua.create_table()?);
392
393        let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
394
395        bind_lua(lua, || {
396            let mut captured = self
397                .0
398                .get_template(&name)
399                .map_err(mlua::Error::external)?
400                .render_captured(context)
401                .map_err(mlua::Error::external)?;
402
403            let mut mv = captured
404                .with_state_mut(|state| func.with_func_mut(&[], Some(state)))
405                .map_err(mlua::Error::external)?
406                .and_then(|v| minijinja_to_lua(lua, &v))
407                .unwrap_or_default();
408
409            let rendered = captured.into_output();
410
411            mv.push_front(mlua::Value::String(lua.create_string(rendered)?));
412
413            Ok(mv)
414        })
415    }
416
417    #[lua(name = "eval")]
418    pub(crate) fn lua_eval(
419        &self,
420        lua: &mlua::Lua,
421        source: String,
422        ctx: Option<mlua::Table>,
423    ) -> mlua::Result<mlua::MultiValue> {
424        let ctx = ctx.unwrap_or(lua.create_table()?);
425
426        let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
427
428        bind_lua(lua, || {
429            let expr = self
430                .0
431                .compile_expression(&source)
432                .map_err(mlua::Error::external)?
433                .eval(&context)
434                .map_err(mlua::Error::external)?;
435
436            minijinja_to_lua(lua, &expr).ok_or_else(|| {
437                mlua::Error::DeserializeError("could not convert output to lua".to_string())
438            })
439        })
440    }
441
442    #[lua(name = "add_filter")]
443    pub(crate) fn lua_add_filter(
444        &mut self,
445        lua: &mlua::Lua,
446        name: String,
447        filter: mlua::Function,
448        pass_state: Option<bool>,
449    ) -> mlua::Result<()> {
450        let key = lua.create_registry_value(filter)?;
451        let mut func = LuaFunctionObject::new(key);
452        func.set_pass_state(pass_state.unwrap_or(true));
453
454        self.0
455            .add_filter(name, move |state: &State, args: JinjaRest<JinjaValue>| {
456                func.with_func(&args, Some(state))
457            });
458
459        Ok(())
460    }
461
462    #[lua(name = "remove_filter", infallible)]
463    pub(crate) fn lua_remove_filter(&mut self, name: String) {
464        self.0.remove_filter(&name)
465    }
466
467    #[lua(name = "add_test")]
468    pub(crate) fn lua_add_test(
469        &mut self,
470        lua: &mlua::Lua,
471        name: String,
472        test: mlua::Function,
473        pass_state: Option<bool>,
474    ) -> mlua::Result<()> {
475        let key = lua.create_registry_value(test)?;
476        let mut func = LuaFunctionObject::new(key);
477        func.set_pass_state(pass_state.unwrap_or(true));
478
479        self.0
480            .add_test(name, move |state: &State, args: JinjaRest<JinjaValue>| {
481                func.with_func(&args, Some(state))
482            });
483
484        Ok(())
485    }
486
487    #[lua(name = "remove_test", infallible)]
488    pub(crate) fn lua_remove_test(&mut self, name: String) {
489        self.0.remove_test(&name)
490    }
491
492    #[lua(name = "add_global")]
493    pub(crate) fn add_global(
494        &mut self,
495        lua: &mlua::Lua,
496        name: String,
497        val: mlua::Value,
498        pass_state: Option<bool>,
499    ) -> mlua::Result<()> {
500        match val {
501            mlua::Value::Function(f) => {
502                let key = lua.create_registry_value(f)?;
503                let mut func = LuaFunctionObject::new(key);
504                func.set_pass_state(pass_state.unwrap_or(true));
505
506                self.0
507                    .add_function(name, move |state: &State, args: JinjaRest<JinjaValue>| {
508                        func.with_func(&args, Some(state))
509                    })
510            },
511            _ => self.0.add_global(name, lua_to_minijinja(lua, &val)),
512        };
513
514        Ok(())
515    }
516
517    #[lua(name = "remove_global", infallible)]
518    pub(crate) fn lua_remove_global(&mut self, name: String) {
519        self.0.remove_global(&name)
520    }
521
522    #[lua(name = "globals")]
523    pub(crate) fn lua_globals(&self, lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
524        let table = lua.create_table()?;
525
526        for (name, value) in self.0.globals() {
527            minijinja_to_lua(lua, &value)
528                .and_then(|mut v| table.set(name, v.pop_front().unwrap_or_default()).ok());
529        }
530
531        Ok(table)
532    }
533}