nidus_core/provider/
mod.rs1use std::{
4 any::{Any, TypeId},
5 panic::{AssertUnwindSafe, catch_unwind},
6 sync::{Arc, Condvar, Mutex, MutexGuard},
7};
8
9use crate::{Container, NidusError, RequestScope, Result, resolution};
10
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13pub enum ProviderLifetime {
14 Singleton,
16 Transient,
18 Request,
20}
21
22pub trait Provider: Send + Sync + 'static {}
24
25impl<T> Provider for T where T: Send + Sync + 'static {}
26
27type ErasedProvider = dyn Any + Send + Sync;
28type ProviderFactory = dyn Fn(&Container) -> Result<Arc<ErasedProvider>> + Send + Sync;
29type RequestProviderFactory =
30 dyn for<'scope> Fn(&RequestScope<'scope>) -> Result<Arc<ErasedProvider>> + Send + Sync;
31
32pub struct ProviderEntry {
34 type_id: TypeId,
35 type_name: &'static str,
36 lifetime: ProviderLifetime,
37 factory: Arc<ProviderFactory>,
38 request_factory: Option<Arc<RequestProviderFactory>>,
39 singleton: Mutex<SingletonState>,
40 singleton_ready: Condvar,
41}
42
43enum SingletonState {
44 Empty,
45 Initializing,
46 Ready(Arc<ErasedProvider>),
47}
48
49impl ProviderEntry {
50 pub fn new(
52 type_id: TypeId,
53 type_name: &'static str,
54 lifetime: ProviderLifetime,
55 factory: Arc<ProviderFactory>,
56 ) -> Self {
57 Self {
58 type_id,
59 type_name,
60 lifetime,
61 factory,
62 request_factory: None,
63 singleton: Mutex::new(SingletonState::Empty),
64 singleton_ready: Condvar::new(),
65 }
66 }
67
68 pub fn new_request_scoped(
70 type_id: TypeId,
71 type_name: &'static str,
72 factory: Arc<ProviderFactory>,
73 request_factory: Arc<RequestProviderFactory>,
74 ) -> Self {
75 Self {
76 type_id,
77 type_name,
78 lifetime: ProviderLifetime::Request,
79 factory,
80 request_factory: Some(request_factory),
81 singleton: Mutex::new(SingletonState::Empty),
82 singleton_ready: Condvar::new(),
83 }
84 }
85
86 pub fn type_name(&self) -> &'static str {
88 self.type_name
89 }
90
91 pub fn lifetime(&self) -> ProviderLifetime {
93 self.lifetime
94 }
95
96 pub(crate) fn resolve_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
97 match self.lifetime {
98 ProviderLifetime::Singleton => self.resolve_singleton(container),
99 ProviderLifetime::Transient | ProviderLifetime::Request => {
100 self.create_erased(container)
101 }
102 }
103 }
104
105 pub(crate) fn resolve_erased_in_scope(
106 &self,
107 scope: &RequestScope<'_>,
108 ) -> Result<Arc<ErasedProvider>> {
109 match self.lifetime {
110 ProviderLifetime::Request => self.create_erased_in_scope(scope),
111 ProviderLifetime::Singleton | ProviderLifetime::Transient => {
112 self.resolve_erased(scope.container())
113 }
114 }
115 }
116
117 fn create_erased(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
118 (self.factory)(container).map_err(|source| NidusError::ProviderFactory {
119 type_name: self.type_name,
120 source: Box::new(source),
121 })
122 }
123
124 fn resolve_singleton(&self, container: &Container) -> Result<Arc<ErasedProvider>> {
125 loop {
126 let mut singleton = lock_unpoisoned(&self.singleton);
127 match &*singleton {
128 SingletonState::Ready(instance) => return Ok(Arc::clone(instance)),
129 SingletonState::Initializing => {
130 if resolution::is_active(self.type_id) {
131 return Err(NidusError::CircularProviderResolution {
132 type_name: self.type_name,
133 });
134 }
135 drop(wait_unpoisoned(&self.singleton_ready, singleton));
136 }
137 SingletonState::Empty => {
138 let _guard = resolution::enter(self.type_id, self.type_name)?;
139 *singleton = SingletonState::Initializing;
140 drop(singleton);
141
142 let instance =
143 match catch_unwind(AssertUnwindSafe(|| self.create_erased(container))) {
144 Ok(outcome) => outcome,
145 Err(panic_payload) => {
146 let mut singleton = lock_unpoisoned(&self.singleton);
147 *singleton = SingletonState::Empty;
148 self.singleton_ready.notify_all();
149 drop(singleton);
150 std::panic::resume_unwind(panic_payload);
151 }
152 };
153 let mut singleton = lock_unpoisoned(&self.singleton);
154 match instance {
155 Ok(instance) => {
156 *singleton = SingletonState::Ready(Arc::clone(&instance));
157 self.singleton_ready.notify_all();
158 return Ok(instance);
159 }
160 Err(error) => {
161 *singleton = SingletonState::Empty;
162 self.singleton_ready.notify_all();
163 return Err(error);
164 }
165 }
166 }
167 }
168 }
169 }
170
171 fn create_erased_in_scope(&self, scope: &RequestScope<'_>) -> Result<Arc<ErasedProvider>> {
172 if let Some(factory) = &self.request_factory {
173 factory(scope).map_err(|source| NidusError::ProviderFactory {
174 type_name: self.type_name,
175 source: Box::new(source),
176 })
177 } else {
178 self.create_erased(scope.container())
179 }
180 }
181}
182
183fn lock_unpoisoned<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
184 mutex
185 .lock()
186 .unwrap_or_else(|poisoned| poisoned.into_inner())
187}
188
189fn wait_unpoisoned<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
190 condvar
191 .wait(guard)
192 .unwrap_or_else(|poisoned| poisoned.into_inner())
193}
194
195#[cfg(test)]
196mod tests {
197 use std::{
198 any::{Any, type_name},
199 sync::Arc,
200 thread,
201 };
202
203 use super::{ProviderEntry, ProviderLifetime};
204 use crate::Container;
205
206 #[test]
207 fn singleton_provider_recovers_from_poisoned_cache() {
208 let provider = Arc::new(ProviderEntry::new(
209 std::any::TypeId::of::<String>(),
210 type_name::<String>(),
211 ProviderLifetime::Singleton,
212 Arc::new(|_container| Ok(Arc::new("ready".to_owned()) as Arc<dyn Any + Send + Sync>)),
213 ));
214 let poisoned_provider = Arc::clone(&provider);
215
216 let panic = thread::spawn(move || {
217 let _singleton = poisoned_provider.singleton.lock().unwrap();
218 panic!("poison singleton cache");
219 });
220 assert!(panic.join().is_err());
221
222 let value = provider
223 .resolve_erased(&Container::new())
224 .unwrap()
225 .downcast::<String>()
226 .unwrap();
227 assert_eq!(&*value, "ready");
228 }
229}