ioc/
lib.rs

1use downcast::AnySync;
2use once_cell::sync::OnceCell;
3use variadic_generics::va_expand;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7pub extern crate anyhow;
8
9mod common;
10pub use common::*;
11
12#[cfg(feature = "async")]
13pub mod r#async;
14
15/* Current TODOs:
16 *
17 * 1. write tests
18 * 2. figure out which parts of the library should be hidden
19 * 3. detect and prevent hangups caused by mutual dependencies
20 * 4. do we want TransientInstancers besides SingletonInstancers?
21 */
22
23// Resolve, ResolveStart --------------------------------------------------
24
25pub trait Resolve: Send + Sized + 'static {
26    type Deps: Send;
27    fn resolve(deps: Self::Deps) -> Result<Self>;
28}
29
30/// Careful when using this trait, or you'll be in for a world of stack
31/// overflows.
32pub trait ResolveStart<R>: Sync {
33    fn resolve_start(&self) -> Result<R>;
34}
35
36impl<X: Sync> ResolveStart<()> for X {
37    fn resolve_start(&self) -> Result<()> { Ok(()) }
38}
39
40impl<R, X> ResolveStart<R> for X
41    where R: Resolve, X: ResolveStart<R::Deps>
42{
43    fn resolve_start(&self) -> Result<R> {
44        R::resolve(<X as ResolveStart<R::Deps>>::resolve_start(self)?)
45    }
46}
47
48// tuples
49va_expand!{ ($va_len:tt) ($($va_idents:ident),+) ($($va_indices:tt),+)
50    impl<$($va_idents,)+ X> ResolveStart<($($va_idents,)+)> for X
51    where 
52        $($va_idents: Resolve,)+
53        $(X: ResolveStart<$va_idents::Deps>,)+
54    {
55        fn resolve_start(&self) -> Result<($($va_idents,)+)> { 
56            Ok(($(
57                $va_idents::resolve(<X as ResolveStart<$va_idents::Deps>>::resolve_start(self)?)?,
58            )+))
59        }
60    }
61}
62
63// Middleware --------------------------------------------------
64
65#[derive(Clone)]
66pub struct InstantiationRequest {
67    pub top: Arc<dyn Middleware>,
68    pub service_name: String,
69    pub shadow_levels: HashMap<String, usize>,
70}
71
72impl InstantiationRequest {
73    fn increment_shadow(&mut self, service_name: &str){
74        let level = self.shadow_levels.entry(service_name.to_owned())
75            .or_insert(0);
76        *level += 1;
77    }
78    
79    /// returns true if successfully decremented
80    fn decrement_shadow(&mut self, service_name: &str) -> bool {
81        self.shadow_levels.get_mut(service_name)
82            .map(|level| level.saturating_sub(1))
83            .unwrap_or(1) != 0
84    }
85}
86
87pub trait Middleware: Send + Sync + 'static {
88    fn instantiate(&self, req: InstantiationRequest) -> Result<InstanceRef>;
89}
90
91impl ResolveStart<Arc<dyn Middleware>> for Arc<dyn Middleware> {
92    fn resolve_start(&self) -> Result<Arc<dyn Middleware>> { Ok(self.clone()) }
93}
94
95impl<S> Resolve for TypedInstanceRef<S>
96    where S: Service + ?Sized
97{
98    type Deps = Arc<dyn Middleware>;
99
100    fn resolve(top: Self::Deps) -> Result<Self> {
101        let req = InstantiationRequest{
102            top: top.clone(),
103            service_name: S::service_name(),
104            shadow_levels: Some((S::service_name(), 1)).into_iter().collect(),
105        };
106        top.instantiate(req)?
107            .downcast_arc::<Box<S>>()
108            .map_err(|err| InstanceTypeError::new(S::service_name(), err.type_mismatch()).into())
109    }
110}
111
112// Middleware: ContainerRoot --------------------------------------------------
113
114struct ContainerRoot;
115
116impl Middleware for ContainerRoot {
117    fn instantiate(&self, req: InstantiationRequest) -> Result<InstanceRef> {
118        Err(InstancerNotFoundError::new(req.service_name).into())
119    }
120}
121
122// Middleware: InstancerShadow --------------------------------------------------
123
124struct InstancerShadow {
125    prev: Arc<dyn Middleware>,
126    shadowed_service_name: String
127}
128
129impl InstancerShadow {
130    fn new(prev: Arc<dyn Middleware>, shadowed_service_name: String) -> Self {
131        Self{ prev, shadowed_service_name }
132    }
133}
134
135impl Middleware for InstancerShadow {
136    fn instantiate(&self, mut req: InstantiationRequest) -> Result<InstanceRef> {
137        if self.shadowed_service_name == req.service_name {
138            req.increment_shadow(&self.shadowed_service_name)
139        }
140        self.prev.instantiate(req)
141    }
142}
143
144// Middleware: SingletonInstancer  --------------------------------------------------
145
146#[allow(type_alias_bounds)]
147type CreationFn<T: ?Sized> = Arc<dyn (Fn(&Arc<dyn Middleware>) -> Result<Box<T>>) + Send + Sync>;
148
149struct SingletonInstancer<T: ?Sized> {
150    prev: Arc<dyn Middleware>,
151    creation_fn: CreationFn<T>,
152    #[allow(clippy::redundant_allocation)]
153    instance: OnceCell<Arc<Box<T>>>,
154    service_name: String,
155}
156
157impl<T> SingletonInstancer<T>
158    where T: Service + ?Sized
159{
160    fn new(prev: Arc<dyn Middleware>, creation_fn: CreationFn<T>) -> Self {
161        let service_name = T::service_name();
162        Self{ prev, creation_fn, instance: OnceCell::new(), service_name } 
163    }
164}
165
166impl<T> Middleware for SingletonInstancer<T>
167    where T: Service + ?Sized
168{
169    fn instantiate(&self, mut req: InstantiationRequest) -> Result<InstanceRef> {
170        // if different service or shadowed, pass request (with one less shadow level) up the chain
171        if req.service_name != self.service_name
172        || req.decrement_shadow(&self.service_name)
173        {
174    	    return self.prev.instantiate(req)
175        }
176        
177        // increase shadow level
178        req.increment_shadow(&self.service_name);
179        let shadowed_top: Arc<dyn Middleware> = Arc::new(InstancerShadow::new(req.top, self.service_name.clone()));
180        
181        // recall or create instance
182        self.instance.get_or_try_init(move || (self.creation_fn)(&shadowed_top).map(Arc::new))
183            .map(|inst| inst.clone() as Arc<dyn AnySync>)
184            .map_err(|err| InstanceCreationError::new(self.service_name.clone(), err).into())
185    }
186}
187
188// Container --------------------------------------------------
189
190#[derive(Clone)]
191pub struct Container {
192    top: Arc<dyn Middleware>,
193}
194
195impl Resolve for Container {
196    type Deps = Arc<dyn Middleware>;
197    fn resolve(top: Self::Deps) -> Result<Self> {
198        Ok(Container{ top })
199    }
200}
201
202impl Default for Container {
203    fn default() -> Self{ Self::new(Arc::new(ContainerRoot)) }
204}
205
206impl Container {
207    pub fn new(top: Arc<dyn Middleware>) -> Self {
208        Self{ top }
209    }
210
211    pub fn with_singleton<S, Args, F>(&self, creation_fn: F) -> Self
212    where 
213        S: Service + ?Sized,
214        Arc<dyn Middleware>: ResolveStart<Args>,
215        F: Fn(Args) -> Result<Box<S>> + Send + Sync + 'static
216    {
217    	let creation_fn: CreationFn<S> = Arc::new(move |mw: &Arc<dyn Middleware>| {
218            creation_fn(mw.resolve_start()?)
219        });
220        Self::new(Arc::new(SingletonInstancer::new(self.top.clone(), creation_fn)))
221    }
222
223    pub fn with_singleton_ok<S, Args, F>(&self, creation_fn: F) -> Self
224    where
225        S: Service + ?Sized,
226        Arc<dyn Middleware>: ResolveStart<Args>,
227        F: Fn(Args) -> Box<S> + Send + Sync + 'static
228    {
229    	let creation_fn: CreationFn<S> = Arc::new(move |mw: &Arc<dyn Middleware>| {
230            Ok(creation_fn(mw.resolve_start()?))
231        });
232        Self::new(Arc::new(SingletonInstancer::new(self.top.clone(), creation_fn)))
233    }
234
235    pub fn resolve<X>(&self) -> Result<X>
236        where Arc<dyn Middleware>: ResolveStart<X>
237    {
238        self.top.resolve_start()
239    }
240}
241
242pub fn container() -> Container { Default::default() }