1use std::{
4 fmt,
5 ops::{Deref, DerefMut},
6 sync::atomic::{AtomicPtr, Ordering},
7};
8
9use minijinja::Value as JinjaValue;
10use mlua::LuaSerdeExt;
11
12use crate::convert::{
13 LuaAutoEscape,
14 LuaUndefinedBehavior,
15 lua_args_to_minijinja,
16 lua_to_minijinja,
17 minijinja_to_lua,
18};
19
20thread_local! {
21 static CURRENT_LUA: AtomicPtr<mlua::Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
22}
23
24trait LuaState<'template, 'env> {
25 fn state(&self) -> &minijinja::State<'template, 'env>;
26}
27
28#[derive(Debug)]
32pub struct LuaStateRef<'scope, 'template, 'env>(&'scope minijinja::State<'template, 'env>);
33
34impl<'scope, 'template, 'env> From<&'scope minijinja::State<'template, 'env>>
35 for LuaStateRef<'scope, 'template, 'env>
36{
37 fn from(value: &'scope minijinja::State<'template, 'env>) -> Self {
38 LuaStateRef(value)
39 }
40}
41
42impl<'scope, 'template, 'env> From<LuaStateRef<'scope, 'template, 'env>>
43 for &'scope minijinja::State<'template, 'env>
44{
45 fn from(value: LuaStateRef<'scope, 'template, 'env>) -> Self {
46 value.0
47 }
48}
49
50impl<'scope, 'template, 'env> Deref for LuaStateRef<'scope, 'template, 'env> {
51 type Target = minijinja::State<'template, 'env>;
52
53 fn deref(&self) -> &Self::Target {
54 self.0
55 }
56}
57
58impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateRef<'scope, 'template, 'env> {
59 fn state(&self) -> &minijinja::State<'template, 'env> {
60 self.0
61 }
62}
63
64impl<'scope, 'template, 'env> fmt::Display for LuaStateRef<'scope, 'template, 'env> {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 write!(f, "State")
67 }
68}
69
70impl<'scope, 'template, 'env> mlua::UserData for LuaStateRef<'scope, 'template, 'env> {
71 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
72 fields.add_meta_field("__name", "state");
73 }
74
75 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
76 add_common_methods(methods);
77 }
78}
79
80#[derive(Debug)]
84pub struct LuaStateMut<'scope, 'template, 'env>(&'scope mut minijinja::State<'template, 'env>);
85
86impl<'scope, 'template, 'env> LuaStateMut<'scope, 'template, 'env> {
87 fn state_mut(&mut self) -> &mut minijinja::State<'template, 'env> {
88 self.0
89 }
90}
91
92impl<'scope, 'template, 'env> From<&'scope mut minijinja::State<'template, 'env>>
93 for LuaStateMut<'scope, 'template, 'env>
94{
95 fn from(value: &'scope mut minijinja::State<'template, 'env>) -> Self {
96 LuaStateMut(value)
97 }
98}
99
100impl<'scope, 'template, 'env> From<LuaStateMut<'scope, 'template, 'env>>
101 for &'scope mut minijinja::State<'template, 'env>
102{
103 fn from(value: LuaStateMut<'scope, 'template, 'env>) -> Self {
104 value.0
105 }
106}
107
108impl<'scope, 'template, 'env> Deref for LuaStateMut<'scope, 'template, 'env> {
109 type Target = minijinja::State<'template, 'env>;
110
111 fn deref(&self) -> &Self::Target {
112 self.0
113 }
114}
115
116impl<'scope, 'template, 'env> DerefMut for LuaStateMut<'scope, 'template, 'env> {
117 fn deref_mut(&mut self) -> &mut Self::Target {
118 self.0
119 }
120}
121
122impl<'scope, 'template, 'env> LuaState<'template, 'env> for LuaStateMut<'scope, 'template, 'env> {
123 fn state(&self) -> &minijinja::State<'template, 'env> {
124 self.0
125 }
126}
127
128impl<'scope, 'template, 'env> fmt::Display for LuaStateMut<'scope, 'template, 'env> {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(f, "State")
131 }
132}
133
134impl<'scope, 'template, 'env> mlua::UserData for LuaStateMut<'scope, 'template, 'env> {
135 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
136 fields.add_meta_field("__name", "state");
137 }
138
139 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
140 add_common_methods(methods);
141
142 methods.add_method_mut(
144 "render_block",
145 |_, this, block: String| -> mlua::Result<String> {
146 this.state_mut()
147 .render_block(&block)
148 .map_err(mlua::Error::external)
149 },
150 );
151 }
152}
153
154fn add_common_methods<'template, 'env, S, M>(methods: &mut M)
156where
157 S: LuaState<'template, 'env>,
158 M: mlua::UserDataMethods<S>,
159 'env: 'template,
160{
161 methods.add_method("name", |_, this, _: ()| -> mlua::Result<String> {
163 Ok(this.state().name().to_string())
164 });
165
166 methods.add_method(
168 "auto_escape",
169 |_, this, _: ()| -> mlua::Result<LuaAutoEscape> { Ok(this.state().auto_escape().into()) },
170 );
171
172 methods.add_method(
174 "undefined_behavior",
175 |_, this, _: ()| -> mlua::Result<LuaUndefinedBehavior> {
176 Ok(this.state().undefined_behavior().into())
177 },
178 );
179
180 methods.add_method(
182 "current_block",
183 |_, this, _: ()| -> mlua::Result<Option<String>> {
184 Ok(this.state().current_block().map(|s| s.to_string()))
185 },
186 );
187
188 methods.add_method(
190 "lookup",
191 |lua, this, name: String| -> mlua::Result<mlua::MultiValue> {
192 Ok(this
195 .state()
196 .lookup(&name)
197 .and_then(|v| minijinja_to_lua(lua, &v))
198 .unwrap_or_default())
199 },
200 );
201
202 methods.add_method(
204 "call_macro",
205 |lua, this, (name, mut args): (String, mlua::MultiValue)| -> mlua::Result<String> {
206 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
207
208 this.state()
209 .call_macro(&name, &args)
210 .map_err(mlua::Error::external)
211 },
212 );
213
214 methods.add_method("exports", |_, this, _: ()| -> mlua::Result<Vec<String>> {
216 Ok(this
217 .state()
218 .exports()
219 .into_iter()
220 .map(|i| i.to_string())
221 .collect())
222 });
223
224 methods.add_method(
226 "known_variables",
227 |_, this, _: ()| -> mlua::Result<Vec<String>> {
228 Ok(this
229 .state()
230 .known_variables()
231 .into_iter()
232 .map(|i| i.to_string())
233 .collect())
234 },
235 );
236
237 methods.add_method(
239 "apply_filter",
240 |lua,
241 this,
242 (filter, mut args): (String, mlua::MultiValue)|
243 -> mlua::Result<mlua::MultiValue> {
244 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
245
246 this.state()
249 .apply_filter(&filter, &args)
250 .map(|v| minijinja_to_lua(lua, &v).unwrap_or_default())
251 .map_err(mlua::Error::external)
252 },
253 );
254
255 methods.add_method(
257 "perform_test",
258 |lua, this, (test, mut args): (String, mlua::MultiValue)| -> mlua::Result<bool> {
259 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, &mut args, true);
260
261 this.state()
262 .perform_test(&test, &args)
263 .map_err(mlua::Error::external)
264 },
265 );
266
267 methods.add_method(
269 "format",
270 |lua, this, val: mlua::Value| -> mlua::Result<String> {
271 let val = lua_to_minijinja(lua, &val).unwrap_or_default();
272
273 this.state().format(val).map_err(mlua::Error::external)
274 },
275 );
276
277 methods.add_method(
279 "fuel_levels",
280 |lua, this, _: ()| -> mlua::Result<mlua::Value> {
281 lua.to_value(&this.state().fuel_levels())
282 },
283 );
284
285 methods.add_method(
288 "get_temp",
289 |lua, this, name: String| -> mlua::Result<mlua::MultiValue> {
290 Ok(this
293 .state()
294 .get_temp(&name)
295 .and_then(|v| minijinja_to_lua(lua, &v))
296 .unwrap_or_default())
297 },
298 );
299
300 methods.add_method(
302 "set_temp",
303 |lua, this, (name, val): (String, mlua::Value)| -> mlua::Result<mlua::MultiValue> {
304 if let Some(val) = lua_to_minijinja(lua, &val) {
305 Ok(this
306 .state()
307 .set_temp(&name, val)
308 .and_then(|v| minijinja_to_lua(lua, &v))
309 .unwrap_or_default())
310 } else {
311 Err(mlua::Error::FromLuaConversionError {
312 from: val.type_name(),
313 to: "minijinja::Value".to_string(),
314 message: None,
315 })
316 }
317 },
318 );
319
320 methods.add_method(
322 "get_or_set_temp",
323 |lua, this, (name, func): (String, mlua::Function)| -> mlua::Result<mlua::MultiValue> {
324 let val = match this.state().get_temp(&name) {
325 Some(v) => v,
326 None => {
327 let val = func.call::<mlua::Value>(mlua::Value::Nil)?;
328
329 if let Some(val) = lua_to_minijinja(lua, &val) {
330 this.state().set_temp(&name, val.clone());
331 val
332 } else {
333 return Err(mlua::Error::FromLuaConversionError {
334 from: val.type_name(),
335 to: "minijinja::Value".to_string(),
336 message: None,
337 });
338 }
339 },
340 };
341
342 Ok(minijinja_to_lua(lua, &val).unwrap_or_default())
343 },
344 );
345}
346
347pub(crate) fn with_lua<R, F: FnOnce(&mlua::Lua) -> Result<R, mlua::Error>>(
351 f: F,
352) -> Result<R, mlua::Error> {
353 CURRENT_LUA.with(|handle| {
354 let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const mlua::Lua).as_ref() };
355
356 match ptr {
357 Some(lua) => f(lua),
358 None => Err(mlua::Error::runtime(
359 "mlua::Lua state accessed outside of a render context.",
360 )),
361 }
362 })
363}
364
365pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &mlua::Lua, f: F) -> R {
369 let old_handle = CURRENT_LUA
370 .with(|handle| handle.swap(lua as *const mlua::Lua as *mut mlua::Lua, Ordering::Relaxed));
371
372 let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
373
374 CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
375 match rv {
376 Ok(rv) => rv,
377 Err(payload) => std::panic::resume_unwind(payload),
378 }
379}