use std::{
fmt,
sync::atomic::{AtomicPtr, Ordering},
};
use minijinja::Value as JinjaValue;
use mlua::{
LuaSerdeExt,
prelude::{Lua, LuaError, LuaFunction, LuaUserData, LuaValue, LuaVariadic},
};
use crate::convert::{
auto_escape_to_lua,
lua_to_minijinja,
minijinja_to_lua,
undefined_behavior_to_lua,
};
thread_local! {
static CURRENT_LUA: AtomicPtr<Lua> = const { AtomicPtr::new(std::ptr::null_mut()) };
}
#[derive(Debug)]
pub(crate) struct JinjaState<'scope> {
state: &'scope minijinja::State<'scope, 'scope>,
}
impl<'scope> JinjaState<'scope> {
pub(crate) fn new(state: &'scope minijinja::State<'scope, 'scope>) -> Self {
Self { state }
}
}
impl<'scope> fmt::Display for JinjaState<'scope> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "State")
}
}
impl<'scope> LuaUserData for JinjaState<'scope> {
fn add_fields<F: mlua::UserDataFields<Self>>(fields: &mut F) {
fields.add_meta_field("__name", "state");
}
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_method(
"name",
|_, this: &JinjaState<'scope>, _: LuaValue| -> Result<String, _> {
Ok(this.state.name().to_string())
},
);
methods.add_method(
"auto_escape",
|_, this: &JinjaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
Ok(auto_escape_to_lua(this.state.auto_escape()))
},
);
methods.add_method(
"undefined_behavior",
|_, this: &JinjaState<'scope>, _: LuaValue| -> Result<Option<String>, _> {
Ok(undefined_behavior_to_lua(this.state.undefined_behavior()))
},
);
methods.add_method(
"current_block",
|_, this: &JinjaState<'scope>, _: LuaValue| -> Result<Option<&str>, _> {
Ok(this.state.current_block())
},
);
methods.add_method(
"lookup",
|lua: &Lua, this: &JinjaState<'scope>, name: String| -> Result<Option<LuaValue>, _> {
Ok(this
.state
.lookup(&name)
.and_then(|v| minijinja_to_lua(lua, &v)))
},
);
methods.add_method(
"call_macro",
|lua: &Lua,
this: &JinjaState<'scope>,
(name, args): (String, LuaVariadic<LuaValue>)|
-> Result<String, LuaError> {
let args: Vec<JinjaValue> = args
.into_iter()
.map(|v| lua_to_minijinja(lua, &v).unwrap_or(JinjaValue::UNDEFINED))
.collect();
this.state
.call_macro(&name, &args)
.map_err(LuaError::external)
},
);
methods.add_method(
"exports",
|_, this: &JinjaState<'scope>, _: LuaValue| -> Result<Vec<&str>, _> {
Ok(this.state.exports())
},
);
methods.add_method(
"known_variables",
|_,
this: &JinjaState<'scope>,
_: LuaValue|
-> Result<Vec<std::borrow::Cow<'_, str>>, _> {
Ok(this.state.known_variables())
},
);
methods.add_method(
"apply_filter",
|lua: &Lua,
this: &JinjaState<'scope>,
(filter, args): (String, LuaVariadic<LuaValue>)|
-> Result<Option<LuaValue>, LuaError> {
let args: Vec<JinjaValue> = args
.into_iter()
.map(|v| lua_to_minijinja(lua, &v).unwrap_or(JinjaValue::UNDEFINED))
.collect();
this.state
.apply_filter(&filter, &args)
.map(|v| minijinja_to_lua(lua, &v))
.map_err(LuaError::external)
},
);
methods.add_method(
"perform_test",
|lua: &Lua,
this: &JinjaState<'scope>,
(test, args): (String, LuaVariadic<LuaValue>)|
-> Result<bool, LuaError> {
let args: Vec<JinjaValue> = args
.into_iter()
.map(|v| lua_to_minijinja(lua, &v).unwrap_or(JinjaValue::UNDEFINED))
.collect();
this.state
.perform_test(&test, &args)
.map_err(LuaError::external)
},
);
methods.add_method(
"format",
|lua: &Lua, this: &JinjaState<'scope>, val: LuaValue| -> Result<String, LuaError> {
let val = lua_to_minijinja(lua, &val).unwrap_or(JinjaValue::UNDEFINED);
this.state.format(val).map_err(LuaError::external)
},
);
methods.add_method(
"fuel_levels",
|lua: &Lua, this: &JinjaState<'scope>, _: LuaValue| -> Result<LuaValue, _> {
lua.to_value(&this.state.fuel_levels())
},
);
methods.add_method(
"get_temp",
|lua: &Lua,
this: &JinjaState<'scope>,
name: String|
-> Result<Option<LuaValue>, LuaError> {
Ok(this
.state
.get_temp(&name)
.and_then(|v| minijinja_to_lua(lua, &v)))
},
);
methods.add_method(
"set_temp",
|lua: &Lua,
this: &JinjaState<'scope>,
(name, val): (String, LuaValue)|
-> Result<Option<LuaValue>, LuaError> {
if let Some(val) = lua_to_minijinja(lua, &val) {
Ok(this
.state
.set_temp(&name, val)
.and_then(|v| minijinja_to_lua(lua, &v)))
} else {
Err(LuaError::ToLuaConversionError {
from: val.type_name().to_string(),
to: "minijinja::value::Value",
message: None,
})
}
},
);
methods.add_method(
"get_or_set_temp",
|lua: &Lua,
this: &JinjaState<'scope>,
(name, func): (String, LuaFunction)|
-> Result<Option<LuaValue>, LuaError> {
let val = match this.state.get_temp(&name) {
Some(v) => v,
None => {
let val = func.call::<LuaValue>(LuaValue::Nil)?;
if let Some(val) = lua_to_minijinja(lua, &val) {
this.state.set_temp(&name, val.clone());
val
} else {
return Err(LuaError::ToLuaConversionError {
from: val.type_name().to_string(),
to: "minijinja::value::Value",
message: None,
});
}
},
};
Ok(minijinja_to_lua(lua, &val))
},
);
}
}
pub(crate) fn with_lua<R, F: FnOnce(&Lua) -> Result<R, LuaError>>(f: F) -> Result<R, LuaError> {
CURRENT_LUA.with(|handle| {
let ptr = unsafe { (handle.load(Ordering::Relaxed) as *const Lua).as_ref() };
match ptr {
Some(lua) => f(lua),
None => Err(LuaError::runtime(
"mlua::Lua state accessed outside of a render context.",
)),
}
})
}
pub(crate) fn bind_lua<R, F: FnOnce() -> R>(lua: &Lua, f: F) -> R {
let old_handle =
CURRENT_LUA.with(|handle| handle.swap(lua as *const Lua as *mut Lua, Ordering::Relaxed));
let rv = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
CURRENT_LUA.with(|handle| handle.store(old_handle, Ordering::Relaxed));
match rv {
Ok(rv) => rv,
Err(payload) => std::panic::resume_unwind(payload),
}
}