use crate::Injectable;
use once_cell::sync::OnceCell;
use std::any::Any;
use std::sync::Arc;
#[cfg(feature = "logging")]
use tracing::{debug, trace};
pub trait Factory: Send + Sync {
fn resolve(&self) -> Arc<dyn Any + Send + Sync>;
fn is_transient(&self) -> bool {
false
}
}
pub struct SingletonFactory {
pub(crate) instance: Arc<dyn Any + Send + Sync>,
}
impl SingletonFactory {
#[inline]
pub fn new<T: Injectable>(instance: T) -> Self {
Self {
instance: Arc::new(instance) as Arc<dyn Any + Send + Sync>,
}
}
#[inline]
pub fn from_arc<T: Injectable>(instance: Arc<T>) -> Self {
Self {
instance: instance as Arc<dyn Any + Send + Sync>,
}
}
#[inline]
pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
Arc::clone(&self.instance)
}
}
impl Factory for SingletonFactory {
#[inline]
fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
self.resolve()
}
}
type LazyInitFn = Arc<dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync>;
pub struct LazyFactory {
init: LazyInitFn,
instance: OnceCell<Arc<dyn Any + Send + Sync>>,
#[cfg(feature = "logging")]
type_name: &'static str,
}
impl LazyFactory {
#[inline]
pub fn new<T: Injectable, F>(factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
init: Arc::new(move || Arc::new(factory()) as Arc<dyn Any + Send + Sync>),
instance: OnceCell::new(),
#[cfg(feature = "logging")]
type_name: std::any::type_name::<T>(),
}
}
#[inline]
pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
#[cfg(feature = "logging")]
let was_empty = self.instance.get().is_none();
let result = Arc::clone(self.instance.get_or_init(|| {
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = self.type_name,
"Lazy singleton initializing on first access"
);
(self.init)()
}));
#[cfg(feature = "logging")]
if !was_empty {
trace!(
target: "dependency_injector",
service = self.type_name,
"Lazy singleton already initialized, returning cached instance"
);
}
result
}
}
impl Factory for LazyFactory {
#[inline]
fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
self.resolve()
}
}
type TransientFn = Arc<dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync>;
pub struct TransientFactory {
factory: TransientFn,
#[cfg(feature = "logging")]
type_name: &'static str,
}
impl TransientFactory {
#[inline]
pub fn new<T: Injectable, F>(factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
factory: Arc::new(move || Arc::new(factory()) as Arc<dyn Any + Send + Sync>),
#[cfg(feature = "logging")]
type_name: std::any::type_name::<T>(),
}
}
#[inline]
pub fn create(&self) -> Arc<dyn Any + Send + Sync> {
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = self.type_name,
"Creating new transient instance"
);
(self.factory)()
}
}
impl Factory for TransientFactory {
#[inline]
fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
self.create()
}
#[inline]
fn is_transient(&self) -> bool {
true
}
}
pub(crate) enum AnyFactory {
Singleton(SingletonFactory),
Lazy(LazyFactory),
Transient(TransientFactory),
}
impl Clone for AnyFactory {
fn clone(&self) -> Self {
match self {
AnyFactory::Singleton(f) => AnyFactory::Singleton(SingletonFactory {
instance: Arc::clone(&f.instance),
}),
AnyFactory::Lazy(f) => {
let instance = f.resolve();
AnyFactory::Singleton(SingletonFactory { instance })
}
AnyFactory::Transient(f) => AnyFactory::Transient(TransientFactory {
factory: Arc::clone(&f.factory),
#[cfg(feature = "logging")]
type_name: f.type_name,
}),
}
}
}
impl AnyFactory {
#[inline]
pub fn singleton<T: Injectable>(instance: T) -> Self {
AnyFactory::Singleton(SingletonFactory::new(instance))
}
#[inline]
pub fn lazy<T: Injectable, F>(factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
AnyFactory::Lazy(LazyFactory::new(factory))
}
#[inline]
pub fn transient<T: Injectable, F>(factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
AnyFactory::Transient(TransientFactory::new(factory))
}
#[inline]
pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
match self {
AnyFactory::Singleton(f) => f.resolve(),
AnyFactory::Lazy(f) => f.resolve(),
AnyFactory::Transient(f) => f.create(),
}
}
#[inline]
pub fn is_transient(&self) -> bool {
matches!(self, AnyFactory::Transient(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Clone)]
struct TestService {
id: u32,
}
#[test]
fn test_singleton_factory() {
let factory = AnyFactory::singleton(TestService { id: 42 });
let a = factory.resolve();
let b = factory.resolve();
let a = a.downcast::<TestService>().unwrap();
let b = b.downcast::<TestService>().unwrap();
assert_eq!(a.id, 42);
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn test_lazy_factory() {
static COUNTER: AtomicU32 = AtomicU32::new(0);
let factory = AnyFactory::lazy(|| TestService {
id: COUNTER.fetch_add(1, Ordering::SeqCst),
});
assert_eq!(COUNTER.load(Ordering::SeqCst), 0);
let a = factory.resolve().downcast::<TestService>().unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(a.id, 0);
let b = factory.resolve().downcast::<TestService>().unwrap();
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn test_transient_factory() {
static COUNTER: AtomicU32 = AtomicU32::new(0);
let factory = AnyFactory::transient(|| TestService {
id: COUNTER.fetch_add(1, Ordering::SeqCst),
});
let a = factory.resolve().downcast::<TestService>().unwrap();
let b = factory.resolve().downcast::<TestService>().unwrap();
assert_eq!(a.id, 0);
assert_eq!(b.id, 1);
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn test_is_transient() {
let singleton = AnyFactory::singleton(TestService { id: 1 });
let lazy = AnyFactory::lazy(|| TestService { id: 2 });
let transient = AnyFactory::transient(|| TestService { id: 3 });
assert!(!singleton.is_transient());
assert!(!lazy.is_transient());
assert!(transient.is_transient());
}
}