1use std::{
4 fmt,
5 ops::{Deref, DerefMut},
6 sync::atomic::{AtomicPtr, Ordering},
7};
8
9use minijinja::Value as JinjaValue;
10use mlua::LuaSerdeExt;
11
12use crate::convert::{
13 LuaAutoEscape,
14 LuaUndefinedBehavior,
15 lua_args_to_minijinja,
16 lua_to_minijinja,
17 minijinja_to_lua,
18};
19
20thread_local! {
21 static CURRENT_LUA: AtomicPtr<mlua::Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
22}
23
24trait LuaState<'template, 'env> {
25 fn state(&self) -> &minijinja::State<'template, 'env>;
26}
27
28#[derive(Debug)]
32pub struct LuaStateRef<'scope, 'template, 'env>(&'scope minijinja::State<'template, 'env>);
33
34impl<'scope, 'template, 'env> From<&'scope minijinja::State<'template, 'env>>
35 for LuaStateRef<'scope, 'template, 'env>
36{
37 fn from(value: &'scope minijinja::State<'template, 'env>) -> Self {
38 LuaStateRef(value)
39 }
40}
41
42impl<'scope, 'template, 'env> From<LuaStateRef<'scope, 'template, 'env>>
43 for &'scope minijinja::State<'template, 'env>
44{
45 fn from(value: LuaStateRef<'scope, 'template, 'env>) -> Self {
46 value.0
47 }
48}
49
50impl<'scope, 'template, 'env> Deref for LuaStateRef<'scope, 'template, 'env> {
51 type Target = minijinja::State<'template, 'env>;
52
53 fn deref(&self) -> &Self::Target {
54 self.0
55 }
56}
57
58impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateRef<'scope, 'template, 'env> {
59 fn state(&self) -> &minijinja::State<'template, 'env> {
60 self.0
61 }
62}
63
64impl<'scope, 'template, 'env> fmt::Display for LuaStateRef<'scope, 'template, 'env> {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 write!(f, "State")
67 }
68}
69
70impl<'scope, 'template, 'env> mlua::UserData for LuaStateRef<'scope, 'template, 'env> {
71 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
72 fields.add_meta_field("__name", "state");
73 }
74
75 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
76 add_common_methods(methods);
77 }
78}
79
80#[derive(Debug)]
84pub struct LuaStateMut<'scope, 'template, 'env>(&'scope mut minijinja::State<'template, 'env>);
85
86impl<'scope, 'template, 'env> LuaStateMut<'scope, 'template, 'env> {
87 fn state_mut(&mut self) -> &mut minijinja::State<'template, 'env> {
88 self.0
89 }
90}
91
92impl<'scope, 'template, 'env> From<&'scope mut minijinja::State<'template, 'env>>
93 for LuaStateMut<'scope, 'template, 'env>
94{
95 fn from(value: &'scope mut minijinja::State<'template, 'env>) -> Self {
96 LuaStateMut(value)
97 }
98}
99
100impl<'scope, 'template, 'env> From<LuaStateMut<'scope, 'template, 'env>>
101 for &'scope mut minijinja::State<'template, 'env>
102{
103 fn from(value: LuaStateMut<'scope, 'template, 'env>) -> Self {
104 value.0
105 }
106}
107
108impl<'scope, 'template, 'env> Deref for LuaStateMut<'scope, 'template, 'env> {
109 type Target = minijinja::State<'template, 'env>;
110
111 fn deref(&self) -> &Self::Target {
112 self.0
113 }
114}
115
116impl<'scope, 'template, 'env> DerefMut for LuaStateMut<'scope, 'template, 'env> {
117 fn deref_mut(&mut self) -> &mut Self::Target {
118 self.0
119 }
120}
121
122impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateMut<'scope, 'template, 'env> {
123 fn state(&self) -> &minijinja::State<'template, 'env> {
124 self.0
125 }
126}
127
128impl<'scope, 'template, 'env> fmt::Display for LuaStateMut<'scope, 'template, 'env> {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(f, "State")
131 }
132}
133
134impl<'scope, 'template, 'env> mlua::UserData for LuaStateMut<'scope, 'template, 'env> {
135 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
136 fields.add_meta_field("__name", "state");
137 }
138
139 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
140 add_common_methods(methods);
141
142 methods.add_method_mut(
144 "render_block",
145 |_, this, block: mlua::BorrowedStr| -> mlua::Result<String> {
146 this.state_mut()
147 .render_block(&block)
148 .map_err(mlua::Error::external)
149 },
150 );
151 }
152}
153
154fn add_common_methods<'template, 'env, S, M>(methods: &mut M)
156where
157 S: LuaState<'template, 'env>,
158 M: mlua::UserDataMethods<S>,
159 'env: 'template,
160{
161 methods.add_method("name", |_, this, ()| -> mlua::Result<String> {
163 Ok(this.state().name().to_string())
164 });
165
166 methods.add_method(
168 "auto_escape",
169 |_, this, ()| -> mlua::Result<LuaAutoEscape> { Ok(this.state().auto_escape().into()) },
170 );
171
172 methods.add_method(
174 "undefined_behavior",
175 |_, this, ()| -> mlua::Result<LuaUndefinedBehavior> {
176 Ok(this.state().undefined_behavior().into())
177 },
178 );
179
180 methods.add_method(
182 "current_block",
183 |_, this, ()| -> mlua::Result<Option<String>> {
184 Ok(this.state().current_block().map(|s| s.to_string()))
185 },
186 );
187
188 methods.add_method(
190 "lookup",
191 |lua, this, name: mlua::BorrowedStr| -> mlua::Result<mlua::MultiValue> {
192 Ok(this
195 .state()
196 .lookup(&name)
197 .and_then(|v| minijinja_to_lua(lua, &v))
198 .unwrap_or_default())
199 },
200 );
201
202 methods.add_method(
204 "call_macro",
205 |lua,
206 this,
207 (name, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
208 -> mlua::Result<String> {
209 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
210
211 this.state()
212 .call_macro(&name, &args)
213 .map_err(mlua::Error::external)
214 },
215 );
216
217 methods.add_method("exports", |_, this, ()| -> mlua::Result<Vec<String>> {
219 Ok(this
220 .state()
221 .exports()
222 .into_iter()
223 .map(|i| i.to_string())
224 .collect())
225 });
226
227 methods.add_method(
229 "known_variables",
230 |_, this, ()| -> mlua::Result<Vec<String>> {
231 Ok(this
232 .state()
233 .known_variables()
234 .into_iter()
235 .map(|i| i.to_string())
236 .collect())
237 },
238 );
239
240 methods.add_method(
242 "apply_filter",
243 |lua,
244 this,
245 (filter, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
246 -> mlua::Result<mlua::MultiValue> {
247 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
248
249 this.state()
252 .apply_filter(&filter, &args)
253 .map(|v| minijinja_to_lua(lua, &v).unwrap_or_default())
254 .map_err(mlua::Error::external)
255 },
256 );
257
258 methods.add_method(
260 "perform_test",
261 |lua,
262 this,
263 (test, mut args): (mlua::BorrowedStr, mlua::MultiValue)|
264 -> mlua::Result<bool> {
265 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
266
267 this.state()
268 .perform_test(&test, &args)
269 .map_err(mlua::Error::external)
270 },
271 );
272
273 methods.add_method(
275 "format",
276 |lua, this, val: mlua::Value| -> mlua::Result<String> {
277 let val = lua_to_minijinja(lua, &val).unwrap_or_default();
278
279 this.state().format(val).map_err(mlua::Error::external)
280 },
281 );
282
283 methods.add_method(
285 "fuel_levels",
286 |lua, this, ()| -> mlua::Result<mlua::Value> { lua.to_value(&this.state().fuel_levels()) },
287 );
288
289 methods.add_method(
292 "get_temp",
293 |lua, this, name: mlua::BorrowedStr| -> mlua::Result<mlua::MultiValue> {
294 Ok(this
297 .state()
298 .get_temp(&name)
299 .and_then(|v| minijinja_to_lua(lua, &v))
300 .unwrap_or_default())
301 },
302 );
303
304 methods.add_method(
306 "set_temp",
307 |lua,
308 this,
309 (name, val): (mlua::BorrowedStr, mlua::Value)|
310 -> mlua::Result<mlua::MultiValue> {
311 if let Some(val) = lua_to_minijinja(lua, &val) {
312 Ok(this
313 .state()
314 .set_temp(&name, val)
315 .and_then(|v| minijinja_to_lua(lua, &v))
316 .unwrap_or_default())
317 } else {
318 Err(mlua::Error::FromLuaConversionError {
319 from: val.type_name(),
320 to: "minijinja::Value".to_string(),
321 message: None,
322 })
323 }
324 },
325 );
326
327 methods.add_method(
329 "get_or_set_temp",
330 |lua,
331 this,
332 (name, func): (mlua::BorrowedStr, mlua::Function)|
333 -> mlua::Result<mlua::MultiValue> {
334 let val = match this.state().get_temp(&name) {
335 Some(v) => v,
336 None => {
337 let val = func.call::<mlua::Value>(mlua::Value::Nil)?;
338
339 if let Some(val) = lua_to_minijinja(lua, &val) {
340 this.state().set_temp(&name, val.clone());
341 val
342 } else {
343 return Err(mlua::Error::FromLuaConversionError {
344 from: val.type_name(),
345 to: "minijinja::Value".to_string(),
346 message: None,
347 });
348 }
349 },
350 };
351
352 Ok(minijinja_to_lua(lua, &val).unwrap_or_default())
353 },
354 );
355}
356
357pub(crate) fn with_lua<R, F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>>(
361 f: F,
362) -> Result<R, mlua::Error> {
363 CURRENT_LUA.with(|handle| {
364 let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const mlua::Lua).as_ref() };
369
370 match ptr {
371 Some(lua) => f(lua),
372 None => Err(mlua::Error::runtime(
373 "mlua::Lua state accessed outside of a render context.",
374 )),
375 }
376 })
377}
378
379pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &mlua::Lua, f: F) -> R {
383 let old_handle = CURRENT_LUA
384 .with(|handle| handle.swap(lua as *const mlua::Lua as *mut mlua::Lua, Ordering::Relaxed));
385
386 let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
387
388 CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
389 match rv {
390 Ok(rv) => rv,
391 Err(payload) => std::panic::resume_unwind(payload),
392 }
393}