Skip to main content

minijinja_lua/
state.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt,
5    ops::{Deref, DerefMut},
6    sync::atomic::{AtomicPtr, Ordering},
7};
8
9use minijinja::Value as JinjaValue;
10use mlua::LuaSerdeExt;
11
12use crate::convert::{
13    LuaAutoEscape,
14    LuaUndefinedBehavior,
15    lua_args_to_minijinja,
16    lua_to_minijinja,
17    minijinja_to_lua,
18};
19
20thread_local! {
21    static CURRENT_LUA: AtomicPtr<mlua::Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
22}
23
24trait LuaState<'template, 'env> {
25    fn state(&self) -> &minijinja::State<'template, 'env>;
26}
27
28/// A [`mlua::UserData`] wrapper around a [`minijinja::State`]. This is passed to
29/// filters and other callbacks in the Jinja environment. It can only be
30/// initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
31#[derive(Debug)]
32pub struct LuaStateRef<'scope, 'template, 'env>(&'scope minijinja::State<'template, 'env>);
33
34impl<'scope, 'template, 'env> From<&'scope minijinja::State<'template, 'env>>
35    for LuaStateRef<'scope, 'template, 'env>
36{
37    fn from(value: &'scope minijinja::State<'template, 'env>) -> Self {
38        LuaStateRef(value)
39    }
40}
41
42impl<'scope, 'template, 'env> From<LuaStateRef<'scope, 'template, 'env>>
43    for &'scope minijinja::State<'template, 'env>
44{
45    fn from(value: LuaStateRef<'scope, 'template, 'env>) -> Self {
46        value.0
47    }
48}
49
50impl<'scope, 'template, 'env> Deref for LuaStateRef<'scope, 'template, 'env> {
51    type Target = minijinja::State<'template, 'env>;
52
53    fn deref(&self) -> &Self::Target {
54        self.0
55    }
56}
57
58impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateRef<'scope, 'template, 'env> {
59    fn state(&self) -> &minijinja::State<'template, 'env> {
60        self.0
61    }
62}
63
64impl<'scope, 'template, 'env> fmt::Display for LuaStateRef<'scope, 'template, 'env> {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        write!(f, "State")
67    }
68}
69
70impl<'scope, 'template, 'env> mlua::UserData for LuaStateRef<'scope, 'template, 'env> {
71    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
72        fields.add_meta_field("__name", "state");
73    }
74
75    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
76        add_common_methods(methods);
77    }
78}
79
80/// A [`mlua::UserData`] wrapper around a mutable [`minijinja::State`]. This is passed to
81/// the callback provided to [`LuaEnvironment.render_captured`](crate::environment::LuaEnvironment).
82/// It can only be initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
83#[derive(Debug)]
84pub struct LuaStateMut<'scope, 'template, 'env>(&'scope mut minijinja::State<'template, 'env>);
85
86impl<'scope, 'template, 'env> LuaStateMut<'scope, 'template, 'env> {
87    fn state_mut(&mut self) -> &mut minijinja::State<'template, 'env> {
88        self.0
89    }
90}
91
92impl<'scope, 'template, 'env> From<&'scope mut minijinja::State<'template, 'env>>
93    for LuaStateMut<'scope, 'template, 'env>
94{
95    fn from(value: &'scope mut minijinja::State<'template, 'env>) -> Self {
96        LuaStateMut(value)
97    }
98}
99
100impl<'scope, 'template, 'env> From<LuaStateMut<'scope, 'template, 'env>>
101    for &'scope mut minijinja::State<'template, 'env>
102{
103    fn from(value: LuaStateMut<'scope, 'template, 'env>) -> Self {
104        value.0
105    }
106}
107
108impl<'scope, 'template, 'env> Deref for LuaStateMut<'scope, 'template, 'env> {
109    type Target = minijinja::State<'template, 'env>;
110
111    fn deref(&self) -> &Self::Target {
112        self.0
113    }
114}
115
116impl<'scope, 'template, 'env> DerefMut for LuaStateMut<'scope, 'template, 'env> {
117    fn deref_mut(&mut self) -> &mut Self::Target {
118        self.0
119    }
120}
121
122impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateMut<'scope, 'template, 'env> {
123    fn state(&self) -> &minijinja::State<'template, 'env> {
124        self.0
125    }
126}
127
128impl<'scope, 'template, 'env> fmt::Display for LuaStateMut<'scope, 'template, 'env> {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "State")
131    }
132}
133
134impl<'scope, 'template, 'env> mlua::UserData for LuaStateMut<'scope, 'template, 'env> {
135    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
136        fields.add_meta_field("__name", "state");
137    }
138
139    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
140        add_common_methods(methods);
141
142        // Render the named block
143        methods.add_method_mut(
144            "render_block",
145            |_, this, block: String| -> mlua::Result<String> {
146                this.state_mut()
147                    .render_block(&block)
148                    .map_err(mlua::Error::external)
149            },
150        );
151    }
152}
153
154/// A helper to add methods shared between [`LuaStateRef`] and [`LuaStateMut`].
155fn add_common_methods<'template, 'env, S, M>(methods: &mut M)
156where
157    S: LuaState<'template, 'env>,
158    M: mlua::UserDataMethods<S>,
159    'env: 'template,
160{
161    // The name of the current template
162    methods.add_method("name", |_, this, _: ()| -> mlua::Result<String> {
163        Ok(this.state().name().to_string())
164    });
165
166    // The current auto escape flag
167    methods.add_method(
168        "auto_escape",
169        |_, this, _: ()| -> mlua::Result<LuaAutoEscape> { Ok(this.state().auto_escape().into()) },
170    );
171
172    // The current undefined behavior
173    methods.add_method(
174        "undefined_behavior",
175        |_, this, _: ()| -> mlua::Result<LuaUndefinedBehavior> {
176            Ok(this.state().undefined_behavior().into())
177        },
178    );
179
180    // The name of the current block
181    methods.add_method(
182        "current_block",
183        |_, this, _: ()| -> mlua::Result<Option<String>> {
184            Ok(this.state().current_block().map(|s| s.to_string()))
185        },
186    );
187
188    // Lookup a value by key in the current context
189    methods.add_method(
190        "lookup",
191        |lua, this, name: String| -> mlua::Result<mlua::MultiValue> {
192            // Since the context may contain dynamic objects, convert the returned value
193            // through the custom layer before returning.
194            Ok(this
195                .state()
196                .lookup(&name)
197                .and_then(|v| minijinja_to_lua(lua, &v))
198                .unwrap_or_default())
199        },
200    );
201
202    // Call the named macro with the provided args.
203    methods.add_method(
204        "call_macro",
205        |lua, this, (name, mut args): (String, mlua::MultiValue)| -> mlua::Result<String> {
206            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
207
208            this.state()
209                .call_macro(&name, &args)
210                .map_err(mlua::Error::external)
211        },
212    );
213
214    // A list of exported variables
215    methods.add_method("exports", |_, this, _: ()| -> mlua::Result<Vec<String>> {
216        Ok(this
217            .state()
218            .exports()
219            .into_iter()
220            .map(|i| i.to_string())
221            .collect())
222    });
223
224    // A list of all known variables
225    methods.add_method(
226        "known_variables",
227        |_, this, _: ()| -> mlua::Result<Vec<String>> {
228            Ok(this
229                .state()
230                .known_variables()
231                .into_iter()
232                .map(|i| i.to_string())
233                .collect())
234        },
235    );
236
237    // Apply the named filter with the provided args
238    methods.add_method(
239        "apply_filter",
240        |lua,
241         this,
242         (filter, mut args): (String, mlua::MultiValue)|
243         -> mlua::Result<mlua::MultiValue> {
244            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
245
246            // Since the context may contain dynamic objects, convert the returned value
247            // through the custom layer before returning.
248            this.state()
249                .apply_filter(&filter, &args)
250                .map(|v| minijinja_to_lua(lua, &v).unwrap_or_default())
251                .map_err(mlua::Error::external)
252        },
253    );
254
255    // Perform the named test with the provided args
256    methods.add_method(
257        "perform_test",
258        |lua, this, (test, mut args): (String, mlua::MultiValue)| -> mlua::Result<bool> {
259            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
260
261            this.state()
262                .perform_test(&test, &args)
263                .map_err(mlua::Error::external)
264        },
265    );
266
267    // Format a value to a string
268    methods.add_method(
269        "format",
270        |lua, this, val: mlua::Value| -> mlua::Result<String> {
271            let val = lua_to_minijinja(lua, &val).unwrap_or_default();
272
273            this.state().format(val).map_err(mlua::Error::external)
274        },
275    );
276
277    // A tuple of the current and remaining fuel usage
278    methods.add_method(
279        "fuel_levels",
280        |lua, this, _: ()| -> mlua::Result<mlua::Value> {
281            lua.to_value(&this.state().fuel_levels())
282        },
283    );
284
285    // Get a temp value.
286    // See: https://docs.rs/minijinja/latest/minijinja/struct.State.html#method.get_temp
287    methods.add_method(
288        "get_temp",
289        |lua, this, name: String| -> mlua::Result<mlua::MultiValue> {
290            // Since the context may contain dynamic objects, convert the returned value
291            // through the custom layer before returning.
292            Ok(this
293                .state()
294                .get_temp(&name)
295                .and_then(|v| minijinja_to_lua(lua, &v))
296                .unwrap_or_default())
297        },
298    );
299
300    // Set a temp value and return the old value
301    methods.add_method(
302        "set_temp",
303        |lua, this, (name, val): (String, mlua::Value)| -> mlua::Result<mlua::MultiValue> {
304            if let Some(val) = lua_to_minijinja(lua, &val) {
305                Ok(this
306                    .state()
307                    .set_temp(&name, val)
308                    .and_then(|v| minijinja_to_lua(lua, &v))
309                    .unwrap_or_default())
310            } else {
311                Err(mlua::Error::FromLuaConversionError {
312                    from: val.type_name(),
313                    to: "minijinja::Value".to_string(),
314                    message: None,
315                })
316            }
317        },
318    );
319
320    // Get a temp value or call `func` to add the value
321    methods.add_method(
322        "get_or_set_temp",
323        |lua, this, (name, func): (String, mlua::Function)| -> mlua::Result<mlua::MultiValue> {
324            let val = match this.state().get_temp(&name) {
325                Some(v) => v,
326                None => {
327                    let val = func.call::<mlua::Value>(mlua::Value::Nil)?;
328
329                    if let Some(val) = lua_to_minijinja(lua, &val) {
330                        this.state().set_temp(&name, val.clone());
331                        val
332                    } else {
333                        return Err(mlua::Error::FromLuaConversionError {
334                            from: val.type_name(),
335                            to: "minijinja::Value".to_string(),
336                            message: None,
337                        });
338                    }
339                },
340            };
341
342            Ok(minijinja_to_lua(lua, &val).unwrap_or_default())
343        },
344    );
345}
346
347/// Allow access to a [`mlua::Lua`] instance across a `Send + Sync` boundary in module mode.
348///
349/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
350pub(crate) fn with_lua<R, F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>>(
351    f: F,
352) -> Result<R, mlua::Error> {
353    CURRENT_LUA.with(|handle| {
354        let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const mlua::Lua).as_ref() };
355
356        match ptr {
357            Some(lua) => f(lua),
358            None => Err(mlua::Error::runtime(
359                "mlua::Lua state accessed outside of a render context.",
360            )),
361        }
362    })
363}
364
365/// Invokes a function with the state stashed away.
366///
367/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
368pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &mlua::Lua, f: F) -> R {
369    let old_handle = CURRENT_LUA
370        .with(|handle| handle.swap(lua as *const mlua::Lua as *mut mlua::Lua, Ordering::Relaxed));
371
372    let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
373
374    CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
375    match rv {
376        Ok(rv) => rv,
377        Err(payload) => std::panic::resume_unwind(payload),
378    }
379}