Skip to main content

minijinja_lua/
state.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt,
5    sync::atomic::{AtomicPtr, Ordering},
6};
7
8use minijinja::Value as JinjaValue;
9use mlua::{
10    LuaSerdeExt,
11    prelude::{Lua, LuaError, LuaFunction, LuaUserData, LuaValue, LuaVariadic},
12};
13
14use crate::convert::{
15    auto_escape_to_lua,
16    lua_args_to_minijinja,
17    lua_to_minijinja,
18    minijinja_to_lua,
19    undefined_behavior_to_lua,
20};
21
22thread_local! {
23    static CURRENT_LUA: AtomicPtr<Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
24}
25
26/// A [`mlua::UserData`] wrapper around a [`minijinja::State`]. This is passed to
27/// filters and other callbacks in the Jinja environment. It can only be
28/// initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
29#[derive(Debug)]
30pub struct LuaState<'scope> {
31    state: &'scope minijinja::State<'scope, 'scope>,
32}
33
34impl<'scope> LuaState<'scope> {
35    /// Get a new state
36    pub(crate) fn new(state: &'scope minijinja::State<'scope, 'scope>) -> Self {
37        Self { state }
38    }
39}
40
41impl<'scope> fmt::Display for LuaState<'scope> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        write!(f, "State")
44    }
45}
46
47impl<'scope> LuaUserData for LuaState<'scope> {
48    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
49        fields.add_meta_field("__name", "state");
50    }
51
52    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
53        // The name of the current template
54        methods.add_method(
55            "name",
56            |_, this: &LuaState<'scope>, _: LuaValue| -> Result<String, _> {
57                Ok(this.state.name().to_string())
58            },
59        );
60
61        // The current auto escape flag
62        methods.add_method(
63            "auto_escape",
64            |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
65                Ok(auto_escape_to_lua(this.state.auto_escape()))
66            },
67        );
68
69        // The current undefined behavior
70        methods.add_method(
71            "undefined_behavior",
72            |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
73                Ok(undefined_behavior_to_lua(this.state.undefined_behavior()))
74            },
75        );
76
77        // The name of the current block
78        methods.add_method(
79            "current_block",
80            |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<&str>, _> {
81                Ok(this.state.current_block())
82            },
83        );
84
85        // Lookup a value by key in the current context
86        methods.add_method(
87            "lookup",
88            |lua: &Lua, this: &LuaState<'scope>, name: String| -> Result<Option<LuaValue>, _> {
89                // Since the context may contain dynamic objects, convert the returned value
90                // through the custom layer before returning.
91                Ok(this
92                    .state
93                    .lookup(&name)
94                    .and_then(|v| minijinja_to_lua(lua, &v)))
95            },
96        );
97
98        // Call the named macro with the provided args.
99        methods.add_method(
100            "call_macro",
101            |lua: &Lua,
102             this: &LuaState<'scope>,
103             (name, args): (String, LuaVariadic<LuaValue>)|
104             -> Result<String, LuaError> {
105                let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
106
107                this.state
108                    .call_macro(&name, &args)
109                    .map_err(LuaError::external)
110            },
111        );
112
113        // A list of exported variables
114        methods.add_method(
115            "exports",
116            |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Vec<&str>, _> {
117                Ok(this.state.exports())
118            },
119        );
120
121        // A list of all known variables
122        methods.add_method(
123            "known_variables",
124            |_,
125             this: &LuaState<'scope>,
126             _: LuaValue|
127             -> Result<Vec<std::borrow::Cow<'_, str>>, _> {
128                Ok(this.state.known_variables())
129            },
130        );
131
132        // Apply the named filter with the provided args
133        methods.add_method(
134            "apply_filter",
135            |lua: &Lua,
136             this: &LuaState<'scope>,
137             (filter, args): (String, LuaVariadic<LuaValue>)|
138             -> Result<Option<LuaValue>, LuaError> {
139                let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
140
141                // Since the context may contain dynamic objects, convert the returned value
142                // through the custom layer before returning.
143                this.state
144                    .apply_filter(&filter, &args)
145                    .map(|v| minijinja_to_lua(lua, &v))
146                    .map_err(LuaError::external)
147            },
148        );
149
150        // Perform the named test with the provided args
151        methods.add_method(
152            "perform_test",
153            |lua: &Lua,
154             this: &LuaState<'scope>,
155             (test, args): (String, LuaVariadic<LuaValue>)|
156             -> Result<bool, LuaError> {
157                let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
158
159                this.state
160                    .perform_test(&test, &args)
161                    .map_err(LuaError::external)
162            },
163        );
164
165        // Format a value to a string
166        methods.add_method(
167            "format",
168            |lua: &Lua, this: &LuaState<'scope>, val: LuaValue| -> Result<String, LuaError> {
169                let val = lua_to_minijinja(lua, &val).unwrap_or(JinjaValue::UNDEFINED);
170
171                this.state.format(val).map_err(LuaError::external)
172            },
173        );
174
175        // A tuple of the current and remaining fuel usage
176        methods.add_method(
177            "fuel_levels",
178            |lua: &Lua, this: &LuaState<'scope>, _: LuaValue| -> Result<LuaValue, _> {
179                lua.to_value(&this.state.fuel_levels())
180            },
181        );
182
183        // Get a temp value.
184        // See: https://docs.rs/minijinja/latest/minijinja/struct.State.html#method.get_temp
185        methods.add_method(
186            "get_temp",
187            |lua: &Lua,
188             this: &LuaState<'scope>,
189             name: String|
190             -> Result<Option<LuaValue>, LuaError> {
191                // Since the context may contain dynamic objects, convert the returned value
192                // through the custom layer before returning.
193                Ok(this
194                    .state
195                    .get_temp(&name)
196                    .and_then(|v| minijinja_to_lua(lua, &v)))
197            },
198        );
199
200        // Set a temp value and return the old value
201        methods.add_method(
202            "set_temp",
203            |lua: &Lua,
204             this: &LuaState<'scope>,
205             (name, val): (String, LuaValue)|
206             -> Result<Option<LuaValue>, LuaError> {
207                if let Some(val) = lua_to_minijinja(lua, &val) {
208                    Ok(this
209                        .state
210                        .set_temp(&name, val)
211                        .and_then(|v| minijinja_to_lua(lua, &v)))
212                } else {
213                    Err(LuaError::ToLuaConversionError {
214                        from: val.type_name().to_string(),
215                        to: "minijinja::Value",
216                        message: None,
217                    })
218                }
219            },
220        );
221
222        // Get a temp value or call `func` to add the value
223        methods.add_method(
224            "get_or_set_temp",
225            |lua: &Lua,
226             this: &LuaState<'scope>,
227             (name, func): (String, LuaFunction)|
228             -> Result<Option<LuaValue>, LuaError> {
229                let val = match this.state.get_temp(&name) {
230                    Some(v) => v,
231                    None => {
232                        let val = func.call::<LuaValue>(LuaValue::Nil)?;
233
234                        if let Some(val) = lua_to_minijinja(lua, &val) {
235                            this.state.set_temp(&name, val.clone());
236                            val
237                        } else {
238                            return Err(LuaError::ToLuaConversionError {
239                                from: val.type_name().to_string(),
240                                to: "minijinja::Value",
241                                message: None,
242                            });
243                        }
244                    },
245                };
246
247                Ok(minijinja_to_lua(lua, &val))
248            },
249        );
250    }
251}
252
253/// Allow access to a [`mlua::Lua`] instance across a `Send + Sync` boundary in module mode.
254///
255/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
256pub(crate) fn with_lua<R, F: FnOnce(&Lua) -> Result<R, LuaError>>(f: F) -> Result<R, LuaError> {
257    CURRENT_LUA.with(|handle| {
258        let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const Lua).as_ref() };
259
260        match ptr {
261            Some(lua) => f(lua),
262            None => Err(LuaError::runtime(
263                "mlua::Lua state accessed outside of a render context.",
264            )),
265        }
266    })
267}
268
269/// Invokes a function with the state stashed away.
270///
271/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
272pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &Lua, f: F) -> R {
273    let old_handle =
274        CURRENT_LUA.with(|handle| handle.swap(lua as *const Lua as *mut Lua, Ordering::Relaxed));
275
276    let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
277
278    CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
279    match rv {
280        Ok(rv) => rv,
281        Err(payload) => std::panic::resume_unwind(payload),
282    }
283}