Skip to main content

minijinja_lua/
state.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt,
5    ops::{Deref, DerefMut},
6    sync::atomic::{AtomicPtr, Ordering},
7};
8
9use minijinja::Value as JinjaValue;
10use mlua::LuaSerdeExt;
11
12use crate::convert::{
13    LuaAutoEscape,
14    LuaUndefinedBehavior,
15    lua_args_to_minijinja,
16    lua_to_minijinja,
17    minijinja_to_lua,
18};
19
20thread_local! {
21    static CURRENT_LUA: AtomicPtr<mlua::Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
22}
23
24trait LuaState<'template, 'env> {
25    fn state(&self) -> &minijinja::State<'template, 'env>;
26}
27
28/// A [`mlua::UserData`] wrapper around a [`minijinja::State`]. This is passed to
29/// filters and other callbacks in the Jinja environment. It can only be
30/// initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
31#[derive(Debug)]
32pub struct LuaStateRef<'scope, 'template, 'env>(&'scope minijinja::State<'template, 'env>);
33
34impl<'scope, 'template, 'env> From<&'scope minijinja::State<'template, 'env>>
35    for LuaStateRef<'scope, 'template, 'env>
36{
37    fn from(value: &'scope minijinja::State<'template, 'env>) -> Self {
38        LuaStateRef(value)
39    }
40}
41
42impl<'scope, 'template, 'env> From<LuaStateRef<'scope, 'template, 'env>>
43    for &'scope minijinja::State<'template, 'env>
44{
45    fn from(value: LuaStateRef<'scope, 'template, 'env>) -> Self {
46        value.0
47    }
48}
49
50impl<'scope, 'template, 'env> Deref for LuaStateRef<'scope, 'template, 'env> {
51    type Target = minijinja::State<'template, 'env>;
52
53    fn deref(&self) -> &Self::Target {
54        self.0
55    }
56}
57
58impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateRef<'scope, 'template, 'env> {
59    fn state(&self) -> &minijinja::State<'template, 'env> {
60        self.0
61    }
62}
63
64impl<'scope, 'template, 'env> fmt::Display for LuaStateRef<'scope, 'template, 'env> {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        write!(f, "State")
67    }
68}
69
70impl<'scope, 'template, 'env> mlua::UserData for LuaStateRef<'scope, 'template, 'env> {
71    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
72        fields.add_meta_field("__name", "state");
73    }
74
75    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
76        add_common_methods(methods);
77    }
78}
79
80/// A [`mlua::UserData`] wrapper around a mutable [`minijinja::State`]. This is passed to
81/// the callback provided to [`LuaEnvironment.render_captured`](crate::environment::LuaEnvironment).
82/// It can only be initialized within an [`mlua::Lua::scope`] callback, as it is not `'static`
83#[derive(Debug)]
84pub struct LuaStateMut<'scope, 'template, 'env>(&'scope mut minijinja::State<'template, 'env>);
85
86impl<'scope, 'template, 'env> LuaStateMut<'scope, 'template, 'env> {
87    fn state_mut(&mut self) -> &mut minijinja::State<'template, 'env> {
88        self.0
89    }
90}
91
92impl<'scope, 'template, 'env> From<&'scope mut minijinja::State<'template, 'env>>
93    for LuaStateMut<'scope, 'template, 'env>
94{
95    fn from(value: &'scope mut minijinja::State<'template, 'env>) -> Self {
96        LuaStateMut(value)
97    }
98}
99
100impl<'scope, 'template, 'env> From<LuaStateMut<'scope, 'template, 'env>>
101    for &'scope mut minijinja::State<'template, 'env>
102{
103    fn from(value: LuaStateMut<'scope, 'template, 'env>) -> Self {
104        value.0
105    }
106}
107
108impl<'scope, 'template, 'env> Deref for LuaStateMut<'scope, 'template, 'env> {
109    type Target = minijinja::State<'template, 'env>;
110
111    fn deref(&self) -> &Self::Target {
112        self.0
113    }
114}
115
116impl<'scope, 'template, 'env> DerefMut for LuaStateMut<'scope, 'template, 'env> {
117    fn deref_mut(&mut self) -> &mut Self::Target {
118        self.0
119    }
120}
121
122impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateMut<'scope, 'template, 'env> {
123    fn state(&self) -> &minijinja::State<'template, 'env> {
124        self.0
125    }
126}
127
128impl<'scope, 'template, 'env> fmt::Display for LuaStateMut<'scope, 'template, 'env> {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "State")
131    }
132}
133
134impl<'scope, 'template, 'env> mlua::UserData for LuaStateMut<'scope, 'template, 'env> {
135    fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
136        fields.add_meta_field("__name", "state");
137    }
138
139    fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
140        add_common_methods(methods);
141
142        // Render the named block
143        methods.add_method_mut(
144            "render_block",
145            |_, this, block: mlua::BorrowedStr| -> mlua::Result<String> {
146                this.state_mut()
147                    .render_block(&block)
148                    .map_err(mlua::Error::external)
149            },
150        );
151    }
152}
153
154/// A helper to add methods shared between [`LuaStateRef`] and [`LuaStateMut`].
155fn add_common_methods<'template, 'env, S, M>(methods: &mut M)
156where
157    S: LuaState<'template, 'env>,
158    M: mlua::UserDataMethods<S>,
159    'env: 'template,
160{
161    // The name of the current template
162    methods.add_method("name", |_, this, ()| -> mlua::Result<String> {
163        Ok(this.state().name().to_string())
164    });
165
166    // The current auto escape flag
167    methods.add_method(
168        "auto_escape",
169        |_, this, ()| -> mlua::Result<LuaAutoEscape> { Ok(this.state().auto_escape().into()) },
170    );
171
172    // The current undefined behavior
173    methods.add_method(
174        "undefined_behavior",
175        |_, this, ()| -> mlua::Result<LuaUndefinedBehavior> {
176            Ok(this.state().undefined_behavior().into())
177        },
178    );
179
180    // The name of the current block
181    methods.add_method(
182        "current_block",
183        |_, this, ()| -> mlua::Result<Option<String>> {
184            Ok(this.state().current_block().map(|s| s.to_string()))
185        },
186    );
187
188    // Lookup a value by key in the current context
189    methods.add_method(
190        "lookup",
191        |lua, this, name: mlua::BorrowedStr| -> mlua::Result<mlua::MultiValue> {
192            // Since the context may contain dynamic objects, convert the returned value
193            // through the custom layer before returning.
194            Ok(this
195                .state()
196                .lookup(&name)
197                .and_then(|v| minijinja_to_lua(lua, &v))
198                .unwrap_or_default())
199        },
200    );
201
202    // Call the named macro with the provided args.
203    methods.add_method(
204        "call_macro",
205        |lua,
206         this,
207         (name, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
208         -> mlua::Result<String> {
209            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
210
211            this.state()
212                .call_macro(&name, &args)
213                .map_err(mlua::Error::external)
214        },
215    );
216
217    // A list of exported variables
218    methods.add_method("exports", |_, this, ()| -> mlua::Result<Vec<String>> {
219        Ok(this
220            .state()
221            .exports()
222            .into_iter()
223            .map(|i| i.to_string())
224            .collect())
225    });
226
227    // A list of all known variables
228    methods.add_method(
229        "known_variables",
230        |_, this, ()| -> mlua::Result<Vec<String>> {
231            Ok(this
232                .state()
233                .known_variables()
234                .into_iter()
235                .map(|i| i.to_string())
236                .collect())
237        },
238    );
239
240    // Apply the named filter with the provided args
241    methods.add_method(
242        "apply_filter",
243        |lua,
244         this,
245         (filter, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
246         -> mlua::Result<mlua::MultiValue> {
247            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
248
249            // Since the context may contain dynamic objects, convert the returned value
250            // through the custom layer before returning.
251            this.state()
252                .apply_filter(&filter, &args)
253                .map(|v| minijinja_to_lua(lua, &v).unwrap_or_default())
254                .map_err(mlua::Error::external)
255        },
256    );
257
258    // Perform the named test with the provided args
259    methods.add_method(
260        "perform_test",
261        |lua,
262         this,
263         (test, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
264         -> mlua::Result<bool> {
265            let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
266
267            this.state()
268                .perform_test(&test, &args)
269                .map_err(mlua::Error::external)
270        },
271    );
272
273    // Format a value to a string
274    methods.add_method(
275        "format",
276        |lua, this, val: mlua::Value| -> mlua::Result<String> {
277            let val = lua_to_minijinja(lua, &val).unwrap_or_default();
278
279            this.state().format(val).map_err(mlua::Error::external)
280        },
281    );
282
283    // A tuple of the current and remaining fuel usage
284    methods.add_method(
285        "fuel_levels",
286        |lua, this, ()| -> mlua::Result<mlua::Value> { lua.to_value(&this.state().fuel_levels()) },
287    );
288
289    // Get a temp value.
290    // See: https://docs.rs/minijinja/latest/minijinja/struct.State.html#method.get_temp
291    methods.add_method(
292        "get_temp",
293        |lua, this, name: mlua::BorrowedStr| -> mlua::Result<mlua::MultiValue> {
294            // Since the context may contain dynamic objects, convert the returned value
295            // through the custom layer before returning.
296            Ok(this
297                .state()
298                .get_temp(&name)
299                .and_then(|v| minijinja_to_lua(lua, &v))
300                .unwrap_or_default())
301        },
302    );
303
304    // Set a temp value and return the old value
305    methods.add_method(
306        "set_temp",
307        |lua,
308         this,
309         (name, val): (mlua::BorrowedStr, mlua::Value)|
310         -> mlua::Result<mlua::MultiValue> {
311            if let Some(val) = lua_to_minijinja(lua, &val) {
312                Ok(this
313                    .state()
314                    .set_temp(&name, val)
315                    .and_then(|v| minijinja_to_lua(lua, &v))
316                    .unwrap_or_default())
317            } else {
318                Err(mlua::Error::FromLuaConversionError {
319                    from: val.type_name(),
320                    to: "minijinja::Value".to_string(),
321                    message: None,
322                })
323            }
324        },
325    );
326
327    // Get a temp value or call `func` to add the value
328    methods.add_method(
329        "get_or_set_temp",
330        |lua,
331         this,
332         (name, func): (mlua::BorrowedStr, mlua::Function)|
333         -> mlua::Result<mlua::MultiValue> {
334            let val = match this.state().get_temp(&name) {
335                Some(v) => v,
336                None => {
337                    let val = func.call::<mlua::Value>(mlua::Value::Nil)?;
338
339                    if let Some(val) = lua_to_minijinja(lua, &val) {
340                        this.state().set_temp(&name, val.clone());
341                        val
342                    } else {
343                        return Err(mlua::Error::FromLuaConversionError {
344                            from: val.type_name(),
345                            to: "minijinja::Value".to_string(),
346                            message: None,
347                        });
348                    }
349                },
350            };
351
352            Ok(minijinja_to_lua(lua, &val).unwrap_or_default())
353        },
354    );
355}
356
357/// Allow access to a [`mlua::Lua`] instance across a `Send + Sync` boundary in module mode.
358///
359/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
360pub(crate) fn with_lua<R, F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>>(
361    f: F,
362) -> Result<R, mlua::Error> {
363    CURRENT_LUA.with(|handle| {
364        // SAFETY: The stored Lua pointer is only valid within the context of the `bind_lua` call
365        // on the same thread which stored it. Callers must not attempt or otherwise retain
366        // the `&Lua` reference, or any references to it, that could outlive the scope of the
367        // `bind_lua` call.
368        let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const mlua::Lua).as_ref() };
369
370        match ptr {
371            Some(lua) => f(lua),
372            None => Err(mlua::Error::runtime(
373                "mlua::Lua state accessed outside of a render context.",
374            )),
375        }
376    })
377}
378
379/// Invokes a function with the state stashed away.
380///
381/// This code mirrors the [`minijinja-py`](https://github.com/mitsuhiko/minijinja/blob/29ac0b2936eacf83ebf781c52f4f4ffc3add4c52/minijinja-py/src/state.rs) implementation.
382pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &mlua::Lua, f: F) -> R {
383    let old_handle = CURRENT_LUA
384        .with(|handle| handle.swap(lua as *const mlua::Lua as *mut mlua::Lua, Ordering::Relaxed));
385
386    let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
387
388    CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
389    match rv {
390        Ok(rv) => rv,
391        Err(payload) => std::panic::resume_unwind(payload),
392    }
393}