Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    borrow::Cow,
5    fmt,
6    sync::{
7        RwLock,
8        RwLockReadGuard,
9        RwLockWriteGuard,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use minijinja::{
15    Environment,
16    Error as JinjaError,
17    ErrorKind as JinjaErrorKind,
18    State,
19    args,
20    value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::prelude::{
23    Lua,
24    LuaError,
25    LuaFunction,
26    LuaMultiValue,
27    LuaSerdeExt,
28    LuaTable,
29    LuaUserData,
30    LuaUserDataMethods,
31    LuaValue,
32};
33
34use crate::{
35    convert::{
36        LuaFunctionObject,
37        LuaObject,
38        lua_to_auto_escape,
39        lua_to_minijinja,
40        lua_to_syntax_config,
41        lua_to_undefined_behavior,
42        minijinja_to_lua,
43        undefined_behavior_to_lua,
44    },
45    state::bind_lua,
46};
47
48/// A wrapper around a `minijinja::Environment`. This wrapper can be serialized into
49/// an `mlua::UserData` object for use within Lua.
50#[derive(Debug)]
51pub struct LuaEnvironment {
52    env: RwLock<Environment<'static>>,
53    reload_before_render: AtomicBool,
54}
55
56impl LuaEnvironment {
57    /// Get a new environment
58    pub(crate) fn new() -> Self {
59        let mut env = Environment::new();
60        minijinja_contrib::add_to_environment(&mut env);
61
62        Self {
63            env: RwLock::new(env),
64            reload_before_render: AtomicBool::new(false),
65        }
66    }
67
68    /// Get a new empty environment
69    pub(crate) fn empty() -> Self {
70        let env = Environment::empty();
71
72        Self {
73            env: RwLock::new(env),
74            reload_before_render: AtomicBool::new(false),
75        }
76    }
77
78    pub(crate) fn reload_before_render(&self) -> bool {
79        self.reload_before_render.load(Ordering::Relaxed)
80    }
81
82    pub(crate) fn set_reload_before_render(&self, enable: bool) {
83        self.reload_before_render.store(enable, Ordering::Relaxed);
84    }
85
86    /// Get a read-only lock on the underlying `minijinja::Environment`
87    pub(crate) fn read_env(&self) -> Result<RwLockReadGuard<'_, Environment<'static>>, LuaError> {
88        self.env
89            .read()
90            .map_err(|_| LuaError::runtime("environment lock poisoned"))
91    }
92
93    /// Get a write lock on the underlying `minijinja::Environment`
94    pub(crate) fn write_env(&self) -> Result<RwLockWriteGuard<'_, Environment<'static>>, LuaError> {
95        self.env
96            .write()
97            .map_err(|_| LuaError::runtime("environment lock poisoned"))
98    }
99}
100
101impl Default for LuaEnvironment {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl fmt::Display for LuaEnvironment {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        write!(f, "Environment")
110    }
111}
112
113impl LuaUserData for LuaEnvironment {
114    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
115        fields.add_field_method_get(
116            "reload_before_render",
117            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
118                Ok(this.reload_before_render())
119            },
120        );
121
122        fields.add_field_method_set(
123            "reload_before_render",
124            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
125                this.set_reload_before_render(val);
126
127                Ok(())
128            },
129        );
130
131        fields.add_field_method_get(
132            "keep_trailing_newline",
133            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
134                Ok(this.read_env()?.keep_trailing_newline())
135            },
136        );
137
138        fields.add_field_method_set(
139            "keep_trailing_newline",
140            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
141                this.write_env()?.set_keep_trailing_newline(val);
142
143                Ok(())
144            },
145        );
146
147        fields.add_field_method_get(
148            "trim_blocks",
149            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
150                Ok(this.read_env()?.trim_blocks())
151            },
152        );
153
154        fields.add_field_method_set(
155            "trim_blocks",
156            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
157                this.write_env()?.set_trim_blocks(val);
158
159                Ok(())
160            },
161        );
162
163        fields.add_field_method_get(
164            "lstrip_blocks",
165            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
166                Ok(this.read_env()?.lstrip_blocks())
167            },
168        );
169
170        fields.add_field_method_set(
171            "lstrip_blocks",
172            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
173                this.write_env()?.set_lstrip_blocks(val);
174
175                Ok(())
176            },
177        );
178
179        fields.add_field_method_get(
180            "debug",
181            |_, this: &LuaEnvironment| -> Result<bool, LuaError> { Ok(this.read_env()?.debug()) },
182        );
183
184        fields.add_field_method_set(
185            "debug",
186            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
187                this.write_env()?.set_debug(val);
188
189                Ok(())
190            },
191        );
192
193        fields.add_field_method_get(
194            "fuel",
195            |_, this: &LuaEnvironment| -> Result<Option<u64>, LuaError> {
196                Ok(this.read_env()?.fuel())
197            },
198        );
199
200        fields.add_field_method_set(
201            "fuel",
202            |_, this: &mut LuaEnvironment, val: Option<u64>| -> Result<(), LuaError> {
203                this.write_env()?.set_fuel(val);
204
205                Ok(())
206            },
207        );
208
209        fields.add_field_method_get(
210            "recursion_limit",
211            |_, this: &LuaEnvironment| -> Result<usize, LuaError> {
212                Ok(this.read_env()?.recursion_limit())
213            },
214        );
215
216        fields.add_field_method_set(
217            "recursion_limit",
218            |_, this: &mut LuaEnvironment, val: usize| -> Result<(), LuaError> {
219                this.write_env()?.set_recursion_limit(val);
220
221                Ok(())
222            },
223        );
224
225        fields.add_field_method_get(
226            "undefined_behavior",
227            |_, this: &LuaEnvironment| -> Result<Option<String>, LuaError> {
228                let ub = this.read_env()?.undefined_behavior();
229
230                Ok(undefined_behavior_to_lua(ub))
231            },
232        );
233
234        fields.add_field_method_set(
235            "undefined_behavior",
236            |_, this: &mut LuaEnvironment, val: String| -> Result<(), LuaError> {
237                let behavior = lua_to_undefined_behavior(&val)?;
238
239                this.write_env()?.set_undefined_behavior(behavior);
240
241                Ok(())
242            },
243        );
244    }
245
246    fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
247        methods.add_function("new", |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> {
248            Ok(LuaEnvironment::new())
249        });
250
251        methods.add_function(
252            "empty",
253            |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::empty()) },
254        );
255
256        methods.add_method(
257            "add_template",
258            |lua: &Lua,
259             this: &LuaEnvironment,
260             (name, source): (String, String)|
261             -> Result<(), LuaError> {
262                bind_lua(lua, || {
263                    this.write_env()?
264                        .add_template_owned(name, source)
265                        .map_err(LuaError::external)
266                })
267            },
268        );
269
270        methods.add_method(
271            "remove_template",
272            |lua: &Lua, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
273                bind_lua(lua, || {
274                    this.write_env()?.remove_template(&name);
275                    Ok(())
276                })
277            },
278        );
279
280        methods.add_method(
281            "clear_templates",
282            |lua: &Lua, this: &LuaEnvironment, _: LuaValue| -> Result<(), LuaError> {
283                bind_lua(lua, || {
284                    this.write_env()?.clear_templates();
285
286                    Ok(())
287                })
288            },
289        );
290
291        methods.add_method(
292            "undeclared_variables",
293            |lua: &Lua,
294             this: &LuaEnvironment,
295             (name, nested): (String, Option<bool>)|
296             -> Result<LuaValue, LuaError> {
297                bind_lua(lua, || {
298                    if this.reload_before_render() {
299                        this.write_env()?.clear_templates();
300                    }
301
302                    let nested = nested.unwrap_or(false);
303
304                    let vars = this
305                        .read_env()?
306                        .get_template(&name)
307                        .map_err(LuaError::external)?
308                        .undeclared_variables(nested);
309
310                    lua.to_value(&vars)
311                })
312            },
313        );
314
315        methods.add_method(
316            "set_loader",
317            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
318                let key = lua.create_registry_value(callback)?;
319                let func = LuaFunctionObject::new(key);
320
321                this.write_env()?.set_loader(move |name| {
322                    let source = func.with_func(args!(name), None)?;
323                    Ok(source.and_then(|v| {
324                        // If the lua function returns nil, i.e., no path found
325                        // it is mapped as `minijinja::value::Undefined`, however
326                        // we need to return a `None` to indicate no path was found.
327                        if v.is_undefined() {
328                            None
329                        } else {
330                            Some(v.to_string())
331                        }
332                    }))
333                });
334
335                Ok(())
336            },
337        );
338
339        methods.add_method(
340            "set_path_join_callback",
341            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
342                let key = lua.create_registry_value(callback)?;
343                let func = LuaFunctionObject::new(key);
344
345                this.write_env()?
346                    .set_path_join_callback(move |name, parent| {
347                        func.with_func(args!(name, parent), None)
348                            .ok()
349                            .flatten()
350                            .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
351                            .unwrap_or(Cow::Borrowed(name))
352                    });
353                Ok(())
354            },
355        );
356
357        methods.add_method(
358            "set_unknown_method_callback",
359            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
360                let key = lua.create_registry_value(callback)?;
361                let mut func = LuaFunctionObject::new(key);
362                func.set_pass_state(true);
363
364                this.write_env()?
365                    .set_unknown_method_callback(move |state, value, method, args| {
366                        func.with_func(args!(value, method, ..args), Some(state))
367                            .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
368                    });
369
370                Ok(())
371            },
372        );
373
374        methods.add_method(
375            "set_pycompat",
376            |_, this: &LuaEnvironment, enable: Option<bool>| {
377                match enable {
378                    Some(true) | None => {
379                        this.write_env()?.set_unknown_method_callback(
380                            minijinja_contrib::pycompat::unknown_method_callback,
381                        );
382                    },
383                    Some(false) => {
384                        this.write_env()?.set_unknown_method_callback(|_, _, _, _| {
385                            Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
386                        });
387                    },
388                }
389
390                Ok(())
391            },
392        );
393
394        methods.add_method(
395            "set_auto_escape_callback",
396            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
397                let key = lua.create_registry_value(callback)?;
398                let func = LuaFunctionObject::new(key);
399
400                this.write_env()?.set_auto_escape_callback(move |name| {
401                    func.with_func(args!(name), None)
402                        .ok()
403                        .flatten()
404                        .and_then(|v| {
405                            let s = v.as_str()?.to_string();
406                            lua_to_auto_escape(&s).ok()
407                        })
408                        .unwrap_or(minijinja::AutoEscape::None)
409                });
410                Ok(())
411            },
412        );
413
414        methods.add_method(
415            "set_formatter",
416            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
417                let key = lua.create_registry_value(callback)?;
418                let mut func = LuaFunctionObject::new(key);
419                func.set_pass_state(true);
420
421                this.write_env()?.set_formatter(move |out, state, value| {
422                    let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
423                        return Ok(());
424                    };
425
426                    let Some(s) = val.as_str() else {
427                        return Err(JinjaError::new(
428                            JinjaErrorKind::WriteFailure,
429                            "formatter must return a string",
430                        ));
431                    };
432
433                    out.write_str(s)
434                        .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
435                });
436
437                Ok(())
438            },
439        );
440
441        methods.add_method(
442            "set_syntax",
443            |_, this: &LuaEnvironment, syntax: LuaTable| -> Result<(), LuaError> {
444                let config = lua_to_syntax_config(syntax).map_err(LuaError::external)?;
445                this.write_env()?.set_syntax(config);
446
447                Ok(())
448            },
449        );
450
451        methods.add_method(
452            "render_template",
453            |lua: &Lua,
454             this: &LuaEnvironment,
455             (name, ctx): (String, Option<LuaTable>)|
456             -> Result<String, LuaError> {
457                if this.reload_before_render() {
458                    this.write_env()?.clear_templates();
459                }
460
461                let ctx = ctx.unwrap_or(lua.create_table()?);
462
463                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
464
465                bind_lua(lua, || {
466                    this.read_env()?
467                        .get_template(&name)
468                        .map_err(LuaError::external)?
469                        .render(context)
470                        .map_err(LuaError::external)
471                })
472            },
473        );
474
475        methods.add_method(
476            "render_str",
477            |lua: &Lua,
478             this: &LuaEnvironment,
479             (source, ctx, name): (String, Option<LuaTable>, Option<String>)|
480             -> Result<String, LuaError> {
481                let ctx = ctx.unwrap_or(lua.create_table()?);
482
483                let name = name.unwrap_or("<string>".to_string());
484                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
485
486                bind_lua(lua, || {
487                    this.read_env()?
488                        .render_named_str(&name, &source, context)
489                        .map_err(LuaError::external)
490                })
491            },
492        );
493
494        methods.add_method(
495            "eval",
496            |lua: &Lua,
497             this: &LuaEnvironment,
498             (source, ctx): (String, Option<LuaTable>)|
499             -> Result<LuaValue, LuaError> {
500                let ctx = ctx.unwrap_or(lua.create_table()?);
501
502                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
503
504                bind_lua(lua, || {
505                    let expr = this
506                        .read_env()?
507                        .compile_expression(&source)
508                        .map_err(LuaError::external)?
509                        .eval(&context)
510                        .map_err(LuaError::external)?;
511
512                    minijinja_to_lua(lua, &expr).ok_or_else(|| LuaError::ToLuaConversionError {
513                        from: "".to_string(),
514                        to: "",
515                        message: None,
516                    })
517                })
518            },
519        );
520
521        methods.add_method(
522            "add_filter",
523            |lua: &Lua,
524             this: &LuaEnvironment,
525             (name, filter, pass_state): (String, LuaFunction, Option<bool>)|
526             -> Result<(), LuaError> {
527                let key = lua.create_registry_value(filter)?;
528                let mut func = LuaFunctionObject::new(key);
529                func.set_pass_state(pass_state.unwrap_or(true));
530
531                this.write_env()?.add_filter(
532                    name,
533                    move |state: &State, args: JinjaRest<JinjaValue>| {
534                        func.with_func(&args, Some(state))
535                    },
536                );
537
538                Ok(())
539            },
540        );
541
542        methods.add_method(
543            "remove_filter",
544            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
545                this.write_env()?.remove_filter(&name);
546
547                Ok(())
548            },
549        );
550
551        methods.add_method(
552            "add_test",
553            |lua: &Lua,
554             this: &LuaEnvironment,
555             (name, test, pass_state): (String, LuaFunction, Option<bool>)|
556             -> Result<(), LuaError> {
557                let key = lua.create_registry_value(test)?;
558                let mut func = LuaFunctionObject::new(key);
559                func.set_pass_state(pass_state.unwrap_or(true));
560
561                this.write_env()?.add_test(
562                    name,
563                    move |state: &State, args: JinjaRest<JinjaValue>| {
564                        func.with_func(&args, Some(state))
565                    },
566                );
567
568                Ok(())
569            },
570        );
571
572        methods.add_method(
573            "remove_test",
574            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
575                this.write_env()?.remove_test(&name);
576
577                Ok(())
578            },
579        );
580
581        methods.add_method(
582            "add_global",
583            |lua: &Lua,
584             this: &LuaEnvironment,
585             (name, val, pass_state): (String, LuaValue, Option<bool>)|
586             -> Result<(), LuaError> {
587                match val {
588                    LuaValue::Function(f) => {
589                        let key = lua.create_registry_value(f)?;
590                        let mut func = LuaFunctionObject::new(key);
591                        func.set_pass_state(pass_state.unwrap_or(true));
592
593                        this.write_env()?.add_function(
594                            name,
595                            move |state: &State, args: JinjaRest<JinjaValue>| {
596                                func.with_func(&args, Some(state))
597                            },
598                        )
599                    },
600                    _ => this
601                        .write_env()?
602                        .add_global(name, lua_to_minijinja(lua, &val)),
603                };
604
605                Ok(())
606            },
607        );
608
609        methods.add_method(
610            "remove_global",
611            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
612                this.write_env()?.remove_global(&name);
613
614                Ok(())
615            },
616        );
617
618        methods.add_method(
619            "globals",
620            |lua: &Lua, this: &LuaEnvironment, _val: LuaValue| -> Result<LuaTable, LuaError> {
621                let table = lua.create_table()?;
622
623                for (name, value) in this.read_env()?.globals() {
624                    let val = minijinja_to_lua(lua, &value);
625                    table.set(name, val)?;
626                }
627
628                Ok(table)
629            },
630        );
631    }
632}