use tetcore_std::{
collections::btree_map::{BTreeMap, Entry}, any::{Any, TypeId}, ops::DerefMut, boxed::Box,
};
use crate::Error;
pub trait Extension: Send + Any {
fn as_mut_any(&mut self) -> &mut dyn Any;
}
#[macro_export]
macro_rules! decl_extension {
(
$( #[ $attr:meta ] )*
$vis:vis struct $ext_name:ident ($inner:ty);
) => {
$( #[ $attr ] )*
$vis struct $ext_name (pub $inner);
impl $crate::Extension for $ext_name {
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl std::ops::Deref for $ext_name {
type Target = $inner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for $ext_name {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<$inner> for $ext_name {
fn from(inner: $inner) -> Self {
Self(inner)
}
}
}
}
pub trait ExtensionStore {
fn extension_by_type_id(&mut self, type_id: TypeId) -> Option<&mut dyn Any>;
fn register_extension_with_type_id(&mut self, type_id: TypeId, extension: Box<dyn Extension>) -> Result<(), Error>;
fn deregister_extension_by_type_id(&mut self, type_id: TypeId) -> Result<(), Error>;
}
#[derive(Default)]
pub struct Extensions {
extensions: BTreeMap<TypeId, Box<dyn Extension>>,
}
#[cfg(feature = "std")]
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Extensions: ({})", self.extensions.len())
}
}
impl Extensions {
pub fn new() -> Self {
Self::default()
}
pub fn register<E: Extension>(
&mut self,
ext: E,
) {
let type_id = ext.type_id();
self.extensions.insert(type_id, Box::new(ext));
}
pub fn register_with_type_id(
&mut self,
type_id: TypeId,
extension: Box<dyn Extension>,
) -> Result<(), Error> {
match self.extensions.entry(type_id) {
Entry::Vacant(vacant) => {
vacant.insert(extension);
Ok(())
},
Entry::Occupied(_) => Err(Error::ExtensionAlreadyRegistered),
}
}
pub fn get_mut(&mut self, ext_type_id: TypeId) -> Option<&mut dyn Any> {
self.extensions.get_mut(&ext_type_id).map(DerefMut::deref_mut).map(Extension::as_mut_any)
}
pub fn deregister(&mut self, type_id: TypeId) -> bool {
self.extensions.remove(&type_id).is_some()
}
pub fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = (&'a TypeId, &'a mut Box<dyn Extension>)> {
self.extensions.iter_mut()
}
}
#[cfg(test)]
mod tests {
use super::*;
decl_extension! {
struct DummyExt(u32);
}
decl_extension! {
struct DummyExt2(u32);
}
#[test]
fn register_and_retrieve_extension() {
let mut exts = Extensions::new();
exts.register(DummyExt(1));
exts.register(DummyExt2(2));
let ext = exts.get_mut(TypeId::of::<DummyExt>()).expect("Extension is registered");
let ext_ty = ext.downcast_mut::<DummyExt>().expect("Downcasting works");
assert_eq!(ext_ty.0, 1);
}
}