ioc/
async.rs

1use super::common::*;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use async_trait::async_trait;
7use variadic_generics::va_expand;
8
9// Resolve, ResolveStart --------------------------------------------------
10
11#[async_trait]
12pub trait Resolve: Send + Sized + 'static {
13    type Deps: Send;
14    async fn resolve(deps: Self::Deps) -> Result<Self>;
15}
16
17#[async_trait]
18pub trait ResolveStart<R>: Sync {
19    async fn resolve_start(&self) -> Result<R>;
20}
21
22#[async_trait]
23impl<X: Sync> ResolveStart<()> for X {
24    async fn resolve_start(&self) -> Result<()> { Ok(()) }
25}
26
27
28#[async_trait]
29impl<R, X> ResolveStart<R> for X
30    where R: Resolve, X: ResolveStart<R::Deps>
31{
32    async fn resolve_start(&self) -> Result<R> {
33        R::resolve(<X as ResolveStart<R::Deps>>::resolve_start(self).await?).await
34    }
35}
36
37// tuples
38va_expand!{ ($va_len:tt) ($($va_idents:ident),+) ($($va_indices:tt),+)
39    #[async_trait]
40    impl<$($va_idents,)+ X> ResolveStart<($($va_idents,)+)> for X
41    where 
42        $($va_idents: Resolve,)+
43        $(X: ResolveStart<$va_idents::Deps>,)+
44    {
45        async fn resolve_start(&self) -> Result<($($va_idents,)+)> { 
46            Ok(($(
47                $va_idents::resolve(<X as ResolveStart<$va_idents::Deps>>::resolve_start(self).await?).await?,
48            )+))
49        }
50    }
51}
52
53// Middleware --------------------------------------------------
54
55#[derive(Clone)]
56pub struct InstantiationRequest {
57    pub top: Arc<dyn Middleware>,
58    pub service_name: String,
59    pub shadow_levels: HashMap<String, usize>,
60}
61
62impl InstantiationRequest {
63    fn increment_shadow(&mut self, service_name: &str){
64        let level = self.shadow_levels.entry(service_name.to_owned())
65            .or_insert(0);
66        *level += 1;
67    }
68    
69    /// returns true if successfully decremented
70    fn decrement_shadow(&mut self, service_name: &str) -> bool {
71        self.shadow_levels.get_mut(service_name)
72            .map(|level| level.saturating_sub(1))
73            .unwrap_or(1) != 0
74    }
75}
76
77#[async_trait]
78pub trait Middleware: Send + Sync + 'static {
79    async fn instantiate(&self, req: InstantiationRequest) -> Result<InstanceRef>;
80}
81
82#[async_trait]
83impl ResolveStart<Arc<dyn Middleware>> for Arc<dyn Middleware> {
84    async fn resolve_start(&self) -> Result<Arc<dyn Middleware>> { Ok(self.clone()) }
85}
86
87#[async_trait]
88impl<S> Resolve for TypedInstanceRef<S>
89    where S: Service + ?Sized
90{
91    type Deps = Arc<dyn Middleware>;
92
93    async fn resolve(top: Self::Deps) -> Result<Self> {
94        let req = InstantiationRequest{
95            top: top.clone(),
96            service_name: S::service_name(),
97            shadow_levels: Some((S::service_name(), 1)).into_iter().collect(),
98        };
99        top.instantiate(req).await?
100            .downcast_arc::<Box<S>>()
101            .map_err(|err| InstanceTypeError::new(S::service_name(), err.type_mismatch()).into())
102    }
103}
104
105// Middleware: ContainerRoot --------------------------------------------------
106
107struct ContainerRoot;
108
109#[async_trait]
110impl Middleware for ContainerRoot {
111    async fn instantiate(&self, req: InstantiationRequest) -> Result<InstanceRef> {
112        Err(InstancerNotFoundError::new(req.service_name).into())
113    }
114}
115
116// Middleware: InstancerShadow --------------------------------------------------
117
118struct InstancerShadow {
119    prev: Arc<dyn Middleware>,
120    shadowed_service_name: String
121}
122
123impl InstancerShadow {
124    fn new(prev: Arc<dyn Middleware>, shadowed_service_name: String) -> Self {
125        Self{ prev, shadowed_service_name }
126    }
127}
128
129#[async_trait]
130impl Middleware for InstancerShadow {
131    async fn instantiate(&self, mut req: InstantiationRequest) -> Result<InstanceRef> {
132        if self.shadowed_service_name == req.service_name {
133            req.increment_shadow(&self.shadowed_service_name)
134        }
135        self.prev.instantiate(req).await
136    }
137}
138
139// Middleware: SingletonInstancer  --------------------------------------------------
140
141#[allow(type_alias_bounds)]
142type CreationFn<T: ?Sized> = Arc<dyn (Fn(&'_ Arc<dyn Middleware>) -> Pin<Box<dyn Future<Output = Result<Box<T>>> + Send + '_>>) + Send + Sync>;
143
144struct SingletonInstancer<T: ?Sized> {
145    prev: Arc<dyn Middleware>,
146    creation_fn: CreationFn<T>,
147    #[allow(clippy::redundant_allocation)]
148    instance: futures::lock::Mutex<Option<Arc<Box<T>>>>,
149    service_name: String,
150}
151
152impl<T> SingletonInstancer<T>
153    where T: Service + ?Sized
154{
155    fn new(prev: Arc<dyn Middleware>, creation_fn: CreationFn<T>) -> Self {
156        let service_name = T::service_name();
157        Self{ prev, creation_fn, instance: futures::lock::Mutex::new(None), service_name } 
158    }
159}
160
161#[async_trait]
162impl<T> Middleware for SingletonInstancer<T>
163    where T: Service + ?Sized
164{
165    async fn instantiate(&self, mut req: InstantiationRequest) -> Result<InstanceRef> {
166        // if different service or shadowed, pass request (with one less shadow level) up the chain
167        if req.service_name != self.service_name
168        || req.decrement_shadow(&self.service_name)
169        {
170    	    return self.prev.instantiate(req).await
171        }
172        
173        // increase shadow level
174        req.increment_shadow(&self.service_name);
175        let shadowed_top: Arc<dyn Middleware> = Arc::new(InstancerShadow::new(req.top, self.service_name.clone()));
176        
177        // recall or create instance
178        let mut guard = self.instance.lock().await;
179        if guard.is_none() {
180            let inst = (self.creation_fn)(&shadowed_top).await
181                .map(Arc::new)
182                .map_err(|err| InstanceCreationError::new(self.service_name.clone(), err))?;
183            *guard = Some(inst);
184        }
185        Ok(guard.as_ref().cloned().unwrap())
186    }
187}
188
189// Container --------------------------------------------------
190
191#[derive(Clone)]
192pub struct Container {
193    top: Arc<dyn Middleware>,
194}
195
196#[async_trait]
197impl Resolve for Container {
198    type Deps = Arc<dyn Middleware>;
199    async fn resolve(top: Self::Deps) -> Result<Self> {
200        Ok(Container{ top })
201    }
202}
203
204impl Default for Container {
205    fn default() -> Self{ Self::new(Arc::new(ContainerRoot)) }
206}
207
208impl Container {
209    pub fn new(top: Arc<dyn Middleware>) -> Self {
210        Self{ top }
211    }
212
213    pub fn with_singleton<S, Args, F>(&self, creation_fn: F) -> Self
214    where 
215        S: Service + ?Sized,
216        Arc<dyn Middleware>: ResolveStart<Args>,
217        Args: Send,
218        F: Fn(Args) -> Result<Box<S>> + Send + Sync + Copy + 'static
219    {
220    	let creation_fn: CreationFn<S> = Arc::new(move |mw| {
221            Box::pin(async move { creation_fn(mw.resolve_start().await?) })
222        });
223        Self::new(Arc::new(SingletonInstancer::new(self.top.clone(), creation_fn)))
224    }
225
226    pub fn with_singleton_ok<S, Args, F>(&self, creation_fn: F) -> Self
227    where
228        S: Service + ?Sized,
229        Arc<dyn Middleware>: ResolveStart<Args>,
230        F: Fn(Args) -> Box<S> + Send + Sync + Copy + 'static
231    {
232    	let creation_fn: CreationFn<S> = Arc::new(move |mw| {
233            Box::pin(async move { Ok(creation_fn(mw.resolve_start().await?)) })
234        });
235        Self::new(Arc::new(SingletonInstancer::new(self.top.clone(), creation_fn)))
236    }
237
238    pub fn with_singleton_async<S, Args, Fut, F>(&self, creation_fn: F) -> Self
239    where 
240        S: Service + ?Sized,
241        Arc<dyn Middleware>: ResolveStart<Args>,
242        Args: Send,
243        Fut: Future<Output = Result<Box<S>>> + Send,
244        F: Fn(Args) -> Fut + Send + Sync + Copy + 'static
245    {
246    	let creation_fn: CreationFn<S> = Arc::new(move |mw| {
247            Box::pin(async move { creation_fn(mw.resolve_start().await?).await })
248        });
249        Self::new(Arc::new(SingletonInstancer::<S>::new(self.top.clone(), creation_fn)))
250    }
251
252    pub fn with_singleton_async_ok<S, Args, Fut, F>(&self, creation_fn: F) -> Self
253    where 
254        S: Service + ?Sized,
255        Arc<dyn Middleware>: ResolveStart<Args>,
256        Args: Send,
257        Fut: Future<Output = Box<S>> + Send,
258        F: Fn(Args) -> Fut + Send + Sync + Copy + 'static
259    {
260    	let creation_fn: CreationFn<S> = Arc::new(move |mw| {
261            Box::pin(async move { Ok(creation_fn(mw.resolve_start().await?).await) })
262        });
263        Self::new(Arc::new(SingletonInstancer::<S>::new(self.top.clone(), creation_fn)))
264    }
265
266    pub async fn resolve<X>(&self) -> Result<X>
267        where Arc<dyn Middleware>: ResolveStart<X>
268    {
269        self.top.resolve_start().await
270    }
271}
272
273pub fn container() -> Container { Default::default() }