Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    borrow::Cow,
5    fmt,
6    sync::{
7        RwLock,
8        RwLockReadGuard,
9        RwLockWriteGuard,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use minijinja::{
15    Environment,
16    Error as JinjaError,
17    ErrorKind as JinjaErrorKind,
18    State,
19    args,
20    value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::prelude::{
23    Lua,
24    LuaError,
25    LuaFunction,
26    LuaMultiValue,
27    LuaSerdeExt,
28    LuaTable,
29    LuaUserData,
30    LuaUserDataMethods,
31    LuaValue,
32};
33
34use crate::{
35    convert::{
36        LuaFunctionObject,
37        LuaObject,
38        lua_to_auto_escape,
39        lua_to_minijinja,
40        lua_to_syntax_config,
41        lua_to_undefined_behavior,
42        minijinja_to_lua,
43        undefined_behavior_to_lua,
44    },
45    state::bind_lua,
46};
47
48/// A wrapper around a `minijinja::Environment`. This wrapper can be serialized into
49/// an `mlua::UserData` object for use within Lua.
50#[derive(Debug)]
51pub struct LuaEnvironment {
52    env: RwLock<Environment<'static>>,
53    reload_before_render: AtomicBool,
54}
55
56impl LuaEnvironment {
57    /// Get a new environment
58    pub(crate) fn new() -> Self {
59        let mut env = Environment::new();
60        minijinja_contrib::add_to_environment(&mut env);
61
62        Self {
63            env: RwLock::new(env),
64            reload_before_render: AtomicBool::new(false),
65        }
66    }
67
68    /// Get a new empty environment
69    pub(crate) fn empty() -> Self {
70        let env = Environment::empty();
71
72        Self {
73            env: RwLock::new(env),
74            reload_before_render: AtomicBool::new(false),
75        }
76    }
77
78    pub(crate) fn reload_before_render(&self) -> bool {
79        self.reload_before_render.load(Ordering::Relaxed)
80    }
81
82    pub(crate) fn set_reload_before_render(&self, enable: bool) {
83        self.reload_before_render.store(enable, Ordering::Relaxed);
84    }
85
86    /// Get a read-only lock on the underlying `minijinja::Environment`
87    pub(crate) fn read_env(&self) -> Result<RwLockReadGuard<'_, Environment<'static>>, LuaError> {
88        self.env
89            .read()
90            .map_err(|_| LuaError::runtime("environment lock poisoned"))
91    }
92
93    /// Get a write lock on the underlying `minijinja::Environment`
94    pub(crate) fn write_env(&self) -> Result<RwLockWriteGuard<'_, Environment<'static>>, LuaError> {
95        self.env
96            .write()
97            .map_err(|_| LuaError::runtime("environment lock poisoned"))
98    }
99}
100
101impl Default for LuaEnvironment {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl fmt::Display for LuaEnvironment {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        write!(f, "Environment")
110    }
111}
112
113impl LuaUserData for LuaEnvironment {
114    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
115        fields.add_field_method_get(
116            "reload_before_render",
117            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
118                Ok(this.reload_before_render())
119            },
120        );
121
122        fields.add_field_method_set(
123            "reload_before_render",
124            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
125                this.set_reload_before_render(val);
126
127                Ok(())
128            },
129        );
130
131        fields.add_field_method_get(
132            "keep_trailing_newline",
133            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
134                Ok(this.read_env()?.keep_trailing_newline())
135            },
136        );
137
138        fields.add_field_method_set(
139            "keep_trailing_newline",
140            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
141                this.write_env()?.set_keep_trailing_newline(val);
142
143                Ok(())
144            },
145        );
146
147        fields.add_field_method_get(
148            "trim_blocks",
149            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
150                Ok(this.read_env()?.trim_blocks())
151            },
152        );
153
154        fields.add_field_method_set(
155            "trim_blocks",
156            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
157                this.write_env()?.set_trim_blocks(val);
158
159                Ok(())
160            },
161        );
162
163        fields.add_field_method_get(
164            "lstrip_blocks",
165            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
166                Ok(this.read_env()?.lstrip_blocks())
167            },
168        );
169
170        fields.add_field_method_set(
171            "lstrip_blocks",
172            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
173                this.write_env()?.set_lstrip_blocks(val);
174
175                Ok(())
176            },
177        );
178
179        fields.add_field_method_get(
180            "debug",
181            |_, this: &LuaEnvironment| -> Result<bool, LuaError> { Ok(this.read_env()?.debug()) },
182        );
183
184        fields.add_field_method_set(
185            "debug",
186            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
187                this.write_env()?.set_debug(val);
188
189                Ok(())
190            },
191        );
192
193        fields.add_field_method_get(
194            "fuel",
195            |_, this: &LuaEnvironment| -> Result<Option<u64>, LuaError> {
196                Ok(this.read_env()?.fuel())
197            },
198        );
199
200        fields.add_field_method_set(
201            "fuel",
202            |_, this: &mut LuaEnvironment, val: Option<u64>| -> Result<(), LuaError> {
203                this.write_env()?.set_fuel(val);
204
205                Ok(())
206            },
207        );
208
209        fields.add_field_method_get(
210            "recursion_limit",
211            |_, this: &LuaEnvironment| -> Result<usize, LuaError> {
212                Ok(this.read_env()?.recursion_limit())
213            },
214        );
215
216        fields.add_field_method_set(
217            "recursion_limit",
218            |_, this: &mut LuaEnvironment, val: usize| -> Result<(), LuaError> {
219                this.write_env()?.set_recursion_limit(val);
220
221                Ok(())
222            },
223        );
224
225        fields.add_field_method_get(
226            "undefined_behavior",
227            |_, this: &LuaEnvironment| -> Result<Option<String>, LuaError> {
228                let ub = this.read_env()?.undefined_behavior();
229
230                Ok(undefined_behavior_to_lua(ub))
231            },
232        );
233
234        fields.add_field_method_set(
235            "undefined_behavior",
236            |_, this: &mut LuaEnvironment, val: String| -> Result<(), LuaError> {
237                let behavior = lua_to_undefined_behavior(&val)?;
238
239                this.write_env()?.set_undefined_behavior(behavior);
240
241                Ok(())
242            },
243        );
244    }
245
246    fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
247        methods.add_function("new", |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> {
248            Ok(LuaEnvironment::new())
249        });
250
251        methods.add_function(
252            "empty",
253            |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::empty()) },
254        );
255
256        methods.add_method(
257            "add_template",
258            |lua: &Lua,
259             this: &LuaEnvironment,
260             (name, source): (String, String)|
261             -> Result<(), LuaError> {
262                bind_lua(lua, || {
263                    this.write_env()?
264                        .add_template_owned(name, source)
265                        .map_err(LuaError::external)
266                })
267            },
268        );
269
270        methods.add_method(
271            "remove_template",
272            |lua: &Lua, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
273                bind_lua(lua, || {
274                    this.write_env()?.remove_template(&name);
275                    Ok(())
276                })
277            },
278        );
279
280        methods.add_method(
281            "clear_templates",
282            |lua: &Lua, this: &LuaEnvironment, _: LuaValue| -> Result<(), LuaError> {
283                bind_lua(lua, || {
284                    this.write_env()?.clear_templates();
285
286                    Ok(())
287                })
288            },
289        );
290
291        methods.add_method(
292            "undeclared_variables",
293            |lua: &Lua,
294             this: &LuaEnvironment,
295             (name, nested): (String, Option<bool>)|
296             -> Result<LuaValue, LuaError> {
297                bind_lua(lua, || {
298                    if this.reload_before_render() {
299                        this.write_env()?.clear_templates();
300                    }
301
302                    let nested = nested.unwrap_or(false);
303
304                    let vars = this
305                        .read_env()?
306                        .get_template(&name)
307                        .map_err(LuaError::external)?
308                        .undeclared_variables(nested);
309
310                    lua.to_value(&vars)
311                })
312            },
313        );
314
315        methods.add_method(
316            "set_loader",
317            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
318                let key = lua.create_registry_value(callback)?;
319                let func = LuaFunctionObject::new(key);
320
321                this.write_env()?.set_loader(move |name| {
322                    let source = func.with_func(args!(name), None)?;
323                    Ok(source.and_then(|v| {
324                        // If the lua function returns nil, i.e., no path found
325                        // it is mapped as `minijinja::value::Undefined`, however
326                        // we need to return a `None` to indicate no path was found.
327                        if v.is_undefined() {
328                            None
329                        } else {
330                            Some(v.to_string())
331                        }
332                    }))
333                });
334
335                Ok(())
336            },
337        );
338
339        methods.add_method(
340            "set_path_join_callback",
341            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
342                let key = lua.create_registry_value(callback)?;
343                let func = LuaFunctionObject::new(key);
344
345                this.write_env()?
346                    .set_path_join_callback(move |name, parent| {
347                        func.with_func(args!(name, parent), None)
348                            .ok()
349                            .flatten()
350                            .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
351                            .unwrap_or(Cow::Borrowed(name))
352                    });
353                Ok(())
354            },
355        );
356
357        methods.add_method(
358            "set_unknown_method_callback",
359            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
360                let key = lua.create_registry_value(callback)?;
361                let mut func = LuaFunctionObject::new(key);
362                func.set_pass_state(true);
363
364                this.write_env()?
365                    .set_unknown_method_callback(move |state, value, method, args| {
366                        func.with_func(args!(value, method, ..args), Some(state))
367                            .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
368                    });
369
370                Ok(())
371            },
372        );
373
374        methods.add_method(
375            "set_auto_escape_callback",
376            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
377                let key = lua.create_registry_value(callback)?;
378                let func = LuaFunctionObject::new(key);
379
380                this.write_env()?.set_auto_escape_callback(move |name| {
381                    func.with_func(args!(name), None)
382                        .ok()
383                        .flatten()
384                        .and_then(|v| {
385                            let s = v.as_str()?.to_string();
386                            lua_to_auto_escape(&s).ok()
387                        })
388                        .unwrap_or(minijinja::AutoEscape::None)
389                });
390                Ok(())
391            },
392        );
393
394        methods.add_method(
395            "set_formatter",
396            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
397                let key = lua.create_registry_value(callback)?;
398                let mut func = LuaFunctionObject::new(key);
399                func.set_pass_state(true);
400
401                this.write_env()?.set_formatter(move |out, state, value| {
402                    let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
403                        return Ok(());
404                    };
405
406                    let Some(s) = val.as_str() else {
407                        return Err(JinjaError::new(
408                            JinjaErrorKind::WriteFailure,
409                            "formatter must return a string",
410                        ));
411                    };
412
413                    out.write_str(s)
414                        .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
415                });
416
417                Ok(())
418            },
419        );
420
421        methods.add_method(
422            "set_syntax",
423            |_, this: &LuaEnvironment, syntax: LuaTable| -> Result<(), LuaError> {
424                let config = lua_to_syntax_config(syntax).map_err(LuaError::external)?;
425                this.write_env()?.set_syntax(config);
426
427                Ok(())
428            },
429        );
430
431        methods.add_method(
432            "render_template",
433            |lua: &Lua,
434             this: &LuaEnvironment,
435             (name, ctx): (String, Option<LuaTable>)|
436             -> Result<String, LuaError> {
437                if this.reload_before_render() {
438                    this.write_env()?.clear_templates();
439                }
440
441                let ctx = ctx.unwrap_or(lua.create_table()?);
442
443                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
444
445                bind_lua(lua, || {
446                    this.read_env()?
447                        .get_template(&name)
448                        .map_err(LuaError::external)?
449                        .render(context)
450                        .map_err(LuaError::external)
451                })
452            },
453        );
454
455        methods.add_method(
456            "render_str",
457            |lua: &Lua,
458             this: &LuaEnvironment,
459             (source, ctx, name): (String, Option<LuaTable>, Option<String>)|
460             -> Result<String, LuaError> {
461                let ctx = ctx.unwrap_or(lua.create_table()?);
462
463                let name = name.unwrap_or("<string>".to_string());
464                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
465
466                bind_lua(lua, || {
467                    this.read_env()?
468                        .render_named_str(&name, &source, context)
469                        .map_err(LuaError::external)
470                })
471            },
472        );
473
474        methods.add_method(
475            "eval",
476            |lua: &Lua,
477             this: &LuaEnvironment,
478             (source, ctx): (String, Option<LuaTable>)|
479             -> Result<LuaValue, LuaError> {
480                let ctx = ctx.unwrap_or(lua.create_table()?);
481
482                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
483
484                bind_lua(lua, || {
485                    let expr = this
486                        .read_env()?
487                        .compile_expression(&source)
488                        .map_err(LuaError::external)?
489                        .eval(&context)
490                        .map_err(LuaError::external)?;
491
492                    minijinja_to_lua(lua, &expr).ok_or_else(|| LuaError::ToLuaConversionError {
493                        from: "".to_string(),
494                        to: "",
495                        message: None,
496                    })
497                })
498            },
499        );
500
501        methods.add_method(
502            "add_filter",
503            |lua: &Lua,
504             this: &LuaEnvironment,
505             (name, filter, pass_state): (String, LuaFunction, Option<bool>)|
506             -> Result<(), LuaError> {
507                let key = lua.create_registry_value(filter)?;
508                let mut func = LuaFunctionObject::new(key);
509                func.set_pass_state(pass_state.unwrap_or(true));
510
511                this.write_env()?.add_filter(
512                    name,
513                    move |state: &State, args: JinjaRest<JinjaValue>| {
514                        func.with_func(&args, Some(state))
515                    },
516                );
517
518                Ok(())
519            },
520        );
521
522        methods.add_method(
523            "remove_filter",
524            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
525                this.write_env()?.remove_filter(&name);
526
527                Ok(())
528            },
529        );
530
531        methods.add_method(
532            "add_test",
533            |lua: &Lua,
534             this: &LuaEnvironment,
535             (name, test, pass_state): (String, LuaFunction, Option<bool>)|
536             -> Result<(), LuaError> {
537                let key = lua.create_registry_value(test)?;
538                let mut func = LuaFunctionObject::new(key);
539                func.set_pass_state(pass_state.unwrap_or(true));
540
541                this.write_env()?.add_test(
542                    name,
543                    move |state: &State, args: JinjaRest<JinjaValue>| {
544                        func.with_func(&args, Some(state))
545                    },
546                );
547
548                Ok(())
549            },
550        );
551
552        methods.add_method(
553            "remove_test",
554            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
555                this.write_env()?.remove_test(&name);
556
557                Ok(())
558            },
559        );
560
561        methods.add_method(
562            "add_global",
563            |lua: &Lua,
564             this: &LuaEnvironment,
565             (name, val, pass_state): (String, LuaValue, Option<bool>)|
566             -> Result<(), LuaError> {
567                match val {
568                    LuaValue::Function(f) => {
569                        let key = lua.create_registry_value(f)?;
570                        let mut func = LuaFunctionObject::new(key);
571                        func.set_pass_state(pass_state.unwrap_or(true));
572
573                        this.write_env()?.add_function(
574                            name,
575                            move |state: &State, args: JinjaRest<JinjaValue>| {
576                                func.with_func(&args, Some(state))
577                            },
578                        )
579                    },
580                    _ => this
581                        .write_env()?
582                        .add_global(name, lua_to_minijinja(lua, &val)),
583                };
584
585                Ok(())
586            },
587        );
588
589        methods.add_method(
590            "remove_global",
591            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
592                this.write_env()?.remove_global(&name);
593
594                Ok(())
595            },
596        );
597
598        methods.add_method(
599            "globals",
600            |lua: &Lua, this: &LuaEnvironment, _val: LuaValue| -> Result<LuaTable, LuaError> {
601                let table = lua.create_table()?;
602
603                for (name, value) in this.read_env()?.globals() {
604                    let val = minijinja_to_lua(lua, &value);
605                    table.set(name, val)?;
606                }
607
608                Ok(table)
609            },
610        );
611    }
612}