Skip to main content

nidus_core/provider/
mod.rs

1//! Provider registration primitives.
2
3use std::{
4    any::{Any, TypeId},
5    panic::{AssertUnwindSafe, catch_unwind},
6    sync::{Arc, Condvar, Mutex, MutexGuard},
7};
8
9use crate::{Container, NidusError, RequestScope, Result, resolution};
10
11/// Provider creation and reuse strategy.
12#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13pub enum ProviderLifetime {
14    /// Create once and reuse for all resolutions.
15    Singleton,
16    /// Create a fresh value on every resolution.
17    Transient,
18    /// Create per request when request scopes are enabled.
19    Request,
20}
21
22/// Marker trait for injectable provider values.
23pub trait Provider: Send + Sync + 'static {}
24
25impl<T> Provider for T where T: Send + Sync + 'static {}
26
27type ErasedProvider = dyn Any + Send + Sync;
28type ProviderFactory = dyn Fn(&Container) -> Result<Arc<ErasedProvider>> + Send + Sync;
29type RequestProviderFactory =
30    dyn for<'scope> Fn(&RequestScope<'scope>) -> Result<Arc<ErasedProvider>> + Send + Sync;
31
32/// A typed provider registration stored by the container.
33pub struct ProviderEntry {
34    type_id: TypeId,
35    type_name: &'static str,
36    lifetime: ProviderLifetime,
37    factory: Arc<ProviderFactory>,
38    request_factory: Option<Arc<RequestProviderFactory>>,
39    singleton: Mutex<SingletonState>,
40    singleton_ready: Condvar,
41}
42
43enum SingletonState {
44    Empty,
45    Initializing,
46    Ready(Arc<ErasedProvider>),
47}
48
49impl ProviderEntry {
50    /// Creates a provider entry from an erased factory.
51    pub fn new(
52        type_id: TypeId,
53        type_name: &'static str,
54        lifetime: ProviderLifetime,
55        factory: Arc<ProviderFactory>,
56    ) -> Self {
57        Self {
58            type_id,
59            type_name,
60            lifetime,
61            factory,
62            request_factory: None,
63            singleton: Mutex::new(SingletonState::Empty),
64            singleton_ready: Condvar::new(),
65        }
66    }
67
68    /// Creates a request-scoped provider entry from an erased request-scope factory.
69    pub fn new_request_scoped(
70        type_id: TypeId,
71        type_name: &'static str,
72        factory: Arc<ProviderFactory>,
73        request_factory: Arc<RequestProviderFactory>,
74    ) -> Self {
75        Self {
76            type_id,
77            type_name,
78            lifetime: ProviderLifetime::Request,
79            factory,
80            request_factory: Some(request_factory),
81            singleton: Mutex::new(SingletonState::Empty),
82            singleton_ready: Condvar::new(),
83        }
84    }
85
86    /// Returns the registered provider type name.
87    pub fn type_name(&self) -> &'static str {
88        self.type_name
89    }
90
91    /// Returns the configured provider lifetime.
92    pub fn lifetime(&self) -> ProviderLifetime {
93        self.lifetime
94    }
95
96    pub(crate) fn resolve_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
97        match self.lifetime {
98            ProviderLifetime::Singleton => self.resolve_singleton(container),
99            ProviderLifetime::Transient | ProviderLifetime::Request => {
100                self.create_erased(container)
101            }
102        }
103    }
104
105    pub(crate) fn resolve_erased_in_scope(
106        &self,
107        scope: &RequestScope<'_>,
108    ) -> Result<Arc<ErasedProvider>> {
109        match self.lifetime {
110            ProviderLifetime::Request => self.create_erased_in_scope(scope),
111            ProviderLifetime::Singleton | ProviderLifetime::Transient => {
112                self.resolve_erased(scope.container())
113            }
114        }
115    }
116
117    fn create_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
118        (self.factory)(container).map_err(|source| NidusError::ProviderFactory {
119            type_name: self.type_name,
120            source: Box::new(source),
121        })
122    }
123
124    fn resolve_singleton(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
125        loop {
126            let mut singleton = lock_unpoisoned(&self.singleton);
127            match &*singleton {
128                SingletonState::Ready(instance) => return Ok(Arc::clone(instance)),
129                SingletonState::Initializing => {
130                    if resolution::is_active(self.type_id) {
131                        return Err(NidusError::CircularProviderResolution {
132                            type_name: self.type_name,
133                        });
134                    }
135                    drop(wait_unpoisoned(&self.singleton_ready, singleton));
136                }
137                SingletonState::Empty => {
138                    let _guard = resolution::enter(self.type_id, self.type_name)?;
139                    *singleton = SingletonState::Initializing;
140                    drop(singleton);
141
142                    let instance =
143                        match catch_unwind(AssertUnwindSafe(|| self.create_erased(container))) {
144                            Ok(outcome) => outcome,
145                            Err(panic_payload) => {
146                                let mut singleton = lock_unpoisoned(&self.singleton);
147                                *singleton = SingletonState::Empty;
148                                self.singleton_ready.notify_all();
149                                drop(singleton);
150                                std::panic::resume_unwind(panic_payload);
151                            }
152                        };
153                    let mut singleton = lock_unpoisoned(&self.singleton);
154                    match instance {
155                        Ok(instance) => {
156                            *singleton = SingletonState::Ready(Arc::clone(&instance));
157                            self.singleton_ready.notify_all();
158                            return Ok(instance);
159                        }
160                        Err(error) => {
161                            *singleton = SingletonState::Empty;
162                            self.singleton_ready.notify_all();
163                            return Err(error);
164                        }
165                    }
166                }
167            }
168        }
169    }
170
171    fn create_erased_in_scope(&self, scope: &RequestScope<'_>) -> Result<Arc<ErasedProvider>> {
172        if let Some(factory) = &self.request_factory {
173            factory(scope).map_err(|source| NidusError::ProviderFactory {
174                type_name: self.type_name,
175                source: Box::new(source),
176            })
177        } else {
178            self.create_erased(scope.container())
179        }
180    }
181}
182
183fn lock_unpoisoned<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
184    mutex
185        .lock()
186        .unwrap_or_else(|poisoned| poisoned.into_inner())
187}
188
189fn wait_unpoisoned<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
190    condvar
191        .wait(guard)
192        .unwrap_or_else(|poisoned| poisoned.into_inner())
193}
194
195#[cfg(test)]
196mod tests {
197    use std::{
198        any::{Any, type_name},
199        sync::Arc,
200        thread,
201    };
202
203    use super::{ProviderEntry, ProviderLifetime};
204    use crate::Container;
205
206    #[test]
207    fn singleton_provider_recovers_from_poisoned_cache() {
208        let provider = Arc::new(ProviderEntry::new(
209            std::any::TypeId::of::<String>(),
210            type_name::<String>(),
211            ProviderLifetime::Singleton,
212            Arc::new(|_container| Ok(Arc::new("ready".to_owned()) as Arc<dyn Any + Send + Sync>)),
213        ));
214        let poisoned_provider = Arc::clone(&provider);
215
216        let panic = thread::spawn(move || {
217            let _singleton = poisoned_provider.singleton.lock().unwrap();
218            panic!("poison singleton cache");
219        });
220        assert!(panic.join().is_err());
221
222        let value = provider
223            .resolve_erased(&Container::new())
224            .unwrap()
225            .downcast::<String>()
226            .unwrap();
227        assert_eq!(&*value, "ready");
228    }
229}