Skip to main content

minijinja_lua/
state.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt,
5    sync::atomic::{AtomicPtr, Ordering},
6};
7
8use minijinja::Value as JinjaValue;
9use mlua::LuaSerdeExt;
10
11use crate::convert::{
12    auto_escape_to_lua,
13    lua_args_to_minijinja,
14    lua_to_minijinja,
15    minijinja_to_lua,
16    undefined_behavior_to_lua,
17};
18
19thread_local! {
20    static CURRENT_LUA: AtomicPtr<mlua::Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
21}
22
23trait LuaState<'env, 'render> {
24    fn state(&self) -> &minijinja::State<'env, 'render>;
25}
26
27/// A [`mlua::UserData`] wrapper around a [`minijinja::State`]. This is passed to
28/// filters and other callbacks in the Jinja environment. It can only be
29/// initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
30#[derive(Debug)]
31pub struct LuaStateRef<'scope, 'env, 'render> {
32    state: &'scope minijinja::State<'env, 'render>,
33}
34
35impl<'scope, 'env, 'render> LuaStateRef<'scope, 'env, 'render> {
36    /// Get a new state
37    pub(crate) fn new(state: &'scope minijinja::State<'env, 'render>) -> Self {
38        Self { state }
39    }
40}
41
42impl<'scope, 'env, 'render> LuaState<'env, 'render> for LuaStateRef<'scope, 'env, 'render> {
43    fn state(&self) -> &minijinja::State<'env, 'render> {
44        self.state
45    }
46}
47
48impl<'scope, 'env, 'render> fmt::Display for LuaStateRef<'scope, 'env, 'render> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "State")
51    }
52}
53
54impl<'scope, 'env, 'render> mlua::UserData for LuaStateRef<'scope, 'env, 'render> {
55    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
56        fields.add_meta_field("__name", "state");
57    }
58
59    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
60        add_common_methods(methods);
61    }
62}
63
64/// A [`mlua::UserData`] wrapper around a mutable [`minijinja::State`]. This is passed to
65/// the callback provided to [`LuaEnvironment.render_captured`](crate::environment::LuaEnvironment).
66/// It can only be initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
67#[derive(Debug)]
68pub struct LuaStateMut<'scope, 'env, 'render> {
69    state: &'scope mut minijinja::State<'env, 'render>,
70}
71
72impl<'scope, 'env, 'render> LuaStateMut<'scope, 'env, 'render> {
73    /// Get a new state
74    pub(crate) fn new(state: &'scope mut minijinja::State<'env, 'render>) -> Self {
75        Self { state }
76    }
77
78    fn state_mut(&mut self) -> &mut minijinja::State<'env, 'render> {
79        self.state
80    }
81}
82
83impl<'scope, 'env, 'render> LuaState<'env, 'render> for LuaStateMut<'scope, 'env, 'render> {
84    fn state(&self) -> &minijinja::State<'env, 'render> {
85        self.state
86    }
87}
88
89impl<'scope, 'env, 'render> fmt::Display for LuaStateMut<'scope, 'env, 'render> {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(f, "State")
92    }
93}
94
95impl<'scope, 'env, 'render> mlua::UserData for LuaStateMut<'scope, 'env, 'render> {
96    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
97        fields.add_meta_field("__name", "state");
98    }
99
100    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
101        add_common_methods(methods);
102
103        // Render the named block
104        methods.add_method_mut(
105            "render_block",
106            |_, this, block: String| -> mlua::Result<String> {
107                this.state_mut()
108                    .render_block(&block)
109                    .map_err(mlua::Error::external)
110            },
111        );
112    }
113}
114
115/// A helper to add methods shared between [`LuaStateRef`] and [`LuaStateMut`].
116fn add_common_methods<'env, 'render, S, M>(methods: &mut M)
117where
118    S: LuaState<'env, 'render>,
119    M: mlua::UserDataMethods<S>,
120    'render: 'env,
121{
122    // The name of the current template
123    methods.add_method("name", |_, this, _: mlua::Value| -> mlua::Result<String> {
124        Ok(this.state().name().to_string())
125    });
126
127    // The current auto escape flag
128    methods.add_method(
129        "auto_escape",
130        |_, this, _: mlua::Value| -> mlua::Result<Option<String>> {
131            Ok(auto_escape_to_lua(this.state().auto_escape()))
132        },
133    );
134
135    // The current undefined behavior
136    methods.add_method(
137        "undefined_behavior",
138        |_, this, _: mlua::Value| -> mlua::Result<Option<String>> {
139            Ok(undefined_behavior_to_lua(this.state().undefined_behavior()))
140        },
141    );
142
143    // The name of the current block
144    methods.add_method(
145        "current_block",
146        |_, this, _: mlua::Value| -> mlua::Result<Option<String>> {
147            Ok(this.state().current_block().map(|s| s.to_string()))
148        },
149    );
150
151    // Lookup a value by key in the current context
152    methods.add_method(
153        "lookup",
154        |lua, this, name: String| -> mlua::Result<Option<mlua::Value>> {
155            // Since the context may contain dynamic objects, convert the returned value
156            // through the custom layer before returning.
157            Ok(this
158                .state()
159                .lookup(&name)
160                .and_then(|v| minijinja_to_lua(lua, &v)))
161        },
162    );
163
164    // Call the named macro with the provided args.
165    methods.add_method(
166        "call_macro",
167        |lua, this, (name, args): (String, mlua::Variadic<mlua::Value>)| -> mlua::Result<String> {
168            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
169
170            this.state()
171                .call_macro(&name, &args)
172                .map_err(mlua::Error::external)
173        },
174    );
175
176    // A list of exported variables
177    methods.add_method(
178        "exports",
179        |_, this, _: mlua::Value| -> mlua::Result<Vec<String>> {
180            Ok(this
181                .state()
182                .exports()
183                .into_iter()
184                .map(|i| i.to_string())
185                .collect())
186        },
187    );
188
189    // A list of all known variables
190    methods.add_method(
191        "known_variables",
192        |_, this, _: mlua::Value| -> mlua::Result<Vec<String>> {
193            Ok(this
194                .state()
195                .known_variables()
196                .into_iter()
197                .map(|i| i.to_string())
198                .collect())
199        },
200    );
201
202    // Apply the named filter with the provided args
203    methods.add_method(
204        "apply_filter",
205        |lua,
206         this,
207         (filter, args): (String, mlua::Variadic<mlua::Value>)|
208         -> mlua::Result<Option<mlua::Value>> {
209            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
210
211            // Since the context may contain dynamic objects, convert the returned value
212            // through the custom layer before returning.
213            this.state()
214                .apply_filter(&filter, &args)
215                .map(|v| minijinja_to_lua(lua, &v))
216                .map_err(mlua::Error::external)
217        },
218    );
219
220    // Perform the named test with the provided args
221    methods.add_method(
222        "perform_test",
223        |lua, this, (test, args): (String, mlua::Variadic<mlua::Value>)| -> mlua::Result<bool> {
224            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
225
226            this.state()
227                .perform_test(&test, &args)
228                .map_err(mlua::Error::external)
229        },
230    );
231
232    // Format a value to a string
233    methods.add_method(
234        "format",
235        |lua, this, val: mlua::Value| -> mlua::Result<String> {
236            let val = lua_to_minijinja(lua, &val).unwrap_or(JinjaValue::UNDEFINED);
237
238            this.state().format(val).map_err(mlua::Error::external)
239        },
240    );
241
242    // A tuple of the current and remaining fuel usage
243    methods.add_method(
244        "fuel_levels",
245        |lua, this, _: mlua::Value| -> mlua::Result<mlua::Value> {
246            lua.to_value(&this.state().fuel_levels())
247        },
248    );
249
250    // Get a temp value.
251    // See: https://docs.rs/minijinja/latest/minijinja/struct.State.html#method.get_temp
252    methods.add_method(
253        "get_temp",
254        |lua, this, name: String| -> mlua::Result<Option<mlua::Value>> {
255            // Since the context may contain dynamic objects, convert the returned value
256            // through the custom layer before returning.
257            Ok(this
258                .state()
259                .get_temp(&name)
260                .and_then(|v| minijinja_to_lua(lua, &v)))
261        },
262    );
263
264    // Set a temp value and return the old value
265    methods.add_method(
266        "set_temp",
267        |lua, this, (name, val): (String, mlua::Value)| -> mlua::Result<Option<mlua::Value>> {
268            if let Some(val) = lua_to_minijinja(lua, &val) {
269                Ok(this
270                    .state()
271                    .set_temp(&name, val)
272                    .and_then(|v| minijinja_to_lua(lua, &v)))
273            } else {
274                Err(mlua::Error::FromLuaConversionError {
275                    from: val.type_name(),
276                    to: "minijinja::Value".to_string(),
277                    message: None,
278                })
279            }
280        },
281    );
282
283    // Get a temp value or call `func` to add the value
284    methods.add_method(
285        "get_or_set_temp",
286        |lua, this, (name, func): (String, mlua::Function)| -> mlua::Result<Option<mlua::Value>> {
287            let val = match this.state().get_temp(&name) {
288                Some(v) => v,
289                None => {
290                    let val = func.call::<mlua::Value>(mlua::Value::Nil)?;
291
292                    if let Some(val) = lua_to_minijinja(lua, &val) {
293                        this.state().set_temp(&name, val.clone());
294                        val
295                    } else {
296                        return Err(mlua::Error::FromLuaConversionError {
297                            from: val.type_name(),
298                            to: "minijinja::Value".to_string(),
299                            message: None,
300                        });
301                    }
302                },
303            };
304
305            Ok(minijinja_to_lua(lua, &val))
306        },
307    );
308}
309
310/// Allow access to a [`mlua::Lua`] instance across a `Send + Sync` boundary in module mode.
311///
312/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
313pub(crate) fn with_lua<R, F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>>(
314    f: F,
315) -> Result<R, mlua::Error> {
316    CURRENT_LUA.with(|handle| {
317        let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const mlua::Lua).as_ref() };
318
319        match ptr {
320            Some(lua) => f(lua),
321            None => Err(mlua::Error::runtime(
322                "mlua::Lua state accessed outside of a render context.",
323            )),
324        }
325    })
326}
327
328/// Invokes a function with the state stashed away.
329///
330/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
331pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &mlua::Lua, f: F) -> R {
332    let old_handle = CURRENT_LUA
333        .with(|handle| handle.swap(lua as *const mlua::Lua as *mut mlua::Lua, Ordering::Relaxed));
334
335    let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
336
337    CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
338    match rv {
339        Ok(rv) => rv,
340        Err(payload) => std::panic::resume_unwind(payload),
341    }
342}