rok-core 0.6.1

Core primitives for the rok ecosystem — errors, crypto, i18n, config, DI, and more
Documentation
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

use super::error::ContainerError;

type AnyArc = Arc<dyn Any + Send + Sync>;

enum Binding {
    Singleton(AnyArc),
    Factory(Box<dyn Fn() -> AnyArc + Send + Sync>),
}

pub struct Container {
    bindings: RwLock<HashMap<TypeId, Binding>>,
}

impl Default for Container {
    fn default() -> Self {
        Self::new()
    }
}

impl Container {
    pub fn new() -> Self {
        Self {
            bindings: RwLock::new(HashMap::new()),
        }
    }

    pub fn bind<T, F>(&self, factory: F)
    where
        T: Any + Send + Sync + 'static,
        F: Fn() -> T + Send + Sync + 'static,
    {
        let mut map = self.bindings.write().expect("container lock poisoned");
        map.insert(
            TypeId::of::<T>(),
            Binding::Factory(Box::new(move || Arc::new(factory()))),
        );
    }

    pub fn singleton<T>(&self, instance: T)
    where
        T: Any + Send + Sync + 'static,
    {
        let mut map = self.bindings.write().expect("container lock poisoned");
        map.insert(TypeId::of::<T>(), Binding::Singleton(Arc::new(instance)));
    }

    pub fn make<T>(&self) -> Result<Arc<T>, ContainerError>
    where
        T: Any + Send + Sync + 'static,
    {
        let map = self.bindings.read().expect("container lock poisoned");
        match map.get(&TypeId::of::<T>()) {
            Some(Binding::Singleton(arc)) => arc
                .clone()
                .downcast::<T>()
                .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
            Some(Binding::Factory(f)) => f()
                .downcast::<T>()
                .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
            None => Err(ContainerError::NotRegistered(std::any::type_name::<T>())),
        }
    }

    pub fn extend<T, F>(&self, extender: F) -> Result<(), ContainerError>
    where
        T: Any + Send + Sync + 'static,
        F: FnOnce(Arc<T>) -> T,
    {
        let existing = self.make::<T>()?;
        let new_instance = extender(existing);
        self.singleton(new_instance);
        Ok(())
    }

    pub fn swap<T>(&self, instance: T)
    where
        T: Any + Send + Sync + 'static,
    {
        self.singleton(instance);
    }
}