use std::marker::PhantomData;
use crate::lua_value::LuaValue;
use crate::lua_value::LuaValueKind;
use crate::lua_value::lua_convert::collect_into_lua_values;
use crate::lua_value::lua_convert::{FromLua, FromLuaMulti, IntoLua};
use crate::{LuaResult, LuaVM};
pub type RefId = i32;
pub const LUA_REFNIL: RefId = -1; pub const LUA_NOREF: RefId = -2;
pub(crate) struct RefManager {
next_ref_id: RefId,
free_list: Vec<RefId>,
}
impl RefManager {
pub fn new() -> Self {
RefManager {
next_ref_id: 1, free_list: Vec::new(),
}
}
pub fn alloc_ref_id(&mut self) -> RefId {
if let Some(ref_id) = self.free_list.pop() {
ref_id
} else {
let ref_id = self.next_ref_id;
self.next_ref_id = self.next_ref_id.wrapping_add(1);
if self.next_ref_id < 0 {
self.next_ref_id = 1;
}
ref_id
}
}
pub fn free_ref_id(&mut self, ref_id: RefId) {
if ref_id > 0 && !self.free_list.contains(&ref_id) {
self.free_list.push(ref_id);
}
}
}
pub struct LuaRefValue {
inner: LuaRefInner,
}
enum LuaRefInner {
Direct(LuaValue),
Registry { ref_id: RefId },
}
impl LuaRefValue {
pub(crate) fn new_direct(value: LuaValue) -> Self {
LuaRefValue {
inner: LuaRefInner::Direct(value),
}
}
pub(crate) fn new_registry(ref_id: RefId) -> Self {
LuaRefValue {
inner: LuaRefInner::Registry { ref_id },
}
}
pub fn ref_id(&self) -> Option<RefId> {
match &self.inner {
LuaRefInner::Registry { ref_id } => Some(*ref_id),
LuaRefInner::Direct(_) => None,
}
}
pub fn get(&self, vm: &super::LuaVM) -> LuaValue {
match &self.inner {
LuaRefInner::Direct(value) => *value,
LuaRefInner::Registry { ref_id } => {
vm.registry_geti(*ref_id as i64).unwrap_or_default()
}
}
}
pub fn get_direct(&self) -> Option<&LuaValue> {
match &self.inner {
LuaRefInner::Direct(value) => Some(value),
LuaRefInner::Registry { .. } => None,
}
}
pub fn is_valid(&self) -> bool {
match &self.inner {
LuaRefInner::Direct(_) => true,
LuaRefInner::Registry { ref_id } => *ref_id > 0,
}
}
pub fn is_registry_ref(&self) -> bool {
matches!(&self.inner, LuaRefInner::Registry { .. })
}
pub fn to_raw_ref(&self) -> RefId {
match &self.inner {
LuaRefInner::Direct(value) => {
if value.is_nil() {
LUA_REFNIL
} else {
LUA_NOREF
}
}
LuaRefInner::Registry { ref_id } => *ref_id,
}
}
}
impl Clone for LuaRefValue {
fn clone(&self) -> Self {
match &self.inner {
LuaRefInner::Direct(value) => LuaRefValue::new_direct(*value),
LuaRefInner::Registry { ref_id } => {
LuaRefValue::new_registry(*ref_id)
}
}
}
}
impl std::fmt::Debug for LuaRefValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.inner {
LuaRefInner::Direct(value) => {
write!(f, "LuaRefValue::Direct({:?})", value)
}
LuaRefInner::Registry { ref_id } => {
write!(f, "LuaRefValue::Registry(ref_id={})", ref_id)
}
}
}
}
struct RefInner {
ref_id: RefId,
vm: *mut super::LuaVM,
_marker: PhantomData<*const ()>,
}
impl RefInner {
fn new(ref_id: RefId, vm: *mut super::LuaVM) -> Self {
RefInner {
ref_id,
vm,
_marker: PhantomData,
}
}
#[inline]
fn to_value(&self) -> LuaValue {
let vm = unsafe { &*self.vm };
vm.registry_geti(self.ref_id as i64)
.unwrap_or(LuaValue::nil())
}
#[inline]
fn vm(&self) -> &super::LuaVM {
unsafe { &*self.vm }
}
#[allow(clippy::mut_from_ref)]
#[inline]
fn vm_mut(&self) -> &mut super::LuaVM {
unsafe { &mut *self.vm }
}
}
impl Drop for RefInner {
fn drop(&mut self) {
if self.ref_id > 0 && !self.vm.is_null() {
unsafe {
(*self.vm).release_ref_id(self.ref_id);
}
}
}
}
impl std::fmt::Debug for RefInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RefInner(ref_id={})", self.ref_id)
}
}
pub(crate) fn store_in_registry(vm: &mut super::LuaVM, value: LuaValue) -> RefId {
let ref_id = vm.ref_manager.alloc_ref_id();
vm.registry_seti(ref_id as i64, value);
ref_id
}
pub struct LuaTableRef {
inner: RefInner,
}
impl LuaTableRef {
pub(crate) fn from_raw(ref_id: RefId, vm: *mut super::LuaVM) -> Self {
LuaTableRef {
inner: RefInner::new(ref_id, vm),
}
}
pub fn get(&self, key: &str) -> super::LuaResult<LuaValue> {
let vm = self.inner.vm_mut();
let table = self.inner.to_value();
let key_val = vm.create_string(key)?;
Ok(vm.raw_get(&table, &key_val).unwrap_or_default())
}
pub fn geti(&self, key: i64) -> super::LuaResult<LuaValue> {
let vm = self.inner.vm();
let table = self.inner.to_value();
Ok(vm.raw_geti(&table, key).unwrap_or_default())
}
pub fn get_value(&self, key: &LuaValue) -> super::LuaResult<LuaValue> {
let vm = self.inner.vm();
let table = self.inner.to_value();
Ok(vm.raw_get(&table, key).unwrap_or_default())
}
pub fn get_as<T: crate::FromLua>(&self, key: &str) -> super::LuaResult<T> {
let val = self.get(key)?;
let vm = self.inner.vm_mut();
T::from_lua(val, vm.main_state()).map_err(|msg| vm.error(msg))
}
pub fn set(&self, key: &str, value: LuaValue) -> super::LuaResult<()> {
let vm = self.inner.vm_mut();
let table = self.inner.to_value();
let key_val = vm.create_string(key)?;
vm.raw_set(&table, key_val, value);
Ok(())
}
pub fn seti(&self, key: i64, value: LuaValue) -> super::LuaResult<()> {
let vm = self.inner.vm_mut();
let table = self.inner.to_value();
vm.raw_seti(&table, key, value);
Ok(())
}
pub fn set_value(&self, key: LuaValue, value: LuaValue) -> super::LuaResult<()> {
let vm = self.inner.vm_mut();
let table = self.inner.to_value();
vm.raw_set(&table, key, value);
Ok(())
}
pub fn pairs(&self) -> super::LuaResult<Vec<(LuaValue, LuaValue)>> {
let vm = self.inner.vm();
let table = self.inner.to_value();
vm.table_pairs(&table)
}
pub fn len(&self) -> super::LuaResult<usize> {
let vm = self.inner.vm();
let table = self.inner.to_value();
vm.table_length(&table)
}
pub fn push(&self, value: LuaValue) -> super::LuaResult<()> {
let current_len = self.len()?;
self.seti((current_len + 1) as i64, value)
}
pub fn to_value(&self) -> LuaValue {
self.inner.to_value()
}
pub fn ref_id(&self) -> RefId {
self.inner.ref_id
}
}
impl std::fmt::Debug for LuaTableRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LuaTableRef(ref_id={})", self.inner.ref_id)
}
}
pub struct LuaFunctionRef {
inner: RefInner,
}
impl LuaFunctionRef {
pub(crate) fn from_raw(ref_id: RefId, vm: *mut LuaVM) -> Self {
LuaFunctionRef {
inner: RefInner::new(ref_id, vm),
}
}
pub fn call_raw(&self, args: Vec<LuaValue>) -> LuaResult<Vec<LuaValue>> {
let vm = self.inner.vm_mut();
let func = self.inner.to_value();
vm.call(func, args)
}
pub fn call1_raw(&self, args: Vec<LuaValue>) -> LuaResult<LuaValue> {
let results = self.call_raw(args)?;
Ok(results.into_iter().next().unwrap_or(LuaValue::nil()))
}
pub fn call<A: IntoLua, R: FromLuaMulti>(&self, args: A) -> LuaResult<R> {
let vm = self.inner.vm_mut();
let args = collect_into_lua_values(vm.main_state(), args).map_err(|msg| vm.error(msg))?;
let func = self.inner.to_value();
let results = vm.call_raw(func, args)?;
R::from_lua_multi(results, vm.main_state_ref()).map_err(|msg| vm.error(msg))
}
pub fn call1<A: IntoLua, R: FromLua>(&self, args: A) -> LuaResult<R> {
let vm = self.inner.vm_mut();
let args = collect_into_lua_values(vm.main_state(), args).map_err(|msg| vm.error(msg))?;
let func = self.inner.to_value();
let result = vm
.call_raw(func, args)?
.into_iter()
.next()
.unwrap_or(LuaValue::nil());
R::from_lua(result, vm.main_state_ref()).map_err(|msg| vm.error(msg))
}
pub async fn call_async(&self, args: Vec<LuaValue>) -> LuaResult<Vec<LuaValue>> {
let vm = self.inner.vm_mut();
let func = self.inner.to_value();
vm.call_async(func, args).await
}
pub fn to_value(&self) -> LuaValue {
self.inner.to_value()
}
pub fn ref_id(&self) -> RefId {
self.inner.ref_id
}
}
impl std::fmt::Debug for LuaFunctionRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LuaFunctionRef(ref_id={})", self.inner.ref_id)
}
}
pub struct LuaStringRef {
inner: RefInner,
}
impl LuaStringRef {
pub(crate) fn from_raw(ref_id: RefId, vm: *mut super::LuaVM) -> Self {
LuaStringRef {
inner: RefInner::new(ref_id, vm),
}
}
pub fn as_str(&self) -> Option<&str> {
let value = self.inner.to_value();
value.as_str().map(|s| {
unsafe { &*(s as *const str) }
})
}
pub fn as_bytes(&self) -> Option<&[u8]> {
let value = self.inner.to_value();
value
.as_bytes()
.map(|bytes| unsafe { &*(bytes as *const [u8]) })
}
pub fn to_string_lossy(&self) -> String {
self.as_str()
.map(str::to_owned)
.unwrap_or_else(|| String::from_utf8_lossy(self.as_bytes().unwrap_or(&[])).into_owned())
}
pub fn byte_len(&self) -> usize {
self.as_bytes().map(|s| s.len()).unwrap_or(0)
}
pub fn to_value(&self) -> LuaValue {
self.inner.to_value()
}
pub fn ref_id(&self) -> RefId {
self.inner.ref_id
}
}
impl std::fmt::Debug for LuaStringRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LuaStringRef(ref_id={}, {:?})",
self.inner.ref_id,
self.as_str()
)
}
}
impl std::fmt::Display for LuaStringRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_string_lossy())
}
}
pub struct UserDataRef<T: 'static> {
inner: RefInner,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> UserDataRef<T> {
pub(crate) fn from_raw(ref_id: RefId, vm: *mut super::LuaVM) -> Self {
UserDataRef {
inner: RefInner::new(ref_id, vm),
_marker: PhantomData,
}
}
pub fn get(&self) -> super::LuaResult<&T> {
let value = self.inner.to_value();
let expected = std::any::type_name::<T>();
let Some(userdata) = value.as_userdata_mut() else {
let vm = self.inner.vm_mut();
return Err(vm.error(format!(
"expected userdata {}, got {}",
expected,
value.type_name()
)));
};
let Some(inner) = userdata.downcast_ref::<T>() else {
let actual = userdata.type_name();
let vm = self.inner.vm_mut();
return Err(vm.error(format!("expected userdata {}, got {}", expected, actual)));
};
Ok(unsafe { &*(inner as *const T) })
}
pub fn get_mut(&mut self) -> super::LuaResult<&mut T> {
let value = self.inner.to_value();
let expected = std::any::type_name::<T>();
let Some(userdata) = value.as_userdata_mut() else {
let vm = self.inner.vm_mut();
return Err(vm.error(format!(
"expected userdata {}, got {}",
expected,
value.type_name()
)));
};
let Some(inner) = userdata.downcast_mut::<T>() else {
let actual = userdata.type_name();
let vm = self.inner.vm_mut();
return Err(vm.error(format!("expected userdata {}, got {}", expected, actual)));
};
Ok(unsafe { &mut *(inner as *mut T) })
}
pub fn type_name(&self) -> super::LuaResult<&'static str> {
let value = self.inner.to_value();
let Some(userdata) = value.as_userdata_mut() else {
let vm = self.inner.vm_mut();
return Err(vm.error(format!(
"expected userdata {}, got {}",
std::any::type_name::<T>(),
value.type_name()
)));
};
Ok(userdata.type_name())
}
pub fn to_value(&self) -> LuaValue {
self.inner.to_value()
}
pub fn ref_id(&self) -> RefId {
self.inner.ref_id
}
}
impl<T: 'static> FromLua for UserDataRef<T> {
fn from_lua(value: LuaValue, state: &super::LuaState) -> Result<Self, String> {
let expected = std::any::type_name::<T>();
let Some(userdata) = value.as_userdata_mut() else {
return Err(format!(
"expected userdata {}, got {}",
expected,
value.type_name()
));
};
if userdata.downcast_ref::<T>().is_none() {
return Err(format!(
"expected userdata {}, got {}",
expected,
userdata.type_name()
));
}
let vm = unsafe { &mut *state.vm_ptr() };
let ref_id = store_in_registry(vm, value);
Ok(UserDataRef::from_raw(ref_id, vm as *mut super::LuaVM))
}
}
impl<T: 'static> IntoLua for UserDataRef<T> {
fn into_lua(self, state: &mut super::LuaState) -> Result<usize, String> {
state
.push_value(self.to_value())
.map_err(|e| format!("{:?}", e))?;
Ok(1)
}
}
impl<T: 'static> std::fmt::Debug for UserDataRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"UserDataRef<{}>(ref_id={})",
std::any::type_name::<T>(),
self.inner.ref_id
)
}
}
pub struct LuaAnyRef {
inner: RefInner,
}
impl LuaAnyRef {
pub(crate) fn from_raw(ref_id: RefId, vm: *mut super::LuaVM) -> Self {
LuaAnyRef {
inner: RefInner::new(ref_id, vm),
}
}
pub fn to_value(&self) -> LuaValue {
self.inner.to_value()
}
pub fn as_table(&self) -> Option<LuaTableRef> {
let value = self.inner.to_value();
if !value.is_table() {
return None;
}
let vm = self.inner.vm_mut();
let ref_id = store_in_registry(vm, value);
Some(LuaTableRef::from_raw(ref_id, self.inner.vm))
}
pub fn as_function(&self) -> Option<LuaFunctionRef> {
let value = self.inner.to_value();
if !value.is_function() {
return None;
}
let vm = self.inner.vm_mut();
let ref_id = store_in_registry(vm, value);
Some(LuaFunctionRef::from_raw(ref_id, self.inner.vm))
}
pub fn as_string(&self) -> Option<LuaStringRef> {
let value = self.inner.to_value();
if !value.is_string() {
return None;
}
let vm = self.inner.vm_mut();
let ref_id = store_in_registry(vm, value);
Some(LuaStringRef::from_raw(ref_id, self.inner.vm))
}
pub fn as_userdata<T: 'static>(&self) -> Option<UserDataRef<T>> {
let value = self.inner.to_value();
let userdata = value.as_userdata_mut()?;
userdata.downcast_ref::<T>()?;
let vm = self.inner.vm_mut();
let ref_id = store_in_registry(vm, value);
Some(UserDataRef::from_raw(ref_id, self.inner.vm))
}
pub fn kind(&self) -> LuaValueKind {
self.inner.to_value().kind()
}
pub fn get_as<T: crate::FromLua>(&self) -> super::LuaResult<T> {
let val = self.inner.to_value();
let vm = self.inner.vm_mut();
T::from_lua(val, vm.main_state()).map_err(|msg| vm.error(msg))
}
pub fn ref_id(&self) -> RefId {
self.inner.ref_id
}
}
impl std::fmt::Debug for LuaAnyRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LuaAnyRef(ref_id={}, kind={:?})",
self.inner.ref_id,
self.kind()
)
}
}
#[cfg(test)]
mod tests {
use crate::{LuaVM, lua_vm::SafeOption};
use super::*;
#[test]
fn test_lua_ref_mechanism() {
let mut vm = LuaVM::new(SafeOption::default());
let table = vm.create_table(0, 2).unwrap();
let num_key = vm.create_string("num").unwrap();
let str_key = vm.create_string("str").unwrap();
let str_val = vm.create_string("hello").unwrap();
vm.raw_set(&table, num_key, LuaValue::number(42.0));
vm.raw_set(&table, str_key, str_val);
let number = LuaValue::number(123.456);
let nil_val = LuaValue::nil();
let table_ref = vm.create_ref(table);
let number_ref = vm.create_ref(number);
let nil_ref = vm.create_ref(nil_val);
assert!(table_ref.is_registry_ref(), "Table should use registry");
assert!(!number_ref.is_registry_ref(), "Number should be direct");
assert!(!nil_ref.is_registry_ref(), "Nil should be direct");
let retrieved_table = vm.get_ref_value(&table_ref);
assert!(retrieved_table.is_table(), "Should retrieve table");
let retrieved_num = vm.get_ref_value(&number_ref);
assert_eq!(
retrieved_num.as_number(),
Some(123.456),
"Should retrieve number"
);
let retrieved_nil = vm.get_ref_value(&nil_ref);
assert!(retrieved_nil.is_nil(), "Should retrieve nil");
let num_key2 = vm.create_string("num").unwrap();
let val = vm.raw_get(&retrieved_table, &num_key2);
assert_eq!(
val.and_then(|v| v.as_number()),
Some(42.0),
"Table content should be preserved"
);
let table_ref_id = table_ref.ref_id();
assert!(table_ref_id.is_some(), "Table ref should have ID");
assert!(table_ref_id.unwrap() > 0, "Ref ID should be positive");
let number_ref_id = number_ref.ref_id();
assert!(number_ref_id.is_none(), "Number ref should not have ID");
vm.release_ref(table_ref);
vm.release_ref(number_ref);
vm.release_ref(nil_ref);
let after_release = vm.get_ref_value_by_id(table_ref_id.unwrap());
assert!(after_release.is_nil(), "Released ref should return nil");
println!("✓ Lua ref mechanism test passed");
}
#[test]
fn test_ref_id_reuse() {
let mut vm = LuaVM::new(SafeOption::default());
let t1 = vm.create_table(0, 0).unwrap();
let ref1 = vm.create_ref(t1);
let id1 = ref1.ref_id().unwrap();
vm.release_ref(ref1);
let t2 = vm.create_table(0, 0).unwrap();
let ref2 = vm.create_ref(t2);
let id2 = ref2.ref_id().unwrap();
assert_eq!(id1, id2, "Ref IDs should be reused");
vm.release_ref(ref2);
println!("✓ Ref ID reuse test passed");
}
#[test]
fn test_multiple_refs() {
let mut vm = LuaVM::new(SafeOption::default());
let mut refs = Vec::new();
for i in 0..10 {
let table = vm.create_table(0, 1).unwrap();
let key = vm.create_string("value").unwrap();
let num_val = LuaValue::number(i as f64);
vm.raw_set(&table, key, num_val);
refs.push(vm.create_ref(table));
}
for (i, lua_ref) in refs.iter().enumerate() {
let table = vm.get_ref_value(lua_ref);
let key = vm.create_string("value").unwrap();
let val = vm.raw_get(&table, &key);
assert_eq!(
val.and_then(|v| v.as_number()),
Some(i as f64),
"Ref {} should have correct value",
i
);
}
for lua_ref in refs {
vm.release_ref(lua_ref);
}
println!("✓ Multiple refs test passed");
}
}