use std::any::{Any, TypeId, type_name};
use std::collections::HashMap;
use std::fmt;
use crate::task::RegisterableTask;
#[derive(Default)]
pub struct Deps {
map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl Deps {
#[must_use]
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
#[must_use]
pub fn builder() -> DepsBuilder {
DepsBuilder { inner: Deps::new() }
}
#[must_use]
pub fn get<T>(&self) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
self.map
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
.cloned()
}
pub fn try_get<T>(&self) -> Result<T, MissingDep>
where
T: Clone + Send + Sync + 'static,
{
self.get::<T>().ok_or_else(MissingDep::of::<T>)
}
#[must_use]
pub fn expect<T>(&self) -> T
where
T: Clone + Send + Sync + 'static,
{
match self.get::<T>() {
Some(v) => v,
None => missing_panic(type_name::<T>()),
}
}
#[must_use]
pub fn contains<T>(&self) -> bool
where
T: 'static,
{
self.map.contains_key(&TypeId::of::<T>())
}
#[must_use]
pub fn len(&self) -> usize {
self.map.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn merge(&mut self, other: Self) {
self.map.extend(other.map);
}
}
impl fmt::Debug for Deps {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Deps")
.field("registered_types", &self.map.len())
.finish()
}
}
pub struct DepsBuilder {
inner: Deps,
}
impl DepsBuilder {
#[must_use]
pub fn insert<T>(mut self, dep: T) -> Self
where
T: Clone + Send + Sync + 'static,
{
self.inner.map.insert(TypeId::of::<T>(), Box::new(dep));
self
}
#[must_use]
pub fn merge(mut self, other: Deps) -> Self {
self.inner.merge(other);
self
}
#[must_use]
pub fn build(self) -> Deps {
self.inner
}
}
impl Default for DepsBuilder {
fn default() -> Self {
Deps::builder()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MissingDep {
pub type_name: &'static str,
}
impl MissingDep {
#[must_use]
pub fn of<T: ?Sized + 'static>() -> Self {
Self {
type_name: type_name::<T>(),
}
}
}
impl fmt::Display for MissingDep {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"missing dependency `{}` in Deps container",
self.type_name
)
}
}
impl std::error::Error for MissingDep {}
#[cold]
#[inline(never)]
#[allow(clippy::panic)]
fn missing_panic(type_name: &'static str) -> ! {
panic!(
"Deps::expect: missing dependency `{type_name}` (verify_deps should have caught this at workflow build time)"
)
}
pub trait DepsInjectable: RegisterableTask
where
Self::Input: Send + 'static,
Self::Output: Send + 'static,
Self::Future: Send + 'static,
{
fn from_deps(deps: &Deps) -> Self;
fn verify_deps(deps: &Deps) -> ::std::vec::Vec<MissingDep>;
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq)]
struct ServiceA(u32);
#[derive(Clone, Debug, PartialEq, Eq)]
struct ServiceB(&'static str);
#[test]
fn insert_and_get_concrete() {
let deps = Deps::builder().insert(ServiceA(7)).build();
assert_eq!(deps.get::<ServiceA>(), Some(ServiceA(7)));
}
#[test]
fn insert_arc_keeps_arc_key() {
let deps = Deps::builder().insert(Arc::new(ServiceA(7))).build();
assert!(deps.contains::<Arc<ServiceA>>());
assert!(!deps.contains::<ServiceA>());
let resolved: Arc<ServiceA> = deps.expect();
assert_eq!(*resolved, ServiceA(7));
}
#[test]
fn multiple_types_coexist() {
let deps = Deps::builder()
.insert(ServiceA(1))
.insert(ServiceB("hi"))
.build();
assert_eq!(deps.len(), 2);
assert_eq!(deps.get::<ServiceA>(), Some(ServiceA(1)));
assert_eq!(deps.get::<ServiceB>(), Some(ServiceB("hi")));
}
#[test]
fn missing_type_returns_none() {
let deps = Deps::new();
assert!(deps.get::<ServiceA>().is_none());
}
#[test]
fn try_get_reports_type_name() {
let deps = Deps::new();
let err = deps.try_get::<ServiceA>().unwrap_err();
assert!(err.type_name.contains("ServiceA"));
}
#[test]
fn last_insert_wins_for_same_type() {
let deps = Deps::builder()
.insert(ServiceA(1))
.insert(ServiceA(2))
.build();
assert_eq!(deps.get::<ServiceA>(), Some(ServiceA(2)));
}
#[test]
fn expect_returns_value() {
let deps = Deps::builder().insert(ServiceA(42)).build();
let value: ServiceA = deps.expect();
assert_eq!(value, ServiceA(42));
}
#[test]
#[should_panic(expected = "missing dependency")]
fn expect_panics_with_message() {
let deps = Deps::new();
let _: ServiceA = deps.expect();
}
#[test]
fn missing_dep_display() {
let m = MissingDep::of::<ServiceA>();
let rendered = format!("{m}");
assert!(rendered.contains("ServiceA"));
assert!(rendered.contains("missing dependency"));
}
#[test]
fn empty_and_len() {
let mut deps = Deps::new();
assert!(deps.is_empty());
deps = Deps::builder().insert(ServiceA(0)).build();
assert!(!deps.is_empty());
assert_eq!(deps.len(), 1);
}
#[test]
fn merge_non_overlapping() {
let mut base = Deps::builder().insert(ServiceA(1)).build();
let extra = Deps::builder().insert(ServiceB("x")).build();
base.merge(extra);
assert_eq!(base.len(), 2);
assert_eq!(base.get::<ServiceA>(), Some(ServiceA(1)));
assert_eq!(base.get::<ServiceB>(), Some(ServiceB("x")));
}
#[test]
fn merge_overlap_other_wins() {
let mut base = Deps::builder().insert(ServiceA(1)).build();
let extra = Deps::builder().insert(ServiceA(99)).build();
base.merge(extra);
assert_eq!(base.len(), 1);
assert_eq!(base.get::<ServiceA>(), Some(ServiceA(99)));
}
#[test]
fn merge_empty_into_populated() {
let mut base = Deps::builder().insert(ServiceA(1)).build();
base.merge(Deps::new());
assert_eq!(base.len(), 1);
assert_eq!(base.get::<ServiceA>(), Some(ServiceA(1)));
}
#[test]
fn merge_populated_into_empty() {
let mut base = Deps::new();
let extra = Deps::builder().insert(ServiceA(7)).build();
base.merge(extra);
assert_eq!(base.len(), 1);
assert_eq!(base.get::<ServiceA>(), Some(ServiceA(7)));
}
#[test]
fn builder_merge_layers_containers() {
let library = Deps::builder().insert(ServiceA(1)).build();
let combined = Deps::builder()
.insert(ServiceB("local"))
.merge(library)
.build();
assert_eq!(combined.len(), 2);
assert_eq!(combined.get::<ServiceA>(), Some(ServiceA(1)));
assert_eq!(combined.get::<ServiceB>(), Some(ServiceB("local")));
}
#[test]
fn builder_merge_other_wins_on_overlap() {
let library = Deps::builder().insert(ServiceA(2)).build();
let combined = Deps::builder().insert(ServiceA(1)).merge(library).build();
assert_eq!(combined.get::<ServiceA>(), Some(ServiceA(2)));
}
}