use std::any::{Any, TypeId, type_name};
use std::collections::HashMap;
use std::ops::Deref;
use crate::error::EventBusError;
#[derive(Default)]
pub struct Deps {
map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl Deps {
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: Send + Sync + 'static>(mut self, val: T) -> Self {
self.map.insert(TypeId::of::<T>(), Box::new(val));
self
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.map.get(&TypeId::of::<T>()).and_then(|v| v.downcast_ref::<T>())
}
pub fn get_required<T: Send + Sync + 'static>(&self) -> Result<&T, EventBusError> {
self.get::<T>()
.ok_or_else(|| EventBusError::MissingDependency(type_name::<T>().to_owned()))
}
}
impl std::fmt::Debug for Deps {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Deps").field("count", &self.map.len()).finish()
}
}
pub struct Dep<T>(pub T);
impl<T> Deref for Dep<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Dep<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Dep").field(&self.0).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_and_get() {
struct Foo(u32);
let deps = Deps::new().insert(Foo(42));
assert_eq!(deps.get::<Foo>().expect("Foo present").0, 42);
}
#[test]
fn get_missing_returns_none() {
struct Bar;
let deps = Deps::new();
assert!(deps.get::<Bar>().is_none());
}
#[test]
fn get_required_missing_returns_error() {
#[derive(Debug)]
struct Baz;
let deps = Deps::new();
match deps.get_required::<Baz>() {
Err(EventBusError::MissingDependency(name)) => {
assert!(name.contains("Baz"), "expected type name in error, got: {name}");
}
other => panic!("expected MissingDependency, got {other:?}"),
}
}
#[test]
fn insert_overwrites_same_type() {
struct Num(u32);
let deps = Deps::new().insert(Num(1)).insert(Num(2));
assert_eq!(deps.get::<Num>().expect("Num present").0, 2);
}
#[test]
fn multiple_types_independent() {
struct A(u32);
struct B(&'static str);
let deps = Deps::new().insert(A(10)).insert(B("hello"));
assert_eq!(deps.get::<A>().unwrap().0, 10);
assert_eq!(deps.get::<B>().unwrap().0, "hello");
}
#[test]
fn dep_deref() {
let d = Dep(99u32);
assert_eq!(*d, 99u32);
}
}