Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    borrow::Cow,
5    fmt,
6    sync::{
7        RwLock,
8        RwLockReadGuard,
9        RwLockWriteGuard,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use minijinja::{
15    Environment,
16    Error as JinjaError,
17    ErrorKind as JinjaErrorKind,
18    State,
19    args,
20    value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::LuaSerdeExt;
23
24use crate::{
25    convert::{
26        LuaFunctionObject,
27        LuaObject,
28        lua_to_auto_escape,
29        lua_to_minijinja,
30        lua_to_syntax_config,
31        lua_to_undefined_behavior,
32        minijinja_to_lua,
33        undefined_behavior_to_lua,
34    },
35    state::bind_lua,
36};
37
38/// A wrapper around a [`minijinja::Environment`]. This wrapper can be serialized into
39/// an [`mlua::UserData`] object for use within mlua::Lua.
40#[derive(Debug)]
41pub struct LuaEnvironment {
42    env: RwLock<Environment<'static>>,
43    reload_before_render: AtomicBool,
44}
45
46impl LuaEnvironment {
47    /// Get a new environment
48    pub(crate) fn new() -> Self {
49        let mut env = Environment::new();
50
51        #[cfg(feature = "minijinja-contrib")]
52        minijinja_contrib::add_to_environment(&mut env);
53
54        #[cfg(feature = "json")]
55        crate::contrib::json::add_to_environment(&mut env);
56
57        #[cfg(feature = "datetime")]
58        crate::contrib::datetime::add_to_environment(&mut env);
59
60        Self {
61            env: RwLock::new(env),
62            reload_before_render: AtomicBool::new(false),
63        }
64    }
65
66    /// Get a new empty environment
67    pub(crate) fn empty() -> Self {
68        let env = Environment::empty();
69
70        Self {
71            env: RwLock::new(env),
72            reload_before_render: AtomicBool::new(false),
73        }
74    }
75
76    pub(crate) fn reload_before_render(&self) -> bool {
77        self.reload_before_render.load(Ordering::Relaxed)
78    }
79
80    pub(crate) fn set_reload_before_render(&self, enable: bool) {
81        self.reload_before_render.store(enable, Ordering::Relaxed);
82    }
83
84    /// Get a read-only lock on the underlying `minijinja::Environment`
85    pub(crate) fn read_env(
86        &self,
87    ) -> Result<RwLockReadGuard<'_, Environment<'static>>, mlua::Error> {
88        self.env
89            .read()
90            .map_err(|_| mlua::Error::runtime("environment lock poisoned"))
91    }
92
93    /// Get a write lock on the underlying [`minijinja::Environment`]
94    pub(crate) fn write_env(
95        &self,
96    ) -> Result<RwLockWriteGuard<'_, Environment<'static>>, mlua::Error> {
97        self.env
98            .write()
99            .map_err(|_| mlua::Error::runtime("environment lock poisoned"))
100    }
101}
102
103impl Default for LuaEnvironment {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl fmt::Display for LuaEnvironment {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        write!(f, "Environment")
112    }
113}
114
115impl mlua::UserData for LuaEnvironment {
116    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
117        fields.add_field_method_get(
118            "reload_before_render",
119            |_, this: &LuaEnvironment| -> mlua::Result<bool> { Ok(this.reload_before_render()) },
120        );
121
122        fields.add_field_method_set(
123            "reload_before_render",
124            |_, this: &mut LuaEnvironment, val: bool| -> mlua::Result<()> {
125                this.set_reload_before_render(val);
126
127                Ok(())
128            },
129        );
130
131        fields.add_field_method_get(
132            "keep_trailing_newline",
133            |_, this: &LuaEnvironment| -> mlua::Result<bool> {
134                Ok(this.read_env()?.keep_trailing_newline())
135            },
136        );
137
138        fields.add_field_method_set(
139            "keep_trailing_newline",
140            |_, this: &mut LuaEnvironment, val: bool| -> mlua::Result<()> {
141                this.write_env()?.set_keep_trailing_newline(val);
142
143                Ok(())
144            },
145        );
146
147        fields.add_field_method_get(
148            "trim_blocks",
149            |_, this: &LuaEnvironment| -> mlua::Result<bool> { Ok(this.read_env()?.trim_blocks()) },
150        );
151
152        fields.add_field_method_set(
153            "trim_blocks",
154            |_, this: &mut LuaEnvironment, val: bool| -> mlua::Result<()> {
155                this.write_env()?.set_trim_blocks(val);
156
157                Ok(())
158            },
159        );
160
161        fields.add_field_method_get(
162            "lstrip_blocks",
163            |_, this: &LuaEnvironment| -> mlua::Result<bool> {
164                Ok(this.read_env()?.lstrip_blocks())
165            },
166        );
167
168        fields.add_field_method_set(
169            "lstrip_blocks",
170            |_, this: &mut LuaEnvironment, val: bool| -> mlua::Result<()> {
171                this.write_env()?.set_lstrip_blocks(val);
172
173                Ok(())
174            },
175        );
176
177        fields.add_field_method_get("debug", |_, this: &LuaEnvironment| -> mlua::Result<bool> {
178            Ok(this.read_env()?.debug())
179        });
180
181        fields.add_field_method_set(
182            "debug",
183            |_, this: &mut LuaEnvironment, val: bool| -> mlua::Result<()> {
184                this.write_env()?.set_debug(val);
185
186                Ok(())
187            },
188        );
189
190        fields.add_field_method_get(
191            "fuel",
192            |_, this: &LuaEnvironment| -> mlua::Result<Option<u64>> { Ok(this.read_env()?.fuel()) },
193        );
194
195        fields.add_field_method_set(
196            "fuel",
197            |_, this: &mut LuaEnvironment, val: Option<u64>| -> mlua::Result<()> {
198                this.write_env()?.set_fuel(val);
199
200                Ok(())
201            },
202        );
203
204        fields.add_field_method_get(
205            "recursion_limit",
206            |_, this: &LuaEnvironment| -> mlua::Result<usize> {
207                Ok(this.read_env()?.recursion_limit())
208            },
209        );
210
211        fields.add_field_method_set(
212            "recursion_limit",
213            |_, this: &mut LuaEnvironment, val: usize| -> mlua::Result<()> {
214                this.write_env()?.set_recursion_limit(val);
215
216                Ok(())
217            },
218        );
219
220        fields.add_field_method_get(
221            "undefined_behavior",
222            |_, this: &LuaEnvironment| -> mlua::Result<Option<String>> {
223                let ub = this.read_env()?.undefined_behavior();
224
225                Ok(undefined_behavior_to_lua(ub))
226            },
227        );
228
229        fields.add_field_method_set(
230            "undefined_behavior",
231            |_, this: &mut LuaEnvironment, val: String| -> mlua::Result<()> {
232                let behavior = lua_to_undefined_behavior(&val)?;
233
234                this.write_env()?.set_undefined_behavior(behavior);
235
236                Ok(())
237            },
238        );
239    }
240
241    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
242        methods.add_function(
243            "new",
244            |_, _: mlua::MultiValue| -> mlua::Result<LuaEnvironment> { Ok(LuaEnvironment::new()) },
245        );
246
247        methods.add_function(
248            "empty",
249            |_, _: mlua::MultiValue| -> mlua::Result<LuaEnvironment> {
250                Ok(LuaEnvironment::empty())
251            },
252        );
253
254        methods.add_method(
255            "add_template",
256            |lua: &mlua::Lua,
257             this: &LuaEnvironment,
258             (name, source): (String, String)|
259             -> mlua::Result<()> {
260                bind_lua(lua, || {
261                    this.write_env()?
262                        .add_template_owned(name, source)
263                        .map_err(mlua::Error::external)
264                })
265            },
266        );
267
268        methods.add_method(
269            "remove_template",
270            |lua: &mlua::Lua, this: &LuaEnvironment, name: String| -> mlua::Result<()> {
271                bind_lua(lua, || {
272                    this.write_env()?.remove_template(&name);
273                    Ok(())
274                })
275            },
276        );
277
278        methods.add_method(
279            "clear_templates",
280            |lua: &mlua::Lua, this: &LuaEnvironment, _: mlua::Value| -> mlua::Result<()> {
281                bind_lua(lua, || {
282                    this.write_env()?.clear_templates();
283
284                    Ok(())
285                })
286            },
287        );
288
289        methods.add_method(
290            "undeclared_variables",
291            |lua: &mlua::Lua,
292             this: &LuaEnvironment,
293             (name, nested): (String, Option<bool>)|
294             -> mlua::Result<mlua::Value> {
295                bind_lua(lua, || {
296                    if this.reload_before_render() {
297                        this.write_env()?.clear_templates();
298                    }
299
300                    let nested = nested.unwrap_or(false);
301
302                    let vars = this
303                        .read_env()?
304                        .get_template(&name)
305                        .map_err(mlua::Error::external)?
306                        .undeclared_variables(nested);
307
308                    lua.to_value(&vars)
309                })
310            },
311        );
312
313        methods.add_method(
314            "set_loader",
315            |lua: &mlua::Lua,
316             this: &LuaEnvironment,
317             callback: mlua::Function|
318             -> mlua::Result<()> {
319                let key = lua.create_registry_value(callback)?;
320                let func = LuaFunctionObject::new(key);
321
322                this.write_env()?.set_loader(move |name| {
323                    let source = func.with_func(args!(name), None)?;
324                    Ok(source.and_then(|v| {
325                        // If the lua function returns nil, i.e., no path found
326                        // it is mapped as `minijinja::value::ValueKind::Undefined`, however
327                        // we need to return a `None` to indicate no path was found.
328                        if v.is_undefined() {
329                            None
330                        } else {
331                            Some(v.to_string())
332                        }
333                    }))
334                });
335
336                Ok(())
337            },
338        );
339
340        methods.add_method(
341            "set_path_join_callback",
342            |lua: &mlua::Lua,
343             this: &LuaEnvironment,
344             callback: mlua::Function|
345             -> mlua::Result<()> {
346                let key = lua.create_registry_value(callback)?;
347                let func = LuaFunctionObject::new(key);
348
349                this.write_env()?
350                    .set_path_join_callback(move |name, parent| {
351                        func.with_func(args!(name, parent), None)
352                            .ok()
353                            .flatten()
354                            .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
355                            .unwrap_or(Cow::Borrowed(name))
356                    });
357                Ok(())
358            },
359        );
360
361        methods.add_method(
362            "set_unknown_method_callback",
363            |lua: &mlua::Lua,
364             this: &LuaEnvironment,
365             callback: mlua::Function|
366             -> mlua::Result<()> {
367                let key = lua.create_registry_value(callback)?;
368                let mut func = LuaFunctionObject::new(key);
369                func.set_pass_state(true);
370
371                this.write_env()?
372                    .set_unknown_method_callback(move |state, value, method, args| {
373                        func.with_func(args!(value, method, ..args), Some(state))
374                            .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
375                    });
376
377                Ok(())
378            },
379        );
380
381        methods.add_method(
382            "set_pycompat",
383            |_, this: &LuaEnvironment, enable: Option<bool>| -> mlua::Result<()> {
384                match enable {
385                    Some(true) | None => {
386                        this.write_env()?.set_unknown_method_callback(
387                            minijinja_contrib::pycompat::unknown_method_callback,
388                        );
389                    },
390                    Some(false) => {
391                        this.write_env()?.set_unknown_method_callback(|_, _, _, _| {
392                            Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
393                        });
394                    },
395                }
396
397                Ok(())
398            },
399        );
400
401        methods.add_method(
402            "set_auto_escape_callback",
403            |lua: &mlua::Lua,
404             this: &LuaEnvironment,
405             callback: mlua::Function|
406             -> mlua::Result<()> {
407                let key = lua.create_registry_value(callback)?;
408                let func = LuaFunctionObject::new(key);
409
410                this.write_env()?.set_auto_escape_callback(move |name| {
411                    func.with_func(args!(name), None)
412                        .ok()
413                        .flatten()
414                        .and_then(|v| {
415                            let s = v.as_str()?.to_string();
416                            lua_to_auto_escape(&s).ok()
417                        })
418                        .unwrap_or(minijinja::AutoEscape::None)
419                });
420                Ok(())
421            },
422        );
423
424        methods.add_method(
425            "set_formatter",
426            |lua: &mlua::Lua,
427             this: &LuaEnvironment,
428             callback: mlua::Function|
429             -> mlua::Result<()> {
430                let key = lua.create_registry_value(callback)?;
431                let mut func = LuaFunctionObject::new(key);
432                func.set_pass_state(true);
433
434                this.write_env()?.set_formatter(move |out, state, value| {
435                    let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
436                        return Ok(());
437                    };
438
439                    let Some(s) = val.as_str() else {
440                        return Err(JinjaError::new(
441                            JinjaErrorKind::WriteFailure,
442                            "formatter must return a string",
443                        ));
444                    };
445
446                    out.write_str(s)
447                        .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
448                });
449
450                Ok(())
451            },
452        );
453
454        methods.add_method(
455            "set_syntax",
456            |_, this: &LuaEnvironment, syntax: mlua::Table| -> mlua::Result<()> {
457                let config = lua_to_syntax_config(syntax).map_err(mlua::Error::external)?;
458                this.write_env()?.set_syntax(config);
459
460                Ok(())
461            },
462        );
463
464        methods.add_method(
465            "render_template",
466            |lua: &mlua::Lua,
467             this: &LuaEnvironment,
468             (name, ctx): (String, Option<mlua::Table>)|
469             -> mlua::Result<String> {
470                if this.reload_before_render() {
471                    this.write_env()?.clear_templates();
472                }
473
474                let ctx = ctx.unwrap_or(lua.create_table()?);
475
476                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
477
478                bind_lua(lua, || {
479                    this.read_env()?
480                        .get_template(&name)
481                        .map_err(mlua::Error::external)?
482                        .render(context)
483                        .map_err(mlua::Error::external)
484                })
485            },
486        );
487
488        methods.add_method(
489            "render_str",
490            |lua: &mlua::Lua,
491             this: &LuaEnvironment,
492             (source, ctx, name): (String, Option<mlua::Table>, Option<String>)|
493             -> mlua::Result<String> {
494                let ctx = ctx.unwrap_or(lua.create_table()?);
495
496                let name = name.unwrap_or("<string>".to_string());
497                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
498
499                bind_lua(lua, || {
500                    this.read_env()?
501                        .render_named_str(&name, &source, context)
502                        .map_err(mlua::Error::external)
503                })
504            },
505        );
506
507        methods.add_method(
508            "render_captured",
509            |lua: &mlua::Lua,
510             this: &LuaEnvironment,
511             (name, ctx, callback): (String, Option<mlua::Table>, mlua::Function)|
512             -> mlua::Result<mlua::MultiValue> {
513                if this.reload_before_render() {
514                    this.write_env()?.clear_templates();
515                }
516
517                let key = lua.create_registry_value(callback)?;
518                let mut func = LuaFunctionObject::new(key);
519                func.set_pass_state(true);
520
521                let ctx = ctx.unwrap_or(lua.create_table()?);
522
523                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
524
525                bind_lua(lua, || {
526                    let env = this.read_env()?;
527
528                    let mut captured = env
529                        .get_template(&name)
530                        .map_err(mlua::Error::external)?
531                        .render_captured(context)
532                        .map_err(mlua::Error::external)?;
533
534                    let expr = captured
535                        .with_state_mut(|state| func.with_func_mut(&[], Some(state)))
536                        .map_err(mlua::Error::external)?
537                        .and_then(|v| minijinja_to_lua(lua, &v))
538                        .unwrap_or_else(|| mlua::Value::Nil);
539
540                    let rendered = captured.into_output();
541
542                    let mut mv = mlua::MultiValue::new();
543
544                    mv.push_back(mlua::Value::String(lua.create_string(rendered)?));
545                    mv.push_back(expr);
546
547                    Ok(mv)
548                })
549            },
550        );
551
552        methods.add_method(
553            "eval",
554            |lua: &mlua::Lua,
555             this: &LuaEnvironment,
556             (source, ctx): (String, Option<mlua::Table>)|
557             -> mlua::Result<mlua::Value> {
558                let ctx = ctx.unwrap_or(lua.create_table()?);
559
560                let context = lua_to_minijinja(lua, &mlua::Value::Table(ctx));
561
562                bind_lua(lua, || {
563                    let expr = this
564                        .read_env()?
565                        .compile_expression(&source)
566                        .map_err(mlua::Error::external)?
567                        .eval(&context)
568                        .map_err(mlua::Error::external)?;
569
570                    minijinja_to_lua(lua, &expr).ok_or_else(|| {
571                        mlua::Error::DeserializeError("could not convert output to lua".to_string())
572                    })
573                })
574            },
575        );
576
577        methods.add_method(
578            "add_filter",
579            |lua: &mlua::Lua,
580             this: &LuaEnvironment,
581             (name, filter, pass_state): (String, mlua::Function, Option<bool>)|
582             -> mlua::Result<()> {
583                let key = lua.create_registry_value(filter)?;
584                let mut func = LuaFunctionObject::new(key);
585                func.set_pass_state(pass_state.unwrap_or(true));
586
587                this.write_env()?.add_filter(
588                    name,
589                    move |state: &State, args: JinjaRest<JinjaValue>| {
590                        func.with_func(&args, Some(state))
591                    },
592                );
593
594                Ok(())
595            },
596        );
597
598        methods.add_method(
599            "remove_filter",
600            |_, this: &LuaEnvironment, name: String| -> mlua::Result<()> {
601                this.write_env()?.remove_filter(&name);
602
603                Ok(())
604            },
605        );
606
607        methods.add_method(
608            "add_test",
609            |lua: &mlua::Lua,
610             this: &LuaEnvironment,
611             (name, test, pass_state): (String, mlua::Function, Option<bool>)|
612             -> mlua::Result<()> {
613                let key = lua.create_registry_value(test)?;
614                let mut func = LuaFunctionObject::new(key);
615                func.set_pass_state(pass_state.unwrap_or(true));
616
617                this.write_env()?.add_test(
618                    name,
619                    move |state: &State, args: JinjaRest<JinjaValue>| {
620                        func.with_func(&args, Some(state))
621                    },
622                );
623
624                Ok(())
625            },
626        );
627
628        methods.add_method(
629            "remove_test",
630            |_, this: &LuaEnvironment, name: String| -> mlua::Result<()> {
631                this.write_env()?.remove_test(&name);
632
633                Ok(())
634            },
635        );
636
637        methods.add_method(
638            "add_global",
639            |lua: &mlua::Lua,
640             this: &LuaEnvironment,
641             (name, val, pass_state): (String, mlua::Value, Option<bool>)|
642             -> mlua::Result<()> {
643                match val {
644                    mlua::Value::Function(f) => {
645                        let key = lua.create_registry_value(f)?;
646                        let mut func = LuaFunctionObject::new(key);
647                        func.set_pass_state(pass_state.unwrap_or(true));
648
649                        this.write_env()?.add_function(
650                            name,
651                            move |state: &State, args: JinjaRest<JinjaValue>| {
652                                func.with_func(&args, Some(state))
653                            },
654                        )
655                    },
656                    _ => this
657                        .write_env()?
658                        .add_global(name, lua_to_minijinja(lua, &val)),
659                };
660
661                Ok(())
662            },
663        );
664
665        methods.add_method(
666            "remove_global",
667            |_, this: &LuaEnvironment, name: String| -> mlua::Result<()> {
668                this.write_env()?.remove_global(&name);
669
670                Ok(())
671            },
672        );
673
674        methods.add_method(
675            "globals",
676            |lua: &mlua::Lua,
677             this: &LuaEnvironment,
678             _val: mlua::Value|
679             -> mlua::Result<mlua::Table> {
680                let table = lua.create_table()?;
681
682                for (name, value) in this.read_env()?.globals() {
683                    let val = minijinja_to_lua(lua, &value);
684                    table.set(name, val)?;
685                }
686
687                Ok(table)
688            },
689        );
690    }
691}