1use std::{
4 fmt,
5 sync::atomic::{AtomicPtr, Ordering},
6};
7
8use minijinja::Value as JinjaValue;
9use mlua::{
10 LuaSerdeExt,
11 prelude::{Lua, LuaError, LuaFunction, LuaUserData, LuaValue, LuaVariadic},
12};
13
14use crate::convert::{
15 auto_escape_to_lua,
16 lua_args_to_minijinja,
17 lua_to_minijinja,
18 minijinja_to_lua,
19 undefined_behavior_to_lua,
20};
21
22thread_local! {
23 static CURRENT_LUA: AtomicPtr<Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
24}
25
26#[derive(Debug)]
30pub struct LuaState<'scope> {
31 state: &'scope minijinja::State<'scope, 'scope>,
32}
33
34impl<'scope> LuaState<'scope> {
35 pub(crate) fn new(state: &'scope minijinja::State<'scope, 'scope>) -> Self {
37 Self { state }
38 }
39}
40
41impl<'scope> fmt::Display for LuaState<'scope> {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 write!(f, "State")
44 }
45}
46
47impl<'scope> LuaUserData for LuaState<'scope> {
48 fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
49 fields.add_meta_field("__name", "state");
50 }
51
52 fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
53 methods.add_method(
55 "name",
56 |_, this: &LuaState<'scope>, _: LuaValue| -> Result<String, _> {
57 Ok(this.state.name().to_string())
58 },
59 );
60
61 methods.add_method(
63 "auto_escape",
64 |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
65 Ok(auto_escape_to_lua(this.state.auto_escape()))
66 },
67 );
68
69 methods.add_method(
71 "undefined_behavior",
72 |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
73 Ok(undefined_behavior_to_lua(this.state.undefined_behavior()))
74 },
75 );
76
77 methods.add_method(
79 "current_block",
80 |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Option<&str>, _> {
81 Ok(this.state.current_block())
82 },
83 );
84
85 methods.add_method(
87 "lookup",
88 |lua: &Lua, this: &LuaState<'scope>, name: String| -> Result<Option<LuaValue>, _> {
89 Ok(this
92 .state
93 .lookup(&name)
94 .and_then(|v| minijinja_to_lua(lua, &v)))
95 },
96 );
97
98 methods.add_method(
100 "call_macro",
101 |lua: &Lua,
102 this: &LuaState<'scope>,
103 (name, args): (String, LuaVariadic<LuaValue>)|
104 -> Result<String, LuaError> {
105 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
106
107 this.state
108 .call_macro(&name, &args)
109 .map_err(LuaError::external)
110 },
111 );
112
113 methods.add_method(
115 "exports",
116 |_, this: &LuaState<'scope>, _: LuaValue| -> Result<Vec<&str>, _> {
117 Ok(this.state.exports())
118 },
119 );
120
121 methods.add_method(
123 "known_variables",
124 |_,
125 this: &LuaState<'scope>,
126 _: LuaValue|
127 -> Result<Vec<std::borrow::Cow<'_, str>>, _> {
128 Ok(this.state.known_variables())
129 },
130 );
131
132 methods.add_method(
134 "apply_filter",
135 |lua: &Lua,
136 this: &LuaState<'scope>,
137 (filter, args): (String, LuaVariadic<LuaValue>)|
138 -> Result<Option<LuaValue>, LuaError> {
139 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
140
141 this.state
144 .apply_filter(&filter, &args)
145 .map(|v| minijinja_to_lua(lua, &v))
146 .map_err(LuaError::external)
147 },
148 );
149
150 methods.add_method(
152 "perform_test",
153 |lua: &Lua,
154 this: &LuaState<'scope>,
155 (test, args): (String, LuaVariadic<LuaValue>)|
156 -> Result<bool, LuaError> {
157 let args: Vec<JinjaValue> = lua_args_to_minijinja(lua, args, true);
158
159 this.state
160 .perform_test(&test, &args)
161 .map_err(LuaError::external)
162 },
163 );
164
165 methods.add_method(
167 "format",
168 |lua: &Lua, this: &LuaState<'scope>, val: LuaValue| -> Result<String, LuaError> {
169 let val = lua_to_minijinja(lua, &val).unwrap_or(JinjaValue::UNDEFINED);
170
171 this.state.format(val).map_err(LuaError::external)
172 },
173 );
174
175 methods.add_method(
177 "fuel_levels",
178 |lua: &Lua, this: &LuaState<'scope>, _: LuaValue| -> Result<LuaValue, _> {
179 lua.to_value(&this.state.fuel_levels())
180 },
181 );
182
183 methods.add_method(
186 "get_temp",
187 |lua: &Lua,
188 this: &LuaState<'scope>,
189 name: String|
190 -> Result<Option<LuaValue>, LuaError> {
191 Ok(this
194 .state
195 .get_temp(&name)
196 .and_then(|v| minijinja_to_lua(lua, &v)))
197 },
198 );
199
200 methods.add_method(
202 "set_temp",
203 |lua: &Lua,
204 this: &LuaState<'scope>,
205 (name, val): (String, LuaValue)|
206 -> Result<Option<LuaValue>, LuaError> {
207 if let Some(val) = lua_to_minijinja(lua, &val) {
208 Ok(this
209 .state
210 .set_temp(&name, val)
211 .and_then(|v| minijinja_to_lua(lua, &v)))
212 } else {
213 Err(LuaError::ToLuaConversionError {
214 from: val.type_name().to_string(),
215 to: "minijinja::Value",
216 message: None,
217 })
218 }
219 },
220 );
221
222 methods.add_method(
224 "get_or_set_temp",
225 |lua: &Lua,
226 this: &LuaState<'scope>,
227 (name, func): (String, LuaFunction)|
228 -> Result<Option<LuaValue>, LuaError> {
229 let val = match this.state.get_temp(&name) {
230 Some(v) => v,
231 None => {
232 let val = func.call::<LuaValue>(LuaValue::Nil)?;
233
234 if let Some(val) = lua_to_minijinja(lua, &val) {
235 this.state.set_temp(&name, val.clone());
236 val
237 } else {
238 return Err(LuaError::ToLuaConversionError {
239 from: val.type_name().to_string(),
240 to: "minijinja::Value",
241 message: None,
242 });
243 }
244 },
245 };
246
247 Ok(minijinja_to_lua(lua, &val))
248 },
249 );
250 }
251}
252
253pub(crate) fn with_lua<R, F: FnOnce(&Lua) -> Result<R, LuaError>>(f: F) -> Result<R, LuaError> {
257 CURRENT_LUA.with(|handle| {
258 let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const Lua).as_ref() };
259
260 match ptr {
261 Some(lua) => f(lua),
262 None => Err(LuaError::runtime(
263 "mlua::Lua state accessed outside of a render context.",
264 )),
265 }
266 })
267}
268
269pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &Lua, f: F) -> R {
273 let old_handle =
274 CURRENT_LUA.with(|handle| handle.swap(lua as *const Lua as *mut Lua, Ordering::Relaxed));
275
276 let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
277
278 CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
279 match rv {
280 Ok(rv) => rv,
281 Err(payload) => std::panic::resume_unwind(payload),
282 }
283}