use std::any::TypeId;
use std::collections::BTreeMap;
pub trait Castable: 'static {
fn name() -> &'static str;
fn collect_casts(casts: &mut Casts)
where
Self: Sized;
}
#[derive(Default)]
pub struct Casts {
casts: Vec<(TypeId, TypeId, *const ())>,
casts_by_dst_src: BTreeMap<(Option<TypeId>, Option<TypeId>), u32>,
casts_by_src_dst: BTreeMap<(Option<TypeId>, Option<TypeId>), u32>,
}
impl Casts {
pub unsafe fn add<T, U>(&mut self, cast: fn(*const T) -> *const U)
where
T: 'static,
U: ?Sized + 'static,
{
let src = TypeId::of::<T>();
let dst = TypeId::of::<U>();
self.casts_by_src_dst
.insert((Some(src), Some(dst)), self.casts.len() as u32);
self.casts_by_dst_src
.insert((Some(dst), Some(src)), self.casts.len() as u32);
self.casts.push((src, dst, cast as *const ()));
}
pub fn find_key(&self, src_type: TypeId, dst_type: TypeId) -> Option<u32> {
self.casts_by_dst_src
.get(&(Some(dst_type), Some(src_type)))
.copied()
}
pub fn cast<Dst>(&self, key: u32, src: *const ()) -> *const Dst
where
Dst: ?Sized + 'static,
{
let (_t0, t1, ptr) = dbg!(self.casts[key as usize]);
assert_eq!(t1, TypeId::of::<Dst>());
let func: fn(*const ()) -> *const Dst = unsafe { std::mem::transmute(ptr) };
func(src)
}
pub fn get_dst(&self, key: u32) -> TypeId {
self.casts[key as usize].0
}
pub fn get_src(&self, key: u32) -> TypeId {
self.casts[key as usize].1
}
pub fn find_keys_by_dst<'a>(&'a self, ty: TypeId) -> impl Iterator<Item = (TypeId, u32)> + 'a {
self.casts_by_dst_src
.range((Some(ty), None)..)
.take_while(move |((dst, _src), _)| dst == &Some(ty))
.map(|((_dst, src), ix)| (src.unwrap(), *ix))
}
pub fn find_keys_by_src<'a>(&'a self, ty: TypeId) -> impl Iterator<Item = (TypeId, u32)> + 'a {
self.casts_by_src_dst
.range((Some(ty), None)..)
.take_while(move |((src, _dst), _)| src == &Some(ty))
.map(|((_src, dst), ix)| (dst.unwrap(), *ix))
}
}
unsafe impl Send for Casts {}
unsafe impl Sync for Casts {}
#[doc(hidden)]
#[macro_export]
macro_rules! add_cast {
($casts:expr, $src:ty, $dst:ty) => {
unsafe {
$casts.add::<$src, $dst>(|x| x);
}
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! impl_castable_from {
{
impl Castable for $struct:ident {
$(
into dyn $trait:path;
)*
}
} => {
impl $crate::Castable for $struct {
fn name() -> &'static str {
stringify!($struct)
}
fn collect_casts(casts: &mut $crate::Casts) {
$crate::add_cast!(casts, $struct, $struct);
$(
$crate::add_cast!(casts, $struct, dyn $trait);
)*
#[cfg(feature = "inventory")]
for register in inventory::iter::<$crate::RegisterCast> {
if (register.cast_from)() == std::any::TypeId::of::<$struct>() {
(register.register)(casts)
}
}
}
}
};
}
#[doc(hidden)]
pub const fn assert_implements_castable<T>()
where
T: Castable,
{
}
#[cfg(feature = "inventory")]
#[macro_export]
#[doc(hidden)]
macro_rules! impl_castable_into {
{
impl Castable into dyn $trait:path {
$(for $struct:path;)*
}
} => {
$(
const _: () = $crate::assert_implements_castable::<$struct>();
inventory::submit! {
$crate::RegisterCast {
cast_from: || std::any::TypeId::of::<$struct>(),
register: |casts| $crate::add_cast!(casts, $struct, dyn $trait),
}
}
)*
}
}
#[cfg(feature = "inventory")]
#[macro_export]
macro_rules! impl_castable {
{} => {};
{
impl Castable for $struct:ident {
$(
into dyn $trait:path;
)*
} $($tail:tt)*
} => {
$crate::impl_castable_from! {
impl Castable for $struct {
$(
into dyn $trait;
)*
}
}
$crate::impl_castable!{$($tail)*}
};
{
impl Castable into dyn $trait:path {
$(for $struct:path;)*
} $($tail:tt)*
} => {
$crate::impl_castable_into! {
impl Castable into dyn $trait {
$(for $struct;)*
}
}
$crate::impl_castable!{$($tail)*}
}
}
#[cfg(not(feature = "inventory"))]
#[macro_export]
macro_rules! impl_castable {
{
$(impl Castable for $struct:ident {
$(
into dyn $trait:path;
)*
})*
} => {
$($crate::impl_castable_from! {
impl Castable for $struct {
$(
into dyn $trait;
)*
}
})*
}
}
#[doc(hidden)]
#[cfg(feature = "inventory")]
pub struct RegisterCast {
pub cast_from: fn() -> TypeId,
pub register: fn(&mut Casts),
}
#[cfg(feature = "inventory")]
inventory::collect!(RegisterCast);
#[cfg(test)]
mod tests {
use std::any::TypeId;
use std::collections::BTreeSet;
use super::Castable;
trait Object {}
#[derive(PartialEq, Eq, Debug)]
struct MyObjectA {
a: String,
b: i32,
}
#[derive(PartialEq, Eq, Debug)]
struct MyObjectB {
x: u32,
}
crate::impl_castable! {
impl Castable for MyObjectA {
into dyn std::fmt::Debug;
into dyn std::any::Any;
into dyn Object;
}
}
crate::impl_castable! {
impl Castable for MyObjectB {
into dyn Object;
}
}
impl Object for MyObjectA {}
impl Object for MyObjectB {}
#[test]
fn test_casts() {
assert_eq!(MyObjectA::name(), "MyObjectA");
let mut casts = Default::default();
MyObjectA::collect_casts(&mut casts);
MyObjectB::collect_casts(&mut casts);
assert_eq!(
casts.find_key(TypeId::of::<MyObjectA>(), TypeId::of::<MyObjectA>()),
Some(0)
);
assert_eq!(
casts.find_key(TypeId::of::<MyObjectA>(), TypeId::of::<dyn Object>()),
Some(3)
);
assert_eq!(
casts.find_key(TypeId::of::<MyObjectB>(), TypeId::of::<dyn Object>()),
Some(5)
);
assert_eq!(
casts
.find_keys_by_dst(TypeId::of::<dyn Object>())
.collect::<BTreeSet<(TypeId, u32)>>(),
vec![
(TypeId::of::<MyObjectA>(), 3),
(TypeId::of::<MyObjectB>(), 5)
]
.into_iter()
.collect()
);
assert_eq!(
casts
.find_keys_by_src(TypeId::of::<MyObjectA>())
.collect::<BTreeSet<(TypeId, u32)>>(),
vec![
(TypeId::of::<MyObjectA>(), 0),
(TypeId::of::<dyn std::fmt::Debug>(), 1),
(TypeId::of::<dyn std::any::Any>(), 2),
(TypeId::of::<dyn Object>(), 3)
]
.into_iter()
.collect()
);
assert_eq!(
casts
.find_keys_by_src(TypeId::of::<MyObjectB>())
.collect::<BTreeSet<(TypeId, u32)>>(),
vec![
(TypeId::of::<MyObjectB>(), 4),
(TypeId::of::<dyn Object>(), 5)
]
.into_iter()
.collect()
);
}
}