Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{borrow::Cow, fmt, ops::Deref};
4
5use minijinja::{
6    Environment,
7    Error as JinjaError,
8    ErrorKind as JinjaErrorKind,
9    State,
10    args,
11    value::{Rest as JinjaRest, Value as JinjaValue},
12};
13use mlua::LuaSerdeExt;
14
15use crate::{
16    convert::{
17        LuaAutoEscape,
18        LuaFunctionObject,
19        LuaSyntaxConfig,
20        LuaTableObject,
21        LuaUndefinedBehavior,
22        lua_to_minijinja,
23        minijinja_to_lua,
24    },
25    state::bind_lua,
26};
27
28/// A wrapper around a [`minijinja::Environment`]. This wrapper can be serialized into
29/// an [`mlua::UserData`] object for use within mlua::Lua.
30#[derive(mlua::UserData, Debug)]
31pub struct LuaEnvironment(Environment<'static>);
32
33impl From<Environment<'static>> for LuaEnvironment {
34    fn from(value: Environment<'static>) -> Self {
35        LuaEnvironment(value)
36    }
37}
38
39impl From<LuaEnvironment> for Environment<'static> {
40    fn from(value: LuaEnvironment) -> Self {
41        value.0
42    }
43}
44
45impl Deref for LuaEnvironment {
46    type Target = Environment<'static>;
47
48    fn deref(&self) -> &Self::Target {
49        &self.0
50    }
51}
52
53impl fmt::Display for LuaEnvironment {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "Environment")
56    }
57}
58
59#[mlua::userdata_impl]
60impl LuaEnvironment {
61    /// Get a new environment
62    #[lua(name = "new", infallible)]
63    pub(crate) fn lua_new() -> Self {
64        let mut env = Environment::new();
65
66        #[cfg(feature = "minijinja-contrib")]
67        minijinja_contrib::add_to_environment(&mut env);
68
69        #[cfg(feature = "json")]
70        crate::contrib::json::add_to_environment(&mut env);
71
72        #[cfg(feature = "datetime")]
73        crate::contrib::datetime::add_to_environment(&mut env);
74
75        env.into()
76    }
77
78    /// Get a new empty environment
79    #[lua(name = "empty", infallible)]
80    pub(crate) fn lua_empty() -> Self {
81        Environment::empty().into()
82    }
83
84    #[lua(name = "keep_trailing_newline", getter, infallible)]
85    pub(crate) fn lua_keep_trailing_newline(&self) -> bool {
86        self.0.keep_trailing_newline()
87    }
88
89    #[lua(name = "keep_trailing_newline", setter, infallible)]
90    pub(crate) fn lua_set_keep_trailing_newline(&mut self, val: bool) {
91        self.0.set_keep_trailing_newline(val)
92    }
93
94    #[lua(name = "trim_blocks", getter, infallible)]
95    pub(crate) fn lua_trim_blocks(&self) -> bool {
96        self.0.trim_blocks()
97    }
98
99    #[lua(name = "trim_blocks", setter, infallible)]
100    pub(crate) fn lua_set_trim_blocks(&mut self, val: bool) {
101        self.0.set_trim_blocks(val)
102    }
103
104    #[lua(name = "lstrip_blocks", getter, infallible)]
105    pub(crate) fn lua_lstrip_blocks(&self) -> bool {
106        self.0.lstrip_blocks()
107    }
108
109    #[lua(name = "lstrip_blocks", setter, infallible)]
110    pub(crate) fn lua_set_lstrip_blocks(&mut self, val: bool) {
111        self.0.set_lstrip_blocks(val)
112    }
113
114    #[lua(name = "debug", getter, infallible)]
115    pub(crate) fn lua_debug(&self) -> bool {
116        self.0.debug()
117    }
118
119    #[lua(name = "debug", setter, infallible)]
120    pub(crate) fn lua_set_debug(&mut self, val: bool) {
121        self.0.set_debug(val)
122    }
123
124    #[lua(name = "fuel", getter, infallible)]
125    pub(crate) fn lua_fuel(&self) -> Option<u64> {
126        self.0.fuel()
127    }
128
129    #[lua(name = "fuel", setter, infallible)]
130    pub(crate) fn lua_set_fuel(&mut self, val: Option<u64>) {
131        self.0.set_fuel(val)
132    }
133
134    #[lua(name = "recursion_limit", getter, infallible)]
135    pub(crate) fn lua_recursion_limit(&self) -> usize {
136        self.0.recursion_limit()
137    }
138
139    #[lua(name = "recursion_limit", setter, infallible)]
140    pub(crate) fn lua_set_recursion_limit(&mut self, val: usize) {
141        self.0.set_recursion_limit(val)
142    }
143
144    #[lua(name = "undefined_behavior", getter, infallible)]
145    pub(crate) fn lua_undefined_behavior(&self) -> LuaUndefinedBehavior {
146        self.0.undefined_behavior().into()
147    }
148
149    #[lua(name = "undefined_behavior", setter)]
150    pub(crate) fn lua_set_undefined_behavior(
151        &mut self,
152        val: LuaUndefinedBehavior,
153    ) -> mlua::Result<()> {
154        self.0.set_undefined_behavior(val.into());
155
156        Ok(())
157    }
158
159    #[lua(name = "add_template", infallible)]
160    pub(crate) fn lua_add_template(
161        &mut self,
162        lua: &mlua::Lua,
163        name: String,
164        source: String,
165    ) -> mlua::Result<()> {
166        bind_lua(lua, || {
167            self.0
168                .add_template_owned(name, source)
169                .map_err(mlua::Error::external)
170        })
171    }
172
173    #[lua(name = "remove_template", infallible)]
174    pub(crate) fn lua_remove_template(&mut self, lua: &mlua::Lua, name: &str) {
175        bind_lua(lua, || self.0.remove_template(name))
176    }
177
178    #[lua(name = "clear_templates", infallible)]
179    pub(crate) fn lua_clear_templates(&mut self, lua: &mlua::Lua) {
180        bind_lua(lua, || self.0.clear_templates())
181    }
182
183    #[lua(name = "undeclared_variables")]
184    pub(crate) fn lua_undeclared_variables(
185        &mut self,
186        lua: &mlua::Lua,
187        name: &str,
188        nested: Option<bool>,
189    ) -> mlua::Result<mlua::Value> {
190        bind_lua(lua, || {
191            let nested = nested.unwrap_or(false);
192
193            let vars = self
194                .0
195                .get_template(name)
196                .map_err(mlua::Error::external)?
197                .undeclared_variables(nested);
198
199            lua.to_value(&vars)
200        })
201    }
202
203    #[lua(name = "set_loader")]
204    pub(crate) fn lua_set_loader(
205        &mut self,
206        lua: &mlua::Lua,
207        callback: mlua::Function,
208    ) -> mlua::Result<()> {
209        let func = LuaFunctionObject::from_value(lua, &callback)?;
210
211        self.0.set_loader(move |name| {
212            let source = func.with_func(args!(name), None)?;
213            Ok(source.and_then(|v| {
214                // If the lua function returns nil, i.e., no path found
215                // it is mapped as `minijinja::value::ValueKind::Undefined`, however
216                // we need to return a `None` to indicate no path was found.
217                if v.is_undefined() {
218                    None
219                } else {
220                    Some(v.to_string())
221                }
222            }))
223        });
224
225        Ok(())
226    }
227
228    #[lua(name = "set_path_join_callback")]
229    pub(crate) fn lua_set_path_join_callback(
230        &mut self,
231        lua: &mlua::Lua,
232        callback: mlua::Function,
233    ) -> mlua::Result<()> {
234        let func = LuaFunctionObject::from_value(lua, &callback)?;
235
236        self.0.set_path_join_callback(move |name, parent| {
237            func.with_func(args!(name, parent), None)
238                .ok()
239                .flatten()
240                .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
241                .unwrap_or(Cow::Borrowed(name))
242        });
243
244        Ok(())
245    }
246
247    #[lua(name = "set_unknown_method_callback")]
248    pub(crate) fn lua_set_unknown_method_callback(
249        &mut self,
250        lua: &mlua::Lua,
251        callback: mlua::Function,
252    ) -> mlua::Result<()> {
253        let mut func = LuaFunctionObject::from_value(lua, &callback)?;
254        func.set_pass_state(true);
255
256        self.0
257            .set_unknown_method_callback(move |state, value, method, args| {
258                func.with_func(args!(value, method, ..args), Some(state))
259                    .map(|v| v.unwrap_or_default())
260            });
261
262        Ok(())
263    }
264
265    #[cfg(feature = "minijinja-contrib")]
266    #[lua(name = "set_pycompat", infallible)]
267    pub(crate) fn lua_set_pycompat(&mut self, enable: Option<bool>) {
268        match enable {
269            Some(true) | None => self
270                .0
271                .set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback),
272            Some(false) => self.0.set_unknown_method_callback(|_, _, _, _| {
273                Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
274            }),
275        }
276    }
277
278    #[lua(name = "set_auto_escape_callback")]
279    pub(crate) fn lua_set_auto_escape_callback(
280        &mut self,
281        lua: &mlua::Lua,
282        callback: mlua::Function,
283    ) -> mlua::Result<()> {
284        let func = LuaFunctionObject::from_value(lua, &callback)?;
285
286        self.0
287            .set_auto_escape_callback(move |name| -> minijinja::AutoEscape {
288                func.with_func(args!(name), None)
289                    .ok()
290                    .flatten()
291                    .and_then(|v| LuaAutoEscape::try_from(v.to_string().as_str()).ok())
292                    .unwrap_or_default()
293                    .into()
294            });
295
296        Ok(())
297    }
298
299    #[lua(name = "set_formatter")]
300    pub(crate) fn lua_set_formatter(
301        &mut self,
302        lua: &mlua::Lua,
303        callback: mlua::Function,
304    ) -> mlua::Result<()> {
305        let mut func = LuaFunctionObject::from_value(lua, &callback)?;
306        func.set_pass_state(true);
307
308        self.0.set_formatter(move |out, state, value| {
309            let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
310                return Ok(());
311            };
312
313            let Some(s) = val.as_str() else {
314                return Err(JinjaError::new(
315                    JinjaErrorKind::WriteFailure,
316                    "formatter must return a string",
317                ));
318            };
319
320            out.write_str(s)
321                .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
322        });
323
324        Ok(())
325    }
326
327    #[lua(name = "set_syntax")]
328    pub(crate) fn lua_set_syntax(&mut self, syntax: LuaSyntaxConfig) -> mlua::Result<()> {
329        self.0.set_syntax(syntax.into());
330
331        Ok(())
332    }
333
334    #[lua(name = "render_template")]
335    pub(crate) fn lua_render_template(
336        &mut self,
337        lua: &mlua::Lua,
338        name: &str,
339        ctx: Option<mlua::Table>,
340    ) -> mlua::Result<String> {
341        let ctx: Option<JinjaValue> = ctx
342            .and_then(|t| LuaTableObject::from_value(lua, &t).ok())
343            .map(|obj| obj.into());
344
345        bind_lua(lua, || {
346            self.0
347                .get_template(name)
348                .map_err(mlua::Error::external)?
349                .render(ctx)
350                .map_err(mlua::Error::external)
351        })
352    }
353
354    #[lua(name = "render_str")]
355    pub(crate) fn lua_render_str(
356        &self,
357        lua: &mlua::Lua,
358        source: &str,
359        ctx: Option<mlua::Table>,
360        name: Option<String>,
361    ) -> mlua::Result<String> {
362        let ctx: Option<JinjaValue> = ctx
363            .and_then(|t| LuaTableObject::from_value(lua, &t).ok())
364            .map(|obj| obj.into());
365
366        let name = name.unwrap_or("<string>".to_string());
367
368        bind_lua(lua, || {
369            self.0
370                .render_named_str(&name, source, ctx)
371                .map_err(mlua::Error::external)
372        })
373    }
374
375    #[lua(name = "render_captured")]
376    pub(crate) fn lua_render_captured(
377        &mut self,
378        lua: &mlua::Lua,
379        name: &str,
380        ctx: Option<mlua::Table>,
381        callback: mlua::Function,
382    ) -> mlua::Result<mlua::MultiValue> {
383        let mut func = LuaFunctionObject::from_value(lua, &callback)?;
384        func.set_pass_state(true);
385
386        let ctx: Option<JinjaValue> = ctx
387            .and_then(|t| LuaTableObject::from_value(lua, &t).ok())
388            .map(|obj| obj.into());
389
390        bind_lua(lua, || {
391            let mut captured = self
392                .0
393                .get_template(name)
394                .map_err(mlua::Error::external)?
395                .render_captured(ctx)
396                .map_err(mlua::Error::external)?;
397
398            let mut mv = captured
399                .with_state_mut(|state| func.with_func_mut(&[], Some(state)))
400                .map_err(mlua::Error::external)?
401                .and_then(|v| minijinja_to_lua(lua, &v))
402                .unwrap_or_default();
403
404            let rendered = captured.into_output();
405
406            mv.push_front(mlua::Value::String(lua.create_string(rendered)?));
407
408            Ok(mv)
409        })
410    }
411
412    #[lua(name = "eval")]
413    pub(crate) fn lua_eval(
414        &self,
415        lua: &mlua::Lua,
416        source: &str,
417        ctx: Option<mlua::Table>,
418    ) -> mlua::Result<mlua::MultiValue> {
419        let ctx: Option<JinjaValue> = ctx
420            .and_then(|t| LuaTableObject::from_value(lua, &t).ok())
421            .map(|obj| obj.into());
422
423        bind_lua(lua, || {
424            let expr = self
425                .0
426                .compile_expression(source)
427                .map_err(mlua::Error::external)?
428                .eval(ctx)
429                .map_err(mlua::Error::external)?;
430
431            minijinja_to_lua(lua, &expr).ok_or_else(|| {
432                mlua::Error::DeserializeError("could not convert output to lua".to_string())
433            })
434        })
435    }
436
437    #[lua(name = "add_filter")]
438    pub(crate) fn lua_add_filter(
439        &mut self,
440        lua: &mlua::Lua,
441        name: String,
442        filter: mlua::Function,
443        pass_state: Option<bool>,
444    ) -> mlua::Result<()> {
445        let mut func = LuaFunctionObject::from_value(lua, &filter)?;
446        func.set_pass_state(pass_state.unwrap_or(true));
447
448        self.0
449            .add_filter(name, move |state: &State, args: JinjaRest<JinjaValue>| {
450                func.with_func(&args, Some(state))
451            });
452
453        Ok(())
454    }
455
456    #[lua(name = "remove_filter", infallible)]
457    pub(crate) fn lua_remove_filter(&mut self, name: String) {
458        self.0.remove_filter(&name)
459    }
460
461    #[lua(name = "add_test")]
462    pub(crate) fn lua_add_test(
463        &mut self,
464        lua: &mlua::Lua,
465        name: String,
466        test: mlua::Function,
467        pass_state: Option<bool>,
468    ) -> mlua::Result<()> {
469        let mut func = LuaFunctionObject::from_value(lua, &test)?;
470        func.set_pass_state(pass_state.unwrap_or(true));
471
472        self.0
473            .add_test(name, move |state: &State, args: JinjaRest<JinjaValue>| {
474                func.with_func(&args, Some(state))
475            });
476
477        Ok(())
478    }
479
480    #[lua(name = "remove_test", infallible)]
481    pub(crate) fn lua_remove_test(&mut self, name: String) {
482        self.0.remove_test(&name)
483    }
484
485    #[lua(name = "add_global")]
486    pub(crate) fn add_global(
487        &mut self,
488        lua: &mlua::Lua,
489        name: String,
490        val: mlua::Value,
491        pass_state: Option<bool>,
492    ) -> mlua::Result<()> {
493        match val {
494            mlua::Value::Function(f) => {
495                let mut func = LuaFunctionObject::from_value(lua, &f)?;
496                func.set_pass_state(pass_state.unwrap_or(true));
497
498                self.0
499                    .add_function(name, move |state: &State, args: JinjaRest<JinjaValue>| {
500                        func.with_func(&args, Some(state))
501                    })
502            },
503            _ => self.0.add_global(name, lua_to_minijinja(lua, &val)),
504        };
505
506        Ok(())
507    }
508
509    #[lua(name = "remove_global", infallible)]
510    pub(crate) fn lua_remove_global(&mut self, name: &str) {
511        self.0.remove_global(name)
512    }
513
514    #[lua(name = "globals")]
515    pub(crate) fn lua_globals(&self, lua: &mlua::Lua) -> mlua::Result<mlua::Table> {
516        let table = lua.create_table()?;
517
518        for (name, value) in self.0.globals() {
519            minijinja_to_lua(lua, &value)
520                .and_then(|mut v| table.set(name, v.pop_front().unwrap_or_default()).ok());
521        }
522
523        Ok(table)
524    }
525}