use std::{any::TypeId, collections::hash_map::Entry, marker::PhantomData};
use ahash::AHashMap as HashMap;
use crate::cell::{AtomicRef, AtomicRefMut};
use crate::{Resource, ResourceId, World};
#[cfg(feature = "nightly")]
use core::ptr::{DynMetadata, Pointee};
struct Invariant<T: ?Sized>(*mut T);
unsafe impl<T> Send for Invariant<T> where T: ?Sized {}
unsafe impl<T> Sync for Invariant<T> where T: ?Sized {}
pub unsafe trait CastFrom<T> {
fn cast(t: *mut T) -> *mut Self;
}
pub struct MetaIter<'a, T: ?Sized + 'a> {
#[cfg(not(feature = "nightly"))]
vtable_fns: &'a [fn(*mut ()) -> *mut T],
#[cfg(feature = "nightly")]
vtables: &'a [DynMetadata<T>],
index: usize,
tys: &'a [TypeId],
marker: PhantomData<Invariant<T>>,
world: &'a World,
}
#[cfg(not(feature = "nightly"))]
impl<'a, T> Iterator for MetaIter<'a, T>
where
T: ?Sized + 'a,
{
type Item = AtomicRef<'a, T>;
#[allow(clippy::borrowed_box)] fn next(&mut self) -> Option<<Self as Iterator>::Item> {
loop {
let resource_id = match self.tys.get(self.index) {
Some(&x) => ResourceId::from_type_id(x),
None => return None,
};
let index = self.index;
self.index += 1;
if let Some(res) = unsafe { self.world.try_fetch_internal(resource_id) } {
let vtable_fn = self.vtable_fns[index];
let trait_object = AtomicRef::map(res.borrow(), |res: &Box<dyn Resource>| {
let ptr: *const dyn Resource = Box::as_ref(res);
let trait_ptr = (vtable_fn)(ptr.cast::<()>().cast_mut());
unsafe { &*trait_ptr }
});
return Some(trait_object);
}
}
}
}
#[cfg(feature = "nightly")]
impl<'a, T> Iterator for MetaIter<'a, T>
where
T: ?Sized + 'a,
T: Pointee<Metadata = DynMetadata<T>>,
{
type Item = AtomicRef<'a, T>;
#[allow(clippy::borrowed_box)] fn next(&mut self) -> Option<<Self as Iterator>::Item> {
loop {
let resource_id = match self.tys.get(self.index) {
Some(&x) => ResourceId::from_type_id(x),
None => return None,
};
let index = self.index;
self.index += 1;
if let Some(res) = unsafe { self.world.try_fetch_internal(resource_id) } {
let vtable = self.vtables[index];
let trait_object = AtomicRef::map(res.borrow(), |res: &Box<dyn Resource>| {
let ptr: *const dyn Resource = Box::as_ref(res);
let trait_ptr = core::ptr::from_raw_parts(ptr.cast::<()>(), vtable);
unsafe { &*trait_ptr }
});
return Some(trait_object);
}
}
}
}
pub struct MetaIterMut<'a, T: ?Sized + 'a> {
#[cfg(not(feature = "nightly"))]
vtable_fns: &'a [fn(*mut ()) -> *mut T],
#[cfg(feature = "nightly")]
vtables: &'a [DynMetadata<T>],
index: usize,
tys: &'a [TypeId],
marker: PhantomData<Invariant<T>>,
world: &'a World,
}
#[cfg(not(feature = "nightly"))]
impl<'a, T> Iterator for MetaIterMut<'a, T>
where
T: ?Sized + 'a,
{
type Item = AtomicRefMut<'a, T>;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
loop {
let resource_id = match self.tys.get(self.index) {
Some(&x) => ResourceId::from_type_id(x),
None => return None,
};
let index = self.index;
self.index += 1;
if let Some(res) = unsafe { self.world.try_fetch_internal(resource_id) } {
let vtable_fn = self.vtable_fns[index];
let trait_object =
AtomicRefMut::map(res.borrow_mut(), |res: &mut Box<dyn Resource>| {
let ptr: *mut dyn Resource = Box::as_mut(res);
let trait_ptr = (vtable_fn)(ptr.cast::<()>());
unsafe { &mut *trait_ptr }
});
return Some(trait_object);
}
}
}
}
#[cfg(feature = "nightly")]
impl<'a, T> Iterator for MetaIterMut<'a, T>
where
T: ?Sized + 'a,
T: Pointee<Metadata = DynMetadata<T>>,
{
type Item = AtomicRefMut<'a, T>;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
loop {
let resource_id = match self.tys.get(self.index) {
Some(&x) => ResourceId::from_type_id(x),
None => return None,
};
let index = self.index;
self.index += 1;
if let Some(res) = unsafe { self.world.try_fetch_internal(resource_id) } {
let vtable = self.vtables[index];
let trait_object =
AtomicRefMut::map(res.borrow_mut(), |res: &mut Box<dyn Resource>| {
let ptr: *mut dyn Resource = Box::as_mut(res);
let trait_ptr = core::ptr::from_raw_parts_mut(ptr.cast::<()>(), vtable);
unsafe { &mut *trait_ptr }
});
return Some(trait_object);
}
}
}
}
#[cfg(not(feature = "nightly"))]
fn attach_vtable<TraitObject, T>(value: *mut ()) -> *mut TraitObject
where
TraitObject: CastFrom<T> + 'static + ?Sized,
T: core::any::Any,
{
let trait_ptr = <TraitObject as CastFrom<T>>::cast(value.cast::<T>());
assert!(
core::ptr::eq(value, trait_ptr.cast::<()>()),
"Bug: `CastFrom` did not cast `self`"
);
trait_ptr
}
pub struct MetaTable<T: ?Sized> {
#[cfg(not(feature = "nightly"))]
vtable_fns: Vec<fn(*mut ()) -> *mut T>,
#[cfg(feature = "nightly")]
vtables: Vec<DynMetadata<T>>,
indices: HashMap<TypeId, usize>,
tys: Vec<TypeId>,
marker: PhantomData<Invariant<T>>,
}
impl<T: ?Sized> MetaTable<T> {
pub fn new() -> Self {
assert_unsized::<T>();
Default::default()
}
#[cfg(not(feature = "nightly"))]
pub fn register<R>(&mut self)
where
R: Resource,
T: CastFrom<R> + 'static,
{
let ty_id = TypeId::of::<R>();
let vtable_fn = attach_vtable::<T, R>;
let len = self.indices.len();
match self.indices.entry(ty_id) {
Entry::Occupied(occ) => {
let ind = *occ.get();
self.vtable_fns[ind] = vtable_fn;
}
Entry::Vacant(vac) => {
vac.insert(len);
self.vtable_fns.push(vtable_fn);
self.tys.push(ty_id);
}
}
}
#[cfg(feature = "nightly")]
pub fn register<R>(&mut self)
where
R: Resource,
T: CastFrom<R> + 'static,
T: Pointee<Metadata = DynMetadata<T>>,
{
let ty_id = TypeId::of::<R>();
let invalid_ptr = core::ptr::without_provenance_mut::<R>((self as *mut Self).addr());
let trait_ptr = <T as CastFrom<R>>::cast(invalid_ptr);
assert_eq!(
invalid_ptr.addr(),
trait_ptr.addr(),
"Bug: `CastFrom` did not cast `self`"
);
let vtable = core::ptr::metadata(trait_ptr);
let len = self.indices.len();
match self.indices.entry(ty_id) {
Entry::Occupied(occ) => {
let ind = *occ.get();
self.vtables[ind] = vtable;
}
Entry::Vacant(vac) => {
vac.insert(len);
self.vtables.push(vtable);
self.tys.push(ty_id);
}
}
}
#[cfg(not(feature = "nightly"))]
pub fn get<'a>(&self, res: &'a dyn Resource) -> Option<&'a T> {
self.indices.get(&res.type_id()).map(|&ind| {
let vtable_fn = self.vtable_fns[ind];
let ptr = <*const dyn Resource>::cast::<()>(res).cast_mut();
let trait_ptr = (vtable_fn)(ptr);
unsafe { &*trait_ptr }
})
}
#[cfg(feature = "nightly")]
pub fn get<'a>(&self, res: &'a dyn Resource) -> Option<&'a T>
where
T: Pointee<Metadata = DynMetadata<T>>,
{
self.indices.get(&res.type_id()).map(|&ind| {
let vtable = self.vtables[ind];
let ptr = <*const dyn Resource>::cast::<()>(res);
let trait_ptr = core::ptr::from_raw_parts(ptr, vtable);
unsafe { &*trait_ptr }
})
}
#[cfg(not(feature = "nightly"))]
pub fn get_mut<'a>(&self, res: &'a mut dyn Resource) -> Option<&'a mut T> {
self.indices.get(&res.type_id()).map(|&ind| {
let vtable_fn = self.vtable_fns[ind];
let ptr = <*mut dyn Resource>::cast::<()>(res);
let trait_ptr = (vtable_fn)(ptr);
unsafe { &mut *trait_ptr }
})
}
#[cfg(feature = "nightly")]
pub fn get_mut<'a>(&self, res: &'a mut dyn Resource) -> Option<&'a mut T>
where
T: Pointee<Metadata = DynMetadata<T>>,
{
self.indices.get(&res.type_id()).map(|&ind| {
let vtable = self.vtables[ind];
let ptr = <*mut dyn Resource>::cast::<()>(res);
let trait_ptr = core::ptr::from_raw_parts_mut(ptr, vtable);
unsafe { &mut *trait_ptr }
})
}
pub fn iter<'a>(&'a self, res: &'a World) -> MetaIter<'a, T> {
MetaIter {
#[cfg(not(feature = "nightly"))]
vtable_fns: &self.vtable_fns,
#[cfg(feature = "nightly")]
vtables: &self.vtables,
index: 0,
world: res,
tys: &self.tys,
marker: PhantomData,
}
}
pub fn iter_mut<'a>(&'a self, res: &'a World) -> MetaIterMut<'a, T> {
MetaIterMut {
#[cfg(not(feature = "nightly"))]
vtable_fns: &self.vtable_fns,
#[cfg(feature = "nightly")]
vtables: &self.vtables,
index: 0,
world: res,
tys: &self.tys,
marker: PhantomData,
}
}
}
impl<T> Default for MetaTable<T>
where
T: ?Sized,
{
fn default() -> Self {
MetaTable {
#[cfg(not(feature = "nightly"))]
vtable_fns: Default::default(),
#[cfg(feature = "nightly")]
vtables: Default::default(),
indices: Default::default(),
tys: Default::default(),
marker: Default::default(),
}
}
}
fn assert_unsized<T: ?Sized>() {
use core::mem::size_of;
assert_eq!(size_of::<&T>(), 2 * size_of::<usize>());
}
#[cfg(test)]
mod tests {
use super::*;
trait Object {
fn method1(&self) -> i32;
fn method2(&mut self, x: i32);
}
unsafe impl<T> CastFrom<T> for dyn Object
where
T: Object + 'static,
{
fn cast(t: *mut T) -> *mut Self {
t
}
}
struct ImplementorA(i32);
impl Object for ImplementorA {
fn method1(&self) -> i32 {
self.0
}
fn method2(&mut self, x: i32) {
self.0 += x;
}
}
struct ImplementorB(i32);
impl Object for ImplementorB {
fn method1(&self) -> i32 {
self.0
}
fn method2(&mut self, x: i32) {
self.0 *= x;
}
}
#[test]
fn test_iter_all() {
let mut world = World::empty();
world.insert(ImplementorA(3));
world.insert(ImplementorB(1));
let mut table = MetaTable::<dyn Object>::new();
table.register::<ImplementorA>();
table.register::<ImplementorB>();
{
let mut iter = table.iter(&world);
assert_eq!(iter.next().unwrap().method1(), 3);
assert_eq!(iter.next().unwrap().method1(), 1);
}
{
let mut iter_mut = table.iter_mut(&world);
let mut obj = iter_mut.next().unwrap();
obj.method2(3);
assert_eq!(obj.method1(), 6);
let mut obj = iter_mut.next().unwrap();
obj.method2(4);
assert_eq!(obj.method1(), 4);
}
}
#[test]
fn test_iter_all_after_removal() {
let mut world = World::empty();
world.insert(ImplementorA(3));
world.insert(ImplementorB(1));
let mut table = MetaTable::<dyn Object>::new();
table.register::<ImplementorA>();
table.register::<ImplementorB>();
{
let mut iter = table.iter(&world);
assert_eq!(iter.next().unwrap().method1(), 3);
assert_eq!(iter.next().unwrap().method1(), 1);
}
world.remove::<ImplementorA>().unwrap();
{
let mut iter = table.iter(&world);
assert_eq!(iter.next().unwrap().method1(), 1);
}
world.remove::<ImplementorB>().unwrap();
}
struct ImplementorC;
impl Object for ImplementorC {
fn method1(&self) -> i32 {
33
}
fn method2(&mut self, _x: i32) {
unimplemented!()
}
}
struct ImplementorD;
impl Object for ImplementorD {
fn method1(&self) -> i32 {
42
}
fn method2(&mut self, _x: i32) {
unimplemented!()
}
}
#[test]
fn get() {
let mut world = World::empty();
world.insert(ImplementorC);
world.insert(ImplementorD);
let mut table = MetaTable::<dyn Object>::new();
table.register::<ImplementorC>();
table.register::<ImplementorD>();
assert_eq!(
table
.get(&*world.fetch::<ImplementorC>())
.unwrap()
.method1(),
33
);
assert_eq!(
table
.get(&*world.fetch::<ImplementorD>())
.unwrap()
.method1(),
42
);
world.insert(table);
}
}