Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    borrow::Cow,
5    fmt,
6    sync::{
7        RwLock,
8        RwLockReadGuard,
9        RwLockWriteGuard,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use minijinja::{
15    Environment,
16    Error as JinjaError,
17    ErrorKind as JinjaErrorKind,
18    State,
19    args,
20    value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::LuaSerdeExt;
23
24use crate::{
25    convert::{
26        LuaFunctionObject,
27        LuaObject,
28        lua_to_auto_escape,
29        lua_to_minijinja,
30        lua_to_syntax_config,
31        lua_to_undefined_behavior,
32        minijinja_to_lua,
33        undefined_behavior_to_lua,
34    },
35    state::bind_lua,
36};
37
38/// A wrapper around a [`minijinja::Environment`]. This wrapper can be serialized into
39/// an [`mlua::UserData`] object for use within mlua::Lua.
40#[derive(Debug)]
41pub struct LuaEnvironment {
42    env: RwLock<Environment<'static>>,
43    reload_before_render: AtomicBool,
44}
45
46impl LuaEnvironment {
47    /// Get a new environment
48    pub(crate) fn new() -> Self {
49        let mut env = Environment::new();
50
51        #[cfg(feature = "minijinja-contrib")]
52        minijinja_contrib::add_to_environment(&mut env);
53
54        #[cfg(feature = "json")]
55        crate::contrib::json::add_to_environment(&mut env);
56
57        #[cfg(feature = "datetime")]
58        crate::contrib::datetime::add_to_environment(&mut env);
59
60        Self {
61            env: RwLock::new(env),
62            reload_before_render: AtomicBool::new(false),
63        }
64    }
65
66    /// Get a new empty environment
67    pub(crate) fn empty() -> Self {
68        let env = Environment::empty();
69
70        Self {
71            env: RwLock::new(env),
72            reload_before_render: AtomicBool::new(false),
73        }
74    }
75
76    pub(crate) fn reload_before_render(&self) -> bool {
77        self.reload_before_render.load(Ordering::Relaxed)
78    }
79
80    pub(crate) fn set_reload_before_render(&self, enable: bool) {
81        self.reload_before_render.store(enable, Ordering::Relaxed);
82    }
83
84    /// Get a read-only lock on the underlying `minijinja::Environment`
85    pub(crate) fn read_env(
86        &self,
87    ) -> Result<RwLockReadGuard<'_, Environment<'static>>, mlua::Error> {
88        self.env
89            .read()
90            .map_err(|_| mlua::Error::runtime("environment lock poisoned"))
91    }
92
93    /// Get a write lock on the underlying [`minijinja::Environment`]
94    pub(crate) fn write_env(
95        &self,
96    ) -> Result<RwLockWriteGuard<'_, Environment<'static>>, mlua::Error> {
97        self.env
98            .write()
99            .map_err(|_| mlua::Error::runtime("environment lock poisoned"))
100    }
101}
102
103impl Default for LuaEnvironment {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl fmt::Display for LuaEnvironment {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        write!(f, "Environment")
112    }
113}
114
115impl mlua::UserData for LuaEnvironment {
116    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
117        fields.add_field_method_get(
118            "reload_before_render",
119            |_, this: &LuaEnvironment| -> Result<bool, mlua::Error> {
120                Ok(this.reload_before_render())
121            },
122        );
123
124        fields.add_field_method_set(
125            "reload_before_render",
126            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), mlua::Error> {
127                this.set_reload_before_render(val);
128
129                Ok(())
130            },
131        );
132
133        fields.add_field_method_get(
134            "keep_trailing_newline",
135            |_, this: &LuaEnvironment| -> Result<bool, mlua::Error> {
136                Ok(this.read_env()?.keep_trailing_newline())
137            },
138        );
139
140        fields.add_field_method_set(
141            "keep_trailing_newline",
142            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), mlua::Error> {
143                this.write_env()?.set_keep_trailing_newline(val);
144
145                Ok(())
146            },
147        );
148
149        fields.add_field_method_get(
150            "trim_blocks",
151            |_, this: &LuaEnvironment| -> Result<bool, mlua::Error> {
152                Ok(this.read_env()?.trim_blocks())
153            },
154        );
155
156        fields.add_field_method_set(
157            "trim_blocks",
158            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), mlua::Error> {
159                this.write_env()?.set_trim_blocks(val);
160
161                Ok(())
162            },
163        );
164
165        fields.add_field_method_get(
166            "lstrip_blocks",
167            |_, this: &LuaEnvironment| -> Result<bool, mlua::Error> {
168                Ok(this.read_env()?.lstrip_blocks())
169            },
170        );
171
172        fields.add_field_method_set(
173            "lstrip_blocks",
174            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), mlua::Error> {
175                this.write_env()?.set_lstrip_blocks(val);
176
177                Ok(())
178            },
179        );
180
181        fields.add_field_method_get(
182            "debug",
183            |_, this: &LuaEnvironment| -> Result<bool, mlua::Error> {
184                Ok(this.read_env()?.debug())
185            },
186        );
187
188        fields.add_field_method_set(
189            "debug",
190            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), mlua::Error> {
191                this.write_env()?.set_debug(val);
192
193                Ok(())
194            },
195        );
196
197        fields.add_field_method_get(
198            "fuel",
199            |_, this: &LuaEnvironment| -> Result<Option<u64>, mlua::Error> {
200                Ok(this.read_env()?.fuel())
201            },
202        );
203
204        fields.add_field_method_set(
205            "fuel",
206            |_, this: &mut LuaEnvironment, val: Option<u64>| -> Result<(), mlua::Error> {
207                this.write_env()?.set_fuel(val);
208
209                Ok(())
210            },
211        );
212
213        fields.add_field_method_get(
214            "recursion_limit",
215            |_, this: &LuaEnvironment| -> Result<usize, mlua::Error> {
216                Ok(this.read_env()?.recursion_limit())
217            },
218        );
219
220        fields.add_field_method_set(
221            "recursion_limit",
222            |_, this: &mut LuaEnvironment, val: usize| -> Result<(), mlua::Error> {
223                this.write_env()?.set_recursion_limit(val);
224
225                Ok(())
226            },
227        );
228
229        fields.add_field_method_get(
230            "undefined_behavior",
231            |_, this: &LuaEnvironment| -> Result<Option<String>, mlua::Error> {
232                let ub = this.read_env()?.undefined_behavior();
233
234                Ok(undefined_behavior_to_lua(ub))
235            },
236        );
237
238        fields.add_field_method_set(
239            "undefined_behavior",
240            |_, this: &mut LuaEnvironment, val: String| -> Result<(), mlua::Error> {
241                let behavior = lua_to_undefined_behavior(&val)?;
242
243                this.write_env()?.set_undefined_behavior(behavior);
244
245                Ok(())
246            },
247        );
248    }
249
250    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
251        methods.add_function(
252            "new",
253            |_, _: mlua::MultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::new()) },
254        );
255
256        methods.add_function(
257            "empty",
258            |_, _: mlua::MultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::empty()) },
259        );
260
261        methods.add_method(
262            "add_template",
263            |lua: &mlua::Lua,
264             this: &LuaEnvironment,
265             (name, source): (String, String)|
266             -> Result<(), mlua::Error> {
267                bind_lua(lua, || {
268                    this.write_env()?
269                        .add_template_owned(name, source)
270                        .map_err(mlua::Error::external)
271                })
272            },
273        );
274
275        methods.add_method(
276            "remove_template",
277            |lua: &mlua::Lua, this: &LuaEnvironment, name: String| -> Result<(), mlua::Error> {
278                bind_lua(lua, || {
279                    this.write_env()?.remove_template(&name);
280                    Ok(())
281                })
282            },
283        );
284
285        methods.add_method(
286            "clear_templates",
287            |lua: &mlua::Lua, this: &LuaEnvironment, _: mlua::Value| -> Result<(), mlua::Error> {
288                bind_lua(lua, || {
289                    this.write_env()?.clear_templates();
290
291                    Ok(())
292                })
293            },
294        );
295
296        methods.add_method(
297            "undeclared_variables",
298            |lua: &mlua::Lua,
299             this: &LuaEnvironment,
300             (name, nested): (String, Option<bool>)|
301             -> Result<mlua::Value, mlua::Error> {
302                bind_lua(lua, || {
303                    if this.reload_before_render() {
304                        this.write_env()?.clear_templates();
305                    }
306
307                    let nested = nested.unwrap_or(false);
308
309                    let vars = this
310                        .read_env()?
311                        .get_template(&name)
312                        .map_err(mlua::Error::external)?
313                        .undeclared_variables(nested);
314
315                    lua.to_value(&vars)
316                })
317            },
318        );
319
320        methods.add_method(
321            "set_loader",
322            |lua: &mlua::Lua,
323             this: &LuaEnvironment,
324             callback: mlua::Function|
325             -> Result<(), mlua::Error> {
326                let key = lua.create_registry_value(callback)?;
327                let func = LuaFunctionObject::new(key);
328
329                this.write_env()?.set_loader(move |name| {
330                    let source = func.with_func(args!(name), None)?;
331                    Ok(source.and_then(|v| {
332                        // If the lua function returns nil, i.e., no path found
333                        // it is mapped as `minijinja::value::ValueKind::Undefined`, however
334                        // we need to return a `None` to indicate no path was found.
335                        if v.is_undefined() {
336                            None
337                        } else {
338                            Some(v.to_string())
339                        }
340                    }))
341                });
342
343                Ok(())
344            },
345        );
346
347        methods.add_method(
348            "set_path_join_callback",
349            |lua: &mlua::Lua,
350             this: &LuaEnvironment,
351             callback: mlua::Function|
352             -> Result<(), mlua::Error> {
353                let key = lua.create_registry_value(callback)?;
354                let func = LuaFunctionObject::new(key);
355
356                this.write_env()?
357                    .set_path_join_callback(move |name, parent| {
358                        func.with_func(args!(name, parent), None)
359                            .ok()
360                            .flatten()
361                            .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
362                            .unwrap_or(Cow::Borrowed(name))
363                    });
364                Ok(())
365            },
366        );
367
368        methods.add_method(
369            "set_unknown_method_callback",
370            |lua: &mlua::Lua,
371             this: &LuaEnvironment,
372             callback: mlua::Function|
373             -> Result<(), mlua::Error> {
374                let key = lua.create_registry_value(callback)?;
375                let mut func = LuaFunctionObject::new(key);
376                func.set_pass_state(true);
377
378                this.write_env()?
379                    .set_unknown_method_callback(move |state, value, method, args| {
380                        func.with_func(args!(value, method, ..args), Some(state))
381                            .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
382                    });
383
384                Ok(())
385            },
386        );
387
388        methods.add_method(
389            "set_pycompat",
390            |_, this: &LuaEnvironment, enable: Option<bool>| {
391                match enable {
392                    Some(true) | None => {
393                        this.write_env()?.set_unknown_method_callback(
394                            minijinja_contrib::pycompat::unknown_method_callback,
395                        );
396                    },
397                    Some(false) => {
398                        this.write_env()?.set_unknown_method_callback(|_, _, _, _| {
399                            Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
400                        });
401                    },
402                }
403
404                Ok(())
405            },
406        );
407
408        methods.add_method(
409            "set_auto_escape_callback",
410            |lua: &mlua::Lua,
411             this: &LuaEnvironment,
412             callback: mlua::Function|
413             -> Result<(), mlua::Error> {
414                let key = lua.create_registry_value(callback)?;
415                let func = LuaFunctionObject::new(key);
416
417                this.write_env()?.set_auto_escape_callback(move |name| {
418                    func.with_func(args!(name), None)
419                        .ok()
420                        .flatten()
421                        .and_then(|v| {
422                            let s = v.as_str()?.to_string();
423                            lua_to_auto_escape(&s).ok()
424                        })
425                        .unwrap_or(minijinja::AutoEscape::None)
426                });
427                Ok(())
428            },
429        );
430
431        methods.add_method(
432            "set_formatter",
433            |lua: &mlua::Lua,
434             this: &LuaEnvironment,
435             callback: mlua::Function|
436             -> Result<(), mlua::Error> {
437                let key = lua.create_registry_value(callback)?;
438                let mut func = LuaFunctionObject::new(key);
439                func.set_pass_state(true);
440
441                this.write_env()?.set_formatter(move |out, state, value| {
442                    let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
443                        return Ok(());
444                    };
445
446                    let Some(s) = val.as_str() else {
447                        return Err(JinjaError::new(
448                            JinjaErrorKind::WriteFailure,
449                            "formatter must return a string",
450                        ));
451                    };
452
453                    out.write_str(s)
454                        .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
455                });
456
457                Ok(())
458            },
459        );
460
461        methods.add_method(
462            "set_syntax",
463            |_, this: &LuaEnvironment, syntax: mlua::Table| -> Result<(), mlua::Error> {
464                let config = lua_to_syntax_config(syntax).map_err(mlua::Error::external)?;
465                this.write_env()?.set_syntax(config);
466
467                Ok(())
468            },
469        );
470
471        methods.add_method(
472            "render_template",
473            |lua: &mlua::Lua,
474             this: &LuaEnvironment,
475             (name, ctx): (String, Option<mlua::Table>)|
476             -> Result<String, mlua::Error> {
477                if this.reload_before_render() {
478                    this.write_env()?.clear_templates();
479                }
480
481                let ctx = ctx.unwrap_or(lua.create_table()?);
482
483                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
484
485                bind_lua(lua, || {
486                    this.read_env()?
487                        .get_template(&name)
488                        .map_err(mlua::Error::external)?
489                        .render(context)
490                        .map_err(mlua::Error::external)
491                })
492            },
493        );
494
495        methods.add_method(
496            "render_str",
497            |lua: &mlua::Lua,
498             this: &LuaEnvironment,
499             (source, ctx, name): (String, Option<mlua::Table>, Option<String>)|
500             -> Result<String, mlua::Error> {
501                let ctx = ctx.unwrap_or(lua.create_table()?);
502
503                let name = name.unwrap_or("<string>".to_string());
504                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
505
506                bind_lua(lua, || {
507                    this.read_env()?
508                        .render_named_str(&name, &source, context)
509                        .map_err(mlua::Error::external)
510                })
511            },
512        );
513
514        methods.add_method(
515            "eval",
516            |lua: &mlua::Lua,
517             this: &LuaEnvironment,
518             (source, ctx): (String, Option<mlua::Table>)|
519             -> Result<mlua::Value, mlua::Error> {
520                let ctx = ctx.unwrap_or(lua.create_table()?);
521
522                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
523
524                bind_lua(lua, || {
525                    let expr = this
526                        .read_env()?
527                        .compile_expression(&source)
528                        .map_err(mlua::Error::external)?
529                        .eval(&context)
530                        .map_err(mlua::Error::external)?;
531
532                    minijinja_to_lua(lua, &expr).ok_or_else(|| mlua::Error::ToLuaConversionError {
533                        from: "".to_string(),
534                        to: "",
535                        message: None,
536                    })
537                })
538            },
539        );
540
541        methods.add_method(
542            "add_filter",
543            |lua: &mlua::Lua,
544             this: &LuaEnvironment,
545             (name, filter, pass_state): (String, mlua::Function, Option<bool>)|
546             -> Result<(), mlua::Error> {
547                let key = lua.create_registry_value(filter)?;
548                let mut func = LuaFunctionObject::new(key);
549                func.set_pass_state(pass_state.unwrap_or(true));
550
551                this.write_env()?.add_filter(
552                    name,
553                    move |state: &State, args: JinjaRest<JinjaValue>| {
554                        func.with_func(&args, Some(state))
555                    },
556                );
557
558                Ok(())
559            },
560        );
561
562        methods.add_method(
563            "remove_filter",
564            |_, this: &LuaEnvironment, name: String| -> Result<(), mlua::Error> {
565                this.write_env()?.remove_filter(&name);
566
567                Ok(())
568            },
569        );
570
571        methods.add_method(
572            "add_test",
573            |lua: &mlua::Lua,
574             this: &LuaEnvironment,
575             (name, test, pass_state): (String, mlua::Function, Option<bool>)|
576             -> Result<(), mlua::Error> {
577                let key = lua.create_registry_value(test)?;
578                let mut func = LuaFunctionObject::new(key);
579                func.set_pass_state(pass_state.unwrap_or(true));
580
581                this.write_env()?.add_test(
582                    name,
583                    move |state: &State, args: JinjaRest<JinjaValue>| {
584                        func.with_func(&args, Some(state))
585                    },
586                );
587
588                Ok(())
589            },
590        );
591
592        methods.add_method(
593            "remove_test",
594            |_, this: &LuaEnvironment, name: String| -> Result<(), mlua::Error> {
595                this.write_env()?.remove_test(&name);
596
597                Ok(())
598            },
599        );
600
601        methods.add_method(
602            "add_global",
603            |lua: &mlua::Lua,
604             this: &LuaEnvironment,
605             (name, val, pass_state): (String, mlua::Value, Option<bool>)|
606             -> Result<(), mlua::Error> {
607                match val {
608                    mlua::Value::Function(f) => {
609                        let key = lua.create_registry_value(f)?;
610                        let mut func = LuaFunctionObject::new(key);
611                        func.set_pass_state(pass_state.unwrap_or(true));
612
613                        this.write_env()?.add_function(
614                            name,
615                            move |state: &State, args: JinjaRest<JinjaValue>| {
616                                func.with_func(&args, Some(state))
617                            },
618                        )
619                    },
620                    _ => this
621                        .write_env()?
622                        .add_global(name, lua_to_minijinja(lua, &val)),
623                };
624
625                Ok(())
626            },
627        );
628
629        methods.add_method(
630            "remove_global",
631            |_, this: &LuaEnvironment, name: String| -> Result<(), mlua::Error> {
632                this.write_env()?.remove_global(&name);
633
634                Ok(())
635            },
636        );
637
638        methods.add_method(
639            "globals",
640            |lua: &mlua::Lua,
641             this: &LuaEnvironment,
642             _val: mlua::Value|
643             -> Result<mlua::Table, mlua::Error> {
644                let table = lua.create_table()?;
645
646                for (name, value) in this.read_env()?.globals() {
647                    let val = minijinja_to_lua(lua, &value);
648                    table.set(name, val)?;
649                }
650
651                Ok(table)
652            },
653        );
654    }
655}