Skip to main content

minijinja_lua/
environment.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    borrow::Cow,
5    fmt,
6    sync::{
7        RwLock,
8        RwLockReadGuard,
9        RwLockWriteGuard,
10        atomic::{AtomicBool, Ordering},
11    },
12};
13
14use minijinja::{
15    Environment,
16    Error as JinjaError,
17    ErrorKind as JinjaErrorKind,
18    State,
19    args,
20    value::{Rest as JinjaRest, Value as JinjaValue},
21};
22use mlua::prelude::{
23    Lua,
24    LuaError,
25    LuaFunction,
26    LuaMultiValue,
27    LuaSerdeExt,
28    LuaTable,
29    LuaUserData,
30    LuaUserDataMethods,
31    LuaValue,
32};
33
34use crate::{
35    convert::{
36        LuaFunctionObject,
37        LuaObject,
38        lua_to_auto_escape,
39        lua_to_minijinja,
40        lua_to_syntax_config,
41        lua_to_undefined_behavior,
42        minijinja_to_lua,
43        undefined_behavior_to_lua,
44    },
45    state::bind_lua,
46};
47
48/// A wrapper around a [`minijinja::Environment`]. This wrapper can be serialized into
49/// an [`mlua::UserData`] object for use within Lua.
50#[derive(Debug)]
51pub struct LuaEnvironment {
52    env: RwLock<Environment<'static>>,
53    reload_before_render: AtomicBool,
54}
55
56impl LuaEnvironment {
57    /// Get a new environment
58    pub(crate) fn new() -> Self {
59        let mut env = Environment::new();
60
61        #[cfg(feature = "minijinja-contrib")]
62        minijinja_contrib::add_to_environment(&mut env);
63
64        #[cfg(feature = "json")]
65        crate::contrib::json::add_to_environment(&mut env);
66
67        #[cfg(feature = "datetime")]
68        crate::contrib::datetime::add_to_environment(&mut env);
69
70        Self {
71            env: RwLock::new(env),
72            reload_before_render: AtomicBool::new(false),
73        }
74    }
75
76    /// Get a new empty environment
77    pub(crate) fn empty() -> Self {
78        let env = Environment::empty();
79
80        Self {
81            env: RwLock::new(env),
82            reload_before_render: AtomicBool::new(false),
83        }
84    }
85
86    pub(crate) fn reload_before_render(&self) -> bool {
87        self.reload_before_render.load(Ordering::Relaxed)
88    }
89
90    pub(crate) fn set_reload_before_render(&self, enable: bool) {
91        self.reload_before_render.store(enable, Ordering::Relaxed);
92    }
93
94    /// Get a read-only lock on the underlying `minijinja::Environment`
95    pub(crate) fn read_env(&self) -> Result<RwLockReadGuard<'_, Environment<'static>>, LuaError> {
96        self.env
97            .read()
98            .map_err(|_| LuaError::runtime("environment lock poisoned"))
99    }
100
101    /// Get a write lock on the underlying [`minijinja::Environment`]
102    pub(crate) fn write_env(&self) -> Result<RwLockWriteGuard<'_, Environment<'static>>, LuaError> {
103        self.env
104            .write()
105            .map_err(|_| LuaError::runtime("environment lock poisoned"))
106    }
107}
108
109impl Default for LuaEnvironment {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl fmt::Display for LuaEnvironment {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        write!(f, "Environment")
118    }
119}
120
121impl LuaUserData for LuaEnvironment {
122    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
123        fields.add_field_method_get(
124            "reload_before_render",
125            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
126                Ok(this.reload_before_render())
127            },
128        );
129
130        fields.add_field_method_set(
131            "reload_before_render",
132            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
133                this.set_reload_before_render(val);
134
135                Ok(())
136            },
137        );
138
139        fields.add_field_method_get(
140            "keep_trailing_newline",
141            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
142                Ok(this.read_env()?.keep_trailing_newline())
143            },
144        );
145
146        fields.add_field_method_set(
147            "keep_trailing_newline",
148            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
149                this.write_env()?.set_keep_trailing_newline(val);
150
151                Ok(())
152            },
153        );
154
155        fields.add_field_method_get(
156            "trim_blocks",
157            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
158                Ok(this.read_env()?.trim_blocks())
159            },
160        );
161
162        fields.add_field_method_set(
163            "trim_blocks",
164            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
165                this.write_env()?.set_trim_blocks(val);
166
167                Ok(())
168            },
169        );
170
171        fields.add_field_method_get(
172            "lstrip_blocks",
173            |_, this: &LuaEnvironment| -> Result<bool, LuaError> {
174                Ok(this.read_env()?.lstrip_blocks())
175            },
176        );
177
178        fields.add_field_method_set(
179            "lstrip_blocks",
180            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
181                this.write_env()?.set_lstrip_blocks(val);
182
183                Ok(())
184            },
185        );
186
187        fields.add_field_method_get(
188            "debug",
189            |_, this: &LuaEnvironment| -> Result<bool, LuaError> { Ok(this.read_env()?.debug()) },
190        );
191
192        fields.add_field_method_set(
193            "debug",
194            |_, this: &mut LuaEnvironment, val: bool| -> Result<(), LuaError> {
195                this.write_env()?.set_debug(val);
196
197                Ok(())
198            },
199        );
200
201        fields.add_field_method_get(
202            "fuel",
203            |_, this: &LuaEnvironment| -> Result<Option<u64>, LuaError> {
204                Ok(this.read_env()?.fuel())
205            },
206        );
207
208        fields.add_field_method_set(
209            "fuel",
210            |_, this: &mut LuaEnvironment, val: Option<u64>| -> Result<(), LuaError> {
211                this.write_env()?.set_fuel(val);
212
213                Ok(())
214            },
215        );
216
217        fields.add_field_method_get(
218            "recursion_limit",
219            |_, this: &LuaEnvironment| -> Result<usize, LuaError> {
220                Ok(this.read_env()?.recursion_limit())
221            },
222        );
223
224        fields.add_field_method_set(
225            "recursion_limit",
226            |_, this: &mut LuaEnvironment, val: usize| -> Result<(), LuaError> {
227                this.write_env()?.set_recursion_limit(val);
228
229                Ok(())
230            },
231        );
232
233        fields.add_field_method_get(
234            "undefined_behavior",
235            |_, this: &LuaEnvironment| -> Result<Option<String>, LuaError> {
236                let ub = this.read_env()?.undefined_behavior();
237
238                Ok(undefined_behavior_to_lua(ub))
239            },
240        );
241
242        fields.add_field_method_set(
243            "undefined_behavior",
244            |_, this: &mut LuaEnvironment, val: String| -> Result<(), LuaError> {
245                let behavior = lua_to_undefined_behavior(&val)?;
246
247                this.write_env()?.set_undefined_behavior(behavior);
248
249                Ok(())
250            },
251        );
252    }
253
254    fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
255        methods.add_function("new", |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> {
256            Ok(LuaEnvironment::new())
257        });
258
259        methods.add_function(
260            "empty",
261            |_, _: LuaMultiValue| -> Result<LuaEnvironment, _> { Ok(LuaEnvironment::empty()) },
262        );
263
264        methods.add_method(
265            "add_template",
266            |lua: &Lua,
267             this: &LuaEnvironment,
268             (name, source): (String, String)|
269             -> Result<(), LuaError> {
270                bind_lua(lua, || {
271                    this.write_env()?
272                        .add_template_owned(name, source)
273                        .map_err(LuaError::external)
274                })
275            },
276        );
277
278        methods.add_method(
279            "remove_template",
280            |lua: &Lua, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
281                bind_lua(lua, || {
282                    this.write_env()?.remove_template(&name);
283                    Ok(())
284                })
285            },
286        );
287
288        methods.add_method(
289            "clear_templates",
290            |lua: &Lua, this: &LuaEnvironment, _: LuaValue| -> Result<(), LuaError> {
291                bind_lua(lua, || {
292                    this.write_env()?.clear_templates();
293
294                    Ok(())
295                })
296            },
297        );
298
299        methods.add_method(
300            "undeclared_variables",
301            |lua: &Lua,
302             this: &LuaEnvironment,
303             (name, nested): (String, Option<bool>)|
304             -> Result<LuaValue, LuaError> {
305                bind_lua(lua, || {
306                    if this.reload_before_render() {
307                        this.write_env()?.clear_templates();
308                    }
309
310                    let nested = nested.unwrap_or(false);
311
312                    let vars = this
313                        .read_env()?
314                        .get_template(&name)
315                        .map_err(LuaError::external)?
316                        .undeclared_variables(nested);
317
318                    lua.to_value(&vars)
319                })
320            },
321        );
322
323        methods.add_method(
324            "set_loader",
325            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
326                let key = lua.create_registry_value(callback)?;
327                let func = LuaFunctionObject::new(key);
328
329                this.write_env()?.set_loader(move |name| {
330                    let source = func.with_func(args!(name), None)?;
331                    Ok(source.and_then(|v| {
332                        // If the lua function returns nil, i.e., no path found
333                        // it is mapped as `minijinja::value::ValueKind::Undefined`, however
334                        // we need to return a `None` to indicate no path was found.
335                        if v.is_undefined() {
336                            None
337                        } else {
338                            Some(v.to_string())
339                        }
340                    }))
341                });
342
343                Ok(())
344            },
345        );
346
347        methods.add_method(
348            "set_path_join_callback",
349            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
350                let key = lua.create_registry_value(callback)?;
351                let func = LuaFunctionObject::new(key);
352
353                this.write_env()?
354                    .set_path_join_callback(move |name, parent| {
355                        func.with_func(args!(name, parent), None)
356                            .ok()
357                            .flatten()
358                            .and_then(|v| v.as_str().map(|s| Cow::Owned(s.to_string())))
359                            .unwrap_or(Cow::Borrowed(name))
360                    });
361                Ok(())
362            },
363        );
364
365        methods.add_method(
366            "set_unknown_method_callback",
367            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
368                let key = lua.create_registry_value(callback)?;
369                let mut func = LuaFunctionObject::new(key);
370                func.set_pass_state(true);
371
372                this.write_env()?
373                    .set_unknown_method_callback(move |state, value, method, args| {
374                        func.with_func(args!(value, method, ..args), Some(state))
375                            .map(|v| v.unwrap_or(JinjaValue::UNDEFINED))
376                    });
377
378                Ok(())
379            },
380        );
381
382        methods.add_method(
383            "set_pycompat",
384            |_, this: &LuaEnvironment, enable: Option<bool>| {
385                match enable {
386                    Some(true) | None => {
387                        this.write_env()?.set_unknown_method_callback(
388                            minijinja_contrib::pycompat::unknown_method_callback,
389                        );
390                    },
391                    Some(false) => {
392                        this.write_env()?.set_unknown_method_callback(|_, _, _, _| {
393                            Err(JinjaError::from(JinjaErrorKind::UnknownMethod))
394                        });
395                    },
396                }
397
398                Ok(())
399            },
400        );
401
402        methods.add_method(
403            "set_auto_escape_callback",
404            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
405                let key = lua.create_registry_value(callback)?;
406                let func = LuaFunctionObject::new(key);
407
408                this.write_env()?.set_auto_escape_callback(move |name| {
409                    func.with_func(args!(name), None)
410                        .ok()
411                        .flatten()
412                        .and_then(|v| {
413                            let s = v.as_str()?.to_string();
414                            lua_to_auto_escape(&s).ok()
415                        })
416                        .unwrap_or(minijinja::AutoEscape::None)
417                });
418                Ok(())
419            },
420        );
421
422        methods.add_method(
423            "set_formatter",
424            |lua: &Lua, this: &LuaEnvironment, callback: LuaFunction| -> Result<(), LuaError> {
425                let key = lua.create_registry_value(callback)?;
426                let mut func = LuaFunctionObject::new(key);
427                func.set_pass_state(true);
428
429                this.write_env()?.set_formatter(move |out, state, value| {
430                    let Some(val) = func.with_func(args!(value), Some(state)).ok().flatten() else {
431                        return Ok(());
432                    };
433
434                    let Some(s) = val.as_str() else {
435                        return Err(JinjaError::new(
436                            JinjaErrorKind::WriteFailure,
437                            "formatter must return a string",
438                        ));
439                    };
440
441                    out.write_str(s)
442                        .map_err(|_| JinjaError::new(JinjaErrorKind::WriteFailure, "write failed"))
443                });
444
445                Ok(())
446            },
447        );
448
449        methods.add_method(
450            "set_syntax",
451            |_, this: &LuaEnvironment, syntax: LuaTable| -> Result<(), LuaError> {
452                let config = lua_to_syntax_config(syntax).map_err(LuaError::external)?;
453                this.write_env()?.set_syntax(config);
454
455                Ok(())
456            },
457        );
458
459        methods.add_method(
460            "render_template",
461            |lua: &Lua,
462             this: &LuaEnvironment,
463             (name, ctx): (String, Option<LuaTable>)|
464             -> Result<String, LuaError> {
465                if this.reload_before_render() {
466                    this.write_env()?.clear_templates();
467                }
468
469                let ctx = ctx.unwrap_or(lua.create_table()?);
470
471                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
472
473                bind_lua(lua, || {
474                    this.read_env()?
475                        .get_template(&name)
476                        .map_err(LuaError::external)?
477                        .render(context)
478                        .map_err(LuaError::external)
479                })
480            },
481        );
482
483        methods.add_method(
484            "render_str",
485            |lua: &Lua,
486             this: &LuaEnvironment,
487             (source, ctx, name): (String, Option<LuaTable>, Option<String>)|
488             -> Result<String, LuaError> {
489                let ctx = ctx.unwrap_or(lua.create_table()?);
490
491                let name = name.unwrap_or("<string>".to_string());
492                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
493
494                bind_lua(lua, || {
495                    this.read_env()?
496                        .render_named_str(&name, &source, context)
497                        .map_err(LuaError::external)
498                })
499            },
500        );
501
502        methods.add_method(
503            "eval",
504            |lua: &Lua,
505             this: &LuaEnvironment,
506             (source, ctx): (String, Option<LuaTable>)|
507             -> Result<LuaValue, LuaError> {
508                let ctx = ctx.unwrap_or(lua.create_table()?);
509
510                let context = lua_to_minijinja(lua, &LuaValue::Table(ctx));
511
512                bind_lua(lua, || {
513                    let expr = this
514                        .read_env()?
515                        .compile_expression(&source)
516                        .map_err(LuaError::external)?
517                        .eval(&context)
518                        .map_err(LuaError::external)?;
519
520                    minijinja_to_lua(lua, &expr).ok_or_else(|| LuaError::ToLuaConversionError {
521                        from: "".to_string(),
522                        to: "",
523                        message: None,
524                    })
525                })
526            },
527        );
528
529        methods.add_method(
530            "add_filter",
531            |lua: &Lua,
532             this: &LuaEnvironment,
533             (name, filter, pass_state): (String, LuaFunction, Option<bool>)|
534             -> Result<(), LuaError> {
535                let key = lua.create_registry_value(filter)?;
536                let mut func = LuaFunctionObject::new(key);
537                func.set_pass_state(pass_state.unwrap_or(true));
538
539                this.write_env()?.add_filter(
540                    name,
541                    move |state: &State, args: JinjaRest<JinjaValue>| {
542                        func.with_func(&args, Some(state))
543                    },
544                );
545
546                Ok(())
547            },
548        );
549
550        methods.add_method(
551            "remove_filter",
552            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
553                this.write_env()?.remove_filter(&name);
554
555                Ok(())
556            },
557        );
558
559        methods.add_method(
560            "add_test",
561            |lua: &Lua,
562             this: &LuaEnvironment,
563             (name, test, pass_state): (String, LuaFunction, Option<bool>)|
564             -> Result<(), LuaError> {
565                let key = lua.create_registry_value(test)?;
566                let mut func = LuaFunctionObject::new(key);
567                func.set_pass_state(pass_state.unwrap_or(true));
568
569                this.write_env()?.add_test(
570                    name,
571                    move |state: &State, args: JinjaRest<JinjaValue>| {
572                        func.with_func(&args, Some(state))
573                    },
574                );
575
576                Ok(())
577            },
578        );
579
580        methods.add_method(
581            "remove_test",
582            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
583                this.write_env()?.remove_test(&name);
584
585                Ok(())
586            },
587        );
588
589        methods.add_method(
590            "add_global",
591            |lua: &Lua,
592             this: &LuaEnvironment,
593             (name, val, pass_state): (String, LuaValue, Option<bool>)|
594             -> Result<(), LuaError> {
595                match val {
596                    LuaValue::Function(f) => {
597                        let key = lua.create_registry_value(f)?;
598                        let mut func = LuaFunctionObject::new(key);
599                        func.set_pass_state(pass_state.unwrap_or(true));
600
601                        this.write_env()?.add_function(
602                            name,
603                            move |state: &State, args: JinjaRest<JinjaValue>| {
604                                func.with_func(&args, Some(state))
605                            },
606                        )
607                    },
608                    _ => this
609                        .write_env()?
610                        .add_global(name, lua_to_minijinja(lua, &val)),
611                };
612
613                Ok(())
614            },
615        );
616
617        methods.add_method(
618            "remove_global",
619            |_, this: &LuaEnvironment, name: String| -> Result<(), LuaError> {
620                this.write_env()?.remove_global(&name);
621
622                Ok(())
623            },
624        );
625
626        methods.add_method(
627            "globals",
628            |lua: &Lua, this: &LuaEnvironment, _val: LuaValue| -> Result<LuaTable, LuaError> {
629                let table = lua.create_table()?;
630
631                for (name, value) in this.read_env()?.globals() {
632                    let val = minijinja_to_lua(lua, &value);
633                    table.set(name, val)?;
634                }
635
636                Ok(table)
637            },
638        );
639    }
640}