use std::cell::{Ref, RefCell, RefMut};
use std::marker::PhantomData;
use std::collections::HashMap;
use std::string::String as StdString;
use ffi;
use error::*;
use util::*;
use types::{Callback, LuaRef};
use value::{FromLua, FromLuaMulti, ToLuaMulti};
use lua::Lua;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum MetaMethod {
Add,
Sub,
Mul,
Div,
Mod,
Pow,
Unm,
IDiv,
BAnd,
BOr,
BXor,
BNot,
Shl,
Shr,
Concat,
Len,
Eq,
Lt,
Le,
Index,
NewIndex,
Call,
ToString,
}
pub struct UserDataMethods<'lua, T> {
pub(crate) methods: HashMap<StdString, Callback<'lua>>,
pub(crate) meta_methods: HashMap<MetaMethod, Callback<'lua>>,
pub(crate) _type: PhantomData<T>,
}
impl<'lua, T: UserData> UserDataMethods<'lua, T> {
pub fn add_method<A, R, M>(&mut self, name: &str, method: M)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result<R>,
{
self.methods
.insert(name.to_owned(), Self::box_method(method));
}
pub fn add_method_mut<A, R, M>(&mut self, name: &str, method: M)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a mut T, A) -> Result<R>,
{
self.methods
.insert(name.to_owned(), Self::box_method_mut(method));
}
pub fn add_function<A, R, F>(&mut self, name: &str, function: F)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + FnMut(&'lua Lua, A) -> Result<R>,
{
self.methods
.insert(name.to_owned(), Self::box_function(function));
}
pub fn add_meta_method<A, R, M>(&mut self, meta: MetaMethod, method: M)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result<R>,
{
self.meta_methods.insert(meta, Self::box_method(method));
}
pub fn add_meta_method_mut<A, R, M>(&mut self, meta: MetaMethod, method: M)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a mut T, A) -> Result<R>,
{
self.meta_methods.insert(meta, Self::box_method_mut(method));
}
pub fn add_meta_function<A, R, F>(&mut self, meta: MetaMethod, function: F)
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + FnMut(&'lua Lua, A) -> Result<R>,
{
self.meta_methods.insert(meta, Self::box_function(function));
}
fn box_function<A, R, F>(mut function: F) -> Callback<'lua>
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + FnMut(&'lua Lua, A) -> Result<R>,
{
Box::new(move |lua, args| function(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua))
}
fn box_method<A, R, M>(mut method: M) -> Callback<'lua>
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a T, A) -> Result<R>,
{
Box::new(move |lua, mut args| {
if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?;
let userdata = userdata.borrow::<T>()?;
method(lua, &userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
} else {
Err(Error::FromLuaConversionError {
from: "missing argument",
to: "userdata",
message: None,
})
}
})
}
fn box_method_mut<A, R, M>(mut method: M) -> Callback<'lua>
where
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
M: 'static + for<'a> FnMut(&'lua Lua, &'a mut T, A) -> Result<R>,
{
Box::new(move |lua, mut args| {
if let Some(front) = args.pop_front() {
let userdata = AnyUserData::from_lua(front, lua)?;
let mut userdata = userdata.borrow_mut::<T>()?;
method(lua, &mut userdata, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
} else {
Err(Error::FromLuaConversionError {
from: "missing argument",
to: "userdata",
message: None,
})
}
})
}
}
pub trait UserData: 'static + Sized {
fn add_methods(_methods: &mut UserDataMethods<Self>) {}
}
#[derive(Clone, Debug)]
pub struct AnyUserData<'lua>(pub(crate) LuaRef<'lua>);
impl<'lua> AnyUserData<'lua> {
pub fn is<T: UserData>(&self) -> Result<bool> {
match self.inspect(|_: &RefCell<T>| Ok(())) {
Ok(()) => Ok(true),
Err(Error::UserDataTypeMismatch) => Ok(false),
Err(err) => Err(err),
}
}
pub fn borrow<T: UserData>(&self) -> Result<Ref<T>> {
self.inspect(|cell| Ok(cell.try_borrow().map_err(|_| Error::UserDataBorrowError)?))
}
pub fn borrow_mut<T: UserData>(&self) -> Result<RefMut<T>> {
self.inspect(|cell| {
Ok(cell.try_borrow_mut()
.map_err(|_| Error::UserDataBorrowMutError)?)
})
}
fn inspect<'a, T, R, F>(&'a self, func: F) -> Result<R>
where
T: UserData,
F: FnOnce(&'a RefCell<T>) -> Result<R>,
{
unsafe {
let lua = self.0.lua;
stack_err_guard(lua.state, 0, move || {
check_stack(lua.state, 3);
lua.push_ref(lua.state, &self.0);
lua_assert!(
lua.state,
ffi::lua_getmetatable(lua.state, -1) != 0,
"AnyUserData missing metatable"
);
ffi::lua_rawgeti(
lua.state,
ffi::LUA_REGISTRYINDEX,
lua.userdata_metatable::<T>()? as ffi::lua_Integer,
);
if ffi::lua_rawequal(lua.state, -1, -2) == 0 {
ffi::lua_pop(lua.state, 3);
Err(Error::UserDataTypeMismatch)
} else {
let res = func(&*get_userdata::<RefCell<T>>(lua.state, -3)?);
ffi::lua_pop(lua.state, 3);
res
}
})
}
}
}
#[cfg(test)]
mod tests {
use super::{MetaMethod, UserData, UserDataMethods};
use error::{Error, ExternalError};
use string::String;
use function::Function;
use lua::Lua;
#[test]
fn test_user_data() {
struct UserData1(i64);
struct UserData2(Box<i64>);
impl UserData for UserData1 {};
impl UserData for UserData2 {};
let lua = Lua::new();
let userdata1 = lua.create_userdata(UserData1(1)).unwrap();
let userdata2 = lua.create_userdata(UserData2(Box::new(2))).unwrap();
assert!(userdata1.is::<UserData1>().unwrap());
assert!(!userdata1.is::<UserData2>().unwrap());
assert!(userdata2.is::<UserData2>().unwrap());
assert!(!userdata2.is::<UserData1>().unwrap());
assert_eq!(userdata1.borrow::<UserData1>().unwrap().0, 1);
assert_eq!(*userdata2.borrow::<UserData2>().unwrap().0, 2);
}
#[test]
fn test_methods() {
struct MyUserData(i64);
impl UserData for MyUserData {
fn add_methods(methods: &mut UserDataMethods<Self>) {
methods.add_method("get_value", |_, data, ()| Ok(data.0));
methods.add_method_mut("set_value", |_, data, args| {
data.0 = args;
Ok(())
});
}
}
let lua = Lua::new();
let globals = lua.globals();
let userdata = lua.create_userdata(MyUserData(42)).unwrap();
globals.set("userdata", userdata.clone()).unwrap();
lua.exec::<()>(
r#"
function get_it()
return userdata:get_value()
end
function set_it(i)
return userdata:set_value(i)
end
"#,
None,
).unwrap();
let get = globals.get::<_, Function>("get_it").unwrap();
let set = globals.get::<_, Function>("set_it").unwrap();
assert_eq!(get.call::<_, i64>(()).unwrap(), 42);
userdata.borrow_mut::<MyUserData>().unwrap().0 = 64;
assert_eq!(get.call::<_, i64>(()).unwrap(), 64);
set.call::<_, ()>(100).unwrap();
assert_eq!(get.call::<_, i64>(()).unwrap(), 100);
}
#[test]
fn test_metamethods() {
#[derive(Copy, Clone)]
struct MyUserData(i64);
impl UserData for MyUserData {
fn add_methods(methods: &mut UserDataMethods<Self>) {
methods.add_method("get", |_, data, ()| Ok(data.0));
methods.add_meta_function(
MetaMethod::Add,
|_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 + rhs.0)),
);
methods.add_meta_function(
MetaMethod::Sub,
|_, (lhs, rhs): (MyUserData, MyUserData)| Ok(MyUserData(lhs.0 - rhs.0)),
);
methods.add_meta_method(MetaMethod::Index, |_, data, index: String| {
if index.to_str()? == "inner" {
Ok(data.0)
} else {
Err("no such custom index".to_lua_err())
}
});
}
}
let lua = Lua::new();
let globals = lua.globals();
globals.set("userdata1", MyUserData(7)).unwrap();
globals.set("userdata2", MyUserData(3)).unwrap();
assert_eq!(
lua.eval::<MyUserData>("userdata1 + userdata2", None)
.unwrap()
.0,
10
);
assert_eq!(
lua.eval::<MyUserData>("userdata1 - userdata2", None)
.unwrap()
.0,
4
);
assert_eq!(lua.eval::<i64>("userdata1:get()", None).unwrap(), 7);
assert_eq!(lua.eval::<i64>("userdata2.inner", None).unwrap(), 3);
assert!(lua.eval::<()>("userdata2.nonexist_field", None).is_err());
}
#[test]
fn test_expired_userdata() {
struct MyUserdata {
id: u8,
}
impl UserData for MyUserdata {
fn add_methods(methods: &mut UserDataMethods<Self>) {
methods.add_method("access", |_, this, ()| {
assert!(this.id == 123);
Ok(())
});
}
}
let lua = Lua::new();
{
let globals = lua.globals();
globals.set("userdata", MyUserdata { id: 123 }).unwrap();
}
match lua.eval::<()>(
r#"
local tbl = setmetatable({
userdata = userdata
}, { __gc = function(self)
-- resurrect userdata
hatch = self.userdata
end })
tbl = nil
userdata = nil -- make table and userdata collectable
collectgarbage("collect")
hatch:access()
"#,
None,
) {
Err(Error::CallbackError { cause, .. }) => match *cause {
Error::ExpiredUserData { .. } => {}
ref other => panic!("incorrect result: {}", other),
},
other => panic!("incorrect result: {:?}", other),
}
}
#[test]
fn detroys_userdata() {
use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT};
static DROPPED: AtomicBool = ATOMIC_BOOL_INIT;
struct MyUserdata;
impl UserData for MyUserdata {}
impl Drop for MyUserdata {
fn drop(&mut self) {
DROPPED.store(true, Ordering::SeqCst);
}
}
let lua = Lua::new();
{
let globals = lua.globals();
globals.set("userdata", MyUserdata).unwrap();
}
assert_eq!(DROPPED.load(Ordering::SeqCst), false);
drop(lua); assert_eq!(DROPPED.load(Ordering::SeqCst), true);
}
}