use serde::{de, Serialize};
use std::{
any::{Any, TypeId},
collections::HashMap,
ops::Deref,
sync::Arc,
};
#[derive(Default)]
pub struct Context {
map: HashMap<TypeId, Box<dyn Any>>,
}
impl Context {
pub fn insert<T: 'static>(&mut self, state: Inject<T>) {
self.map.insert(TypeId::of::<Inject<T>>(), Box::new(state));
}
pub fn get<T: 'static>(&self) -> Option<&T> {
self.map
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref())
}
}
pub trait FromContext {
fn from_context(ctx: &Context) -> Self;
}
pub struct Inject<T: ?Sized>(Arc<T>);
impl<T> Inject<T> {
pub fn new(state: T) -> Inject<T> {
Inject(Arc::new(state))
}
}
impl<T: ?Sized> Deref for Inject<T> {
type Target = Arc<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T: ?Sized> Clone for Inject<T> {
fn clone(&self) -> Inject<T> {
Inject(Arc::clone(&self.0))
}
}
impl<T: Default> Default for Inject<T> {
fn default() -> Self {
Inject::new(T::default())
}
}
impl<T> Serialize for Inject<T>
where
T: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}
impl<'de, T> de::Deserialize<'de> for Inject<T>
where
T: de::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
Ok(Inject::new(T::deserialize(deserializer)?))
}
}
impl<T: 'static> FromContext for Inject<T> {
fn from_context(ctx: &Context) -> Self {
if let Some(obj) = ctx.get::<Inject<T>>() {
obj.clone()
} else {
panic!("Tried to inject an object not in the MCPService's state!")
}
}
}