#![feature(ptr_metadata)]
#![feature(unsize)]
use std::any::TypeId;
use std::cell::{RefCell, RefMut};
use std::collections::HashMap;
use std::marker::Unsize;
use std::mem::transmute;
use std::ptr::{self, DynMetadata, NonNull, Pointee};
#[cfg(feature = "trait-map-derive")]
#[allow(unused_imports)]
pub use trait_map_derive::TraitMapEntry;
pub trait TraitMapEntry: 'static {
fn on_create<'a>(&mut self, context: Context<'a>);
#[allow(unused_variables)]
fn on_update<'a>(&mut self, context: Context<'a>) {}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
#[repr(transparent)]
pub struct EntryID(u64);
#[derive(Debug, Default)]
pub struct TraitMap {
unique_entry_id: u64,
traits: RefCell<HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
concrete_types: HashMap<EntryID, TypeId>,
}
#[derive(Debug)]
pub struct Context<'a> {
entry_id: EntryID,
pointer: NonNull<()>,
type_id: TypeId,
traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
}
#[derive(Debug)]
pub struct TypedContext<'a, E: ?Sized> {
entry_id: EntryID,
pointer: NonNull<E>,
traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
}
impl TraitMap {
pub fn new() -> Self {
Self::default()
}
pub fn add_entry<Entry: TraitMapEntry + 'static>(&mut self, entry: Entry) -> EntryID {
let entry_ref = Box::leak(Box::new(entry));
let entry_id = EntryID(self.unique_entry_id);
self.unique_entry_id += 1;
self.concrete_types.insert(entry_id, TypeId::of::<Entry>());
let mut context = TypedContext {
entry_id,
pointer: entry_ref.into(),
traits: self.traits.borrow_mut(),
};
context.add_trait::<dyn TraitMapEntry>();
entry_ref.on_create(context.upcast());
entry_id
}
pub fn remove_entry(&mut self, entry_id: EntryID) -> bool {
let mut removed = false;
if let Some(pointer) = self
.traits
.borrow_mut()
.get_mut(&TypeId::of::<dyn TraitMapEntry>())
.and_then(|traits| traits.remove(&entry_id))
{
drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() });
removed = true;
}
self.concrete_types.remove(&entry_id);
for traits in self.traits.borrow_mut().values_mut() {
traits.remove(&entry_id);
}
removed
}
pub fn update_entry(&mut self, entry_id: EntryID) -> bool {
(|| {
let type_id = self.concrete_types.get(&entry_id).cloned()?;
let (pointer, entry) = self
.traits
.borrow()
.get(&TypeId::of::<dyn TraitMapEntry>())
.and_then(|traits| traits.get(&entry_id))
.map(|pointer| unsafe {
(
NonNull::new_unchecked(pointer.pointer),
pointer.reconstruct_mut::<dyn TraitMapEntry>(),
)
})?;
entry.on_update(Context {
entry_id,
pointer,
type_id,
traits: self.traits.borrow_mut(),
});
Some(())
})()
.is_some()
}
pub fn get_entry_type(&self, entry_id: EntryID) -> Option<TypeId> {
self.concrete_types.get(&entry_id).cloned()
}
pub fn all_entries(&self) -> HashMap<EntryID, &dyn TraitMapEntry> {
self.get_entries()
}
pub fn all_entries_mut(&mut self) -> HashMap<EntryID, &mut dyn TraitMapEntry> {
self.get_entries_mut()
}
pub fn get_entries<Trait>(&self) -> HashMap<EntryID, &Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
self
.traits
.borrow()
.get(&TypeId::of::<Trait>())
.map(|traits| {
traits
.iter()
.map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_ref() }))
.collect()
})
.unwrap_or_default()
}
pub fn get_entries_mut<Trait>(&mut self) -> HashMap<EntryID, &mut Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
self
.traits
.borrow()
.get(&TypeId::of::<Trait>())
.map(|traits| {
traits
.iter()
.map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_mut() }))
.collect()
})
.unwrap_or_default()
}
pub fn get_entry<Trait>(&self, entry_id: EntryID) -> Option<&Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
self
.traits
.borrow()
.get(&TypeId::of::<Trait>())
.and_then(|traits| traits.get(&entry_id))
.map(|pointer| unsafe { pointer.reconstruct_ref() })
}
pub fn get_entry_mut<Trait>(&mut self, entry_id: EntryID) -> Option<&mut Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
self
.traits
.borrow()
.get(&TypeId::of::<Trait>())
.and_then(|traits| traits.get(&entry_id))
.map(|pointer| unsafe { pointer.reconstruct_mut() })
}
pub fn get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<&T> {
self
.try_get_entry_downcast(entry_id)
.map(|entry| entry.expect("Invalid downcast"))
}
pub fn get_entry_downcast_mut<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<&mut T> {
self
.try_get_entry_downcast_mut(entry_id)
.map(|entry| entry.expect("Invalid downcast"))
}
pub fn take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<T> {
self
.try_take_entry_downcast(entry_id)
.map(|entry| entry.expect("Invalid downcast"))
}
pub fn try_get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<Option<&T>> {
if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
return Some(None);
}
Some(self.get_entry::<dyn TraitMapEntry>(entry_id).map(|entry| {
let (pointer, _) = (entry as *const dyn TraitMapEntry).to_raw_parts();
unsafe { &*(pointer as *const T) }
}))
}
pub fn try_get_entry_downcast_mut<T: TraitMapEntry + 'static>(
&mut self,
entry_id: EntryID,
) -> Option<Option<&mut T>> {
if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
return Some(None);
}
Some(self.get_entry_mut::<dyn TraitMapEntry>(entry_id).map(|entry| {
let (pointer, _) = (entry as *mut dyn TraitMapEntry).to_raw_parts();
unsafe { &mut *(pointer as *mut T) }
}))
}
pub fn try_take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<Option<T>> {
if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
return Some(None);
}
let entry = self
.traits
.borrow_mut()
.get_mut(&TypeId::of::<dyn TraitMapEntry>())
.and_then(|traits| traits.remove_entry(&entry_id))
.map(|(_, pointer)| *unsafe { Box::from_raw(pointer.pointer as *mut T) })?;
self.remove_entry(entry_id);
Some(Some(entry))
}
}
impl Drop for TraitMap {
fn drop(&mut self) {
if let Some(traits) = self.traits.borrow().get(&TypeId::of::<dyn TraitMapEntry>()) {
for pointer in traits.values() {
drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() })
}
}
}
}
#[derive(Debug)]
struct PointerWithMetadata {
pointer: *mut (),
boxed_metadata: Box<*const ()>,
}
impl PointerWithMetadata {
#[inline]
pub fn new(pointer: *mut (), boxed_metadata: Box<*const ()>) -> Self {
Self {
pointer,
boxed_metadata,
}
}
pub fn from_trait_pointer<T, Trait>(pointer: *mut T) -> Self
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
T: Unsize<Trait>,
{
let (pointer, metadata) = (pointer as *mut Trait).to_raw_parts();
let boxed_metadata = unsafe { transmute(Box::new(metadata)) };
Self::new(pointer, boxed_metadata)
}
pub unsafe fn into_boxed<Trait>(&self) -> Box<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
Box::from_raw(self.reconstruct_ptr())
}
pub unsafe fn reconstruct_ref<'a, Trait>(&self) -> &'a Trait
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
&*self.reconstruct_ptr()
}
pub unsafe fn reconstruct_mut<'a, Trait>(&self) -> &'a mut Trait
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
&mut *self.reconstruct_ptr()
}
pub fn reconstruct_ptr<Trait>(&self) -> *mut Trait
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
let metadata: <Trait as Pointee>::Metadata =
unsafe { *transmute::<_, *const <Trait as Pointee>::Metadata>(self.boxed_metadata.as_ref()) };
ptr::from_raw_parts_mut::<Trait>(self.pointer, metadata)
}
}
impl<'a> Context<'a> {
pub fn downcast<T>(self) -> TypedContext<'a, T>
where
T: 'static,
{
self.try_downcast::<T>().expect("Invalid downcast")
}
pub fn try_downcast<T>(self) -> Result<TypedContext<'a, T>, Self>
where
T: 'static,
{
if self.type_id != TypeId::of::<T>() {
Err(self)
} else {
Ok(TypedContext {
entry_id: self.entry_id,
pointer: self.pointer.cast(),
traits: self.traits,
})
}
}
}
impl<'a, Entry> TypedContext<'a, Entry>
where
Entry: 'static,
{
pub fn upcast(self) -> Context<'a> {
Context {
entry_id: self.entry_id,
pointer: self.pointer.cast(),
type_id: TypeId::of::<Entry>(),
traits: self.traits,
}
}
pub fn add_trait<Trait>(&mut self) -> &mut Self
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
Entry: Unsize<Trait>,
{
let type_id = TypeId::of::<Trait>();
let pointer = PointerWithMetadata::from_trait_pointer::<Entry, Trait>(self.pointer.as_ptr());
let traits = self.traits.entry(type_id).or_default();
traits.insert(self.entry_id, pointer);
self
}
pub fn remove_trait<Trait>(&mut self) -> &mut Self
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
Entry: Unsize<Trait>,
{
let type_id = TypeId::of::<Trait>();
if type_id == TypeId::of::<dyn TraitMapEntry>() {
return self;
}
if let Some(traits) = self.traits.get_mut(&type_id) {
traits.remove(&self.entry_id);
}
self
}
pub fn has_trait<Trait>(&self) -> bool
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
Entry: Unsize<Trait>,
{
let type_id = TypeId::of::<Trait>();
self
.traits
.get(&type_id)
.map(|traits| traits.contains_key(&self.entry_id))
.unwrap_or(false)
}
}
#[cfg(test)]
mod test {
use super::*;
trait TraitOne {
fn add_with_offset(&self, a: u32, b: u32) -> u32;
fn mul_with_mut(&mut self, a: u32, b: u32) -> u32;
}
trait TraitTwo {
fn compute(&self) -> f64;
}
trait TraitThree {
fn unused(&self) -> (i8, i8);
}
struct OneAndTwo {
offset: u32,
compute: f64,
on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
}
impl OneAndTwo {
pub fn new(offset: u32, compute: f64) -> Self {
Self {
offset,
compute,
on_create_fn: Some(Box::new(|_, context| {
context
.downcast::<Self>()
.add_trait::<dyn TraitOne>()
.add_trait::<dyn TraitTwo>();
})),
on_update_fn: None,
}
}
}
struct TwoOnly {
compute: f64,
on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
}
impl TwoOnly {
pub fn new(compute: f64) -> Self {
Self {
compute,
on_create_fn: Some(Box::new(|_, context| {
context.downcast::<Self>().add_trait::<dyn TraitTwo>();
})),
on_update_fn: None,
}
}
}
impl TraitOne for OneAndTwo {
fn add_with_offset(&self, a: u32, b: u32) -> u32 {
a + b + self.offset
}
fn mul_with_mut(&mut self, a: u32, b: u32) -> u32 {
self.offset = a * b;
a + b + self.offset
}
}
impl TraitTwo for OneAndTwo {
fn compute(&self) -> f64 {
self.compute
}
}
impl TraitTwo for TwoOnly {
fn compute(&self) -> f64 {
self.compute * self.compute
}
}
impl TraitMapEntry for OneAndTwo {
fn on_create<'a>(&mut self, context: Context<'a>) {
if let Some(mut on_create_fn) = self.on_create_fn.take() {
on_create_fn(self, context);
self.on_create_fn = Some(on_create_fn);
}
}
fn on_update<'a>(&mut self, context: Context<'a>) {
if let Some(mut on_update_fn) = self.on_update_fn.take() {
on_update_fn(self, context);
self.on_update_fn = Some(on_update_fn);
}
}
}
impl TraitMapEntry for TwoOnly {
fn on_create<'a>(&mut self, context: Context<'a>) {
if let Some(mut on_create_fn) = self.on_create_fn.take() {
on_create_fn(self, context);
self.on_create_fn = Some(on_create_fn);
}
}
fn on_update<'a>(&mut self, context: Context<'a>) {
if let Some(mut on_update_fn) = self.on_update_fn.take() {
on_update_fn(self, context);
self.on_update_fn = Some(on_update_fn);
}
}
}
#[test]
fn test_adding_and_queries_traits() {
let mut map = TraitMap::new();
let entry_one_id = map.add_entry(OneAndTwo::new(3, 10.0));
let entry_two_id = map.add_entry(TwoOnly::new(10.0));
assert_eq!(map.all_entries().len(), 2);
let entries = map.get_entries_mut::<dyn TraitOne>();
assert_eq!(entries.len(), 1);
for (entry_id, entry) in entries.into_iter() {
assert_eq!(entry_id, entry_one_id);
assert_eq!(entry.add_with_offset(1, 2), 6);
assert_eq!(entry.mul_with_mut(1, 2), 5);
assert_eq!(entry.add_with_offset(1, 2), 5);
}
let entries = map.get_entries::<dyn TraitTwo>();
let entry_one = entries.get(&entry_one_id);
let entry_two = entries.get(&entry_two_id);
assert_eq!(entries.len(), 2);
assert!(entry_one.is_some());
assert_eq!(entry_one.unwrap().compute(), 10.0);
assert!(entry_two.is_some());
assert_eq!(entry_two.unwrap().compute(), 100.0);
}
#[test]
fn test_removing_traits() {
let mut map = TraitMap::new();
let mut entry = OneAndTwo::new(3, 10.0);
entry.on_update_fn = Some(Box::new(|_, context| {
context.downcast::<OneAndTwo>().remove_trait::<dyn TraitOne>();
}));
let entry_id = map.add_entry(entry);
assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
map.update_entry(entry_id);
assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
}
#[test]
fn test_adding_and_removing_entry() {
let mut map = TraitMap::new();
let entry_one_id = map.add_entry(TwoOnly::new(10.0));
let entry_two_id = map.add_entry(TwoOnly::new(20.0));
let entry_three_id = map.add_entry(TwoOnly::new(30.0));
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_some());
assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
map.remove_entry(entry_two_id);
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
let entry_four_id = map.add_entry(TwoOnly::new(40.0));
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
assert!(map.get_entry::<dyn TraitTwo>(entry_four_id).is_some());
}
#[test]
#[should_panic]
fn test_context_invalid_downcast_panics() {
let mut map = TraitMap::new();
let mut entry = OneAndTwo::new(3, 10.0);
entry.on_create_fn = Some(Box::new(|_, context| {
context.downcast::<TwoOnly>().add_trait::<dyn TraitTwo>();
}));
map.add_entry::<OneAndTwo>(entry);
}
#[test]
fn test_get_entry() {
let mut map = TraitMap::new();
let entry_one_id = map.add_entry(TwoOnly::new(10.0));
let entry_two_id = map.add_entry(OneAndTwo::new(1, 20.0));
assert!(map.get_entry::<dyn TraitOne>(entry_one_id).is_none()); assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
assert!(map.get_entry::<dyn TraitThree>(entry_one_id).is_none()); assert!(map.get_entry_mut::<dyn TraitOne>(entry_two_id).is_some());
assert!(map.get_entry_mut::<dyn TraitTwo>(entry_two_id).is_some());
assert!(map.get_entry_mut::<dyn TraitThree>(entry_two_id).is_none()); }
#[test]
#[should_panic]
fn test_get_entry_invalid_downcast_panics() {
let mut map = TraitMap::new();
let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
map.get_entry_downcast::<TwoOnly>(entry_id);
}
#[test]
fn test_take_entry_downcast() {
let mut map = TraitMap::new();
let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
let take = map.take_entry_downcast::<OneAndTwo>(entry_id);
assert!(take.is_some());
assert_eq!(take.unwrap().offset, 1);
}
#[test]
#[should_panic]
fn test_take_entry_invalid_downcast_panics() {
let mut map = TraitMap::new();
let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
map.take_entry_downcast::<TwoOnly>(entry_id);
}
#[test]
fn test_cannot_remove_trait_map_entry() {
let mut map = TraitMap::new();
let mut entry = OneAndTwo::new(3, 10.0);
entry.on_update_fn = Some(Box::new(|_, context| {
context
.downcast::<OneAndTwo>()
.remove_trait::<dyn TraitOne>()
.remove_trait::<dyn TraitMapEntry>(); }));
let entry_id = map.add_entry(entry);
map.add_entry(TwoOnly::new(1.5));
assert_eq!(map.all_entries().len(), 2);
assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
map.update_entry(entry_id);
assert_eq!(map.all_entries().len(), 2);
assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
}
}