Skip to main content

ruest/di/
provider.rs

1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7
8use super::Container;
9use super::Scope;
10
11/// Describes how a type is provided to the DI container.
12#[derive(Clone)]
13pub struct ProviderDescriptor {
14    pub type_name: &'static str,
15    pub scope: Scope,
16    pub factory: Arc<dyn ProviderFactory>,
17}
18
19/// Async factory that builds a service instance.
20pub type AsyncFactoryFn =
21    dyn Fn(Container) -> Pin<Box<dyn Future<Output = Arc<dyn Any + Send + Sync>> + Send>>
22        + Send
23        + Sync;
24
25/// Synchronous factory for simple services.
26pub type SyncFactoryFn = dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync;
27
28pub trait ProviderFactory: Send + Sync {
29    fn create_sync(&self, _container: &Container) -> Option<Arc<dyn Any + Send + Sync>> {
30        None
31    }
32
33    fn create_async(
34        &self,
35        _container: &Container,
36    ) -> Option<Pin<Box<dyn Future<Output = Arc<dyn Any + Send + Sync>> + Send>>> {
37        None
38    }
39}
40
41struct SyncFactoryWrapper<F>(F);
42
43impl<F> ProviderFactory for SyncFactoryWrapper<F>
44where
45    F: Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync,
46{
47    fn create_sync(&self, _container: &Container) -> Option<Arc<dyn Any + Send + Sync>> {
48        Some((self.0)())
49    }
50}
51
52struct AsyncFactoryWrapper<F>(F);
53
54impl<F> ProviderFactory for AsyncFactoryWrapper<F>
55where
56    F: Fn(Container) -> Pin<Box<dyn Future<Output = Arc<dyn Any + Send + Sync>> + Send>>
57        + Send
58        + Sync,
59{
60    fn create_async(
61        &self,
62        container: &Container,
63    ) -> Option<Pin<Box<dyn Future<Output = Arc<dyn Any + Send + Sync>> + Send>>> {
64        Some((self.0)(container.clone()))
65    }
66}
67
68/// Trait implemented by types that register themselves as providers.
69#[async_trait]
70pub trait Provider: Send + Sync + 'static {
71    type Output: Send + Sync + 'static;
72
73    async fn provide(container: &Container) -> Arc<Self::Output>;
74}
75
76/// Factory-based provider registration.
77pub struct FactoryProvider<T> {
78    _marker: std::marker::PhantomData<T>,
79}
80
81impl<T: Send + Sync + 'static> FactoryProvider<T> {
82    pub fn sync<F>(type_name: &'static str, scope: Scope, factory: F) -> ProviderDescriptor
83    where
84        F: Fn() -> Arc<T> + Send + Sync + 'static,
85    {
86        ProviderDescriptor {
87            type_name,
88            scope,
89            factory: Arc::new(SyncFactoryWrapper(move || {
90                let value: Arc<T> = factory();
91                value as Arc<dyn Any + Send + Sync>
92            })),
93        }
94    }
95
96    pub fn from_instance(type_name: &'static str, instance: Arc<T>) -> ProviderDescriptor {
97        let instance_any: Arc<dyn Any + Send + Sync> = instance;
98        ProviderDescriptor {
99            type_name,
100            scope: Scope::Singleton,
101            factory: Arc::new(SyncFactoryWrapper(move || Arc::clone(&instance_any))),
102        }
103    }
104
105    pub fn async_factory<F, Fut>(
106        type_name: &'static str,
107        scope: Scope,
108        factory: F,
109    ) -> ProviderDescriptor
110    where
111        F: Fn(Container) -> Fut + Send + Sync + 'static,
112        Fut: Future<Output = Arc<T>> + Send + 'static,
113    {
114        ProviderDescriptor {
115            type_name,
116            scope,
117            factory: Arc::new(AsyncFactoryWrapper(
118                move |container| -> Pin<Box<dyn Future<Output = Arc<dyn Any + Send + Sync>> + Send>> {
119                    let fut = factory(container);
120                    Box::pin(async move {
121                        let value: Arc<T> = fut.await;
122                        value as Arc<dyn Any + Send + Sync>
123                    })
124                },
125            )),
126        }
127    }
128}
129
130/// Register a type that implements `Default` as a singleton provider.
131pub fn default_provider<T>(type_name: &'static str, scope: Scope) -> ProviderDescriptor
132where
133    T: Default + Send + Sync + 'static,
134{
135    FactoryProvider::<T>::sync(type_name, scope, || Arc::new(T::default()))
136}