use std::any;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::ops::{Index, IndexMut};
use crate::prelude::{ProjectError, ProjectResult};
use thiserror::Error;
pub trait Extension: 'static + Send + Sync {}
impl<E: 'static + Send + Sync> Extension for E {}
pub trait ExtensionAware {
fn extensions(&self) -> &ExtensionContainer;
fn extensions_mut(&mut self) -> &mut ExtensionContainer;
fn extension<E: Extension>(&self) -> ProjectResult<&E> {
self.extensions().get_by_type()
}
fn extension_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
self.extensions_mut().get_by_type_mut()
}
}
type AnyExtension = Box<dyn Any + Send + Sync>;
#[derive(Default)]
pub struct ExtensionContainer {
ob_map: HashMap<String, AnyExtension>,
}
impl ExtensionContainer {
pub fn add<E: Extension, S: AsRef<str>>(
&mut self,
name: S,
value: E,
) -> Result<(), ExtensionError> {
let name = name.as_ref();
if self.ob_map.contains_key(name) {
return Err(ExtensionError::AlreadyRegistered(name.to_string()));
}
let boxed = Box::new(value) as AnyExtension;
self.ob_map.insert(name.to_string(), boxed);
Ok(())
}
pub fn get<S: AsRef<str>>(&self, name: S) -> ProjectResult<&AnyExtension> {
self.ob_map
.get(name.as_ref())
.ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
}
pub fn get_mut<S: AsRef<str>>(&mut self, name: S) -> ProjectResult<&mut AnyExtension> {
self.ob_map
.get_mut(name.as_ref())
.ok_or(ProjectError::ExtensionNotRegistered(name.as_ref().to_string()).into())
}
pub fn get_by_type<E: Extension>(&self) -> ProjectResult<&E> {
let mut output: Vec<&E> = vec![];
for value in self.ob_map.values() {
if let Some(ext) = value.downcast_ref::<E>() {
output.push(ext);
}
}
match output.len() {
1 => Ok(output.remove(0)),
_ => {
Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
}
}
}
pub fn get_by_type_mut<E: Extension>(&mut self) -> ProjectResult<&mut E> {
let mut output: Vec<String> = vec![];
for (name, ext) in &self.ob_map {
if ext.is::<E>() {
output.push(name.clone());
}
}
match output.len() {
1 => {
let index = output.remove(0);
self.ob_map
.get_mut(&index)
.and_then(|b| b.downcast_mut())
.ok_or_else(|| unreachable!())
}
_ => {
Err(ProjectError::ExtensionNotRegistered(any::type_name::<E>().to_string()).into())
}
}
}
}
impl Index<&str> for ExtensionContainer {
type Output = AnyExtension;
fn index(&self, index: &str) -> &Self::Output {
self.get(index).unwrap()
}
}
impl IndexMut<&str> for ExtensionContainer {
fn index_mut(&mut self, index: &str) -> &mut Self::Output {
self.get_mut(index).unwrap()
}
}
impl Index<String> for ExtensionContainer {
type Output = AnyExtension;
fn index(&self, index: String) -> &Self::Output {
self.get(index).unwrap()
}
}
impl IndexMut<String> for ExtensionContainer {
fn index_mut(&mut self, index: String) -> &mut Self::Output {
self.get_mut(index).unwrap()
}
}
impl Debug for ExtensionContainer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtensionContainer").finish()
}
}
#[derive(Debug, Error)]
pub enum ExtensionError {
#[error("Extension with name {0:?} already registered")]
AlreadyRegistered(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn use_extensions() {
let mut ext = ExtensionContainer::default();
ext.add("test", String::from("Hello, World")).unwrap();
let value = ext.get("test").unwrap().downcast_ref::<String>().unwrap();
assert_eq!(value, "Hello, World")
}
#[test]
fn disallow_same_name_extensions() {
let mut ext = ExtensionContainer::default();
ext.add("test", String::from("Hello, World")).unwrap();
assert!(matches!(
ext.add("test", String::from("Hello, World")),
Err(ExtensionError::AlreadyRegistered(_))
));
}
}