dependency_injector/
factory.rs1use crate::Injectable;
13use once_cell::sync::OnceCell;
14use std::any::Any;
15use std::sync::Arc;
16
17#[cfg(feature = "logging")]
18use tracing::{debug, trace};
19
20pub trait Factory: Send + Sync {
22 fn resolve(&self) -> Arc<dyn Any + Send + Sync>;
24
25 fn is_transient(&self) -> bool {
27 false
28 }
29}
30
31pub struct SingletonFactory {
40 pub(crate) instance: Arc<dyn Any + Send + Sync>,
42}
43
44impl SingletonFactory {
45 #[inline]
47 pub fn new<T: Injectable>(instance: T) -> Self {
48 Self {
49 instance: Arc::new(instance) as Arc<dyn Any + Send + Sync>,
50 }
51 }
52
53 #[inline]
55 pub fn from_arc<T: Injectable>(instance: Arc<T>) -> Self {
56 Self {
57 instance: instance as Arc<dyn Any + Send + Sync>,
58 }
59 }
60
61 #[inline]
63 pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
64 Arc::clone(&self.instance)
65 }
66}
67
68impl Factory for SingletonFactory {
69 #[inline]
70 fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
71 self.resolve()
72 }
73}
74
75type LazyInitFn = Arc<dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync>;
81
82pub struct LazyFactory {
87 init: LazyInitFn,
89 instance: OnceCell<Arc<dyn Any + Send + Sync>>,
91 #[cfg(feature = "logging")]
93 type_name: &'static str,
94}
95
96impl LazyFactory {
97 #[inline]
99 pub fn new<T: Injectable, F>(factory: F) -> Self
100 where
101 F: Fn() -> T + Send + Sync + 'static,
102 {
103 Self {
104 init: Arc::new(move || Arc::new(factory()) as Arc<dyn Any + Send + Sync>),
105 instance: OnceCell::new(),
106 #[cfg(feature = "logging")]
107 type_name: std::any::type_name::<T>(),
108 }
109 }
110
111 #[inline]
113 pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
114 #[cfg(feature = "logging")]
115 let was_empty = self.instance.get().is_none();
116
117 let result = Arc::clone(self.instance.get_or_init(|| {
118 #[cfg(feature = "logging")]
119 debug!(
120 target: "dependency_injector",
121 service = self.type_name,
122 "Lazy singleton initializing on first access"
123 );
124
125 (self.init)()
126 }));
127
128 #[cfg(feature = "logging")]
129 if !was_empty {
130 trace!(
131 target: "dependency_injector",
132 service = self.type_name,
133 "Lazy singleton already initialized, returning cached instance"
134 );
135 }
136
137 result
138 }
139}
140
141impl Factory for LazyFactory {
142 #[inline]
143 fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
144 self.resolve()
145 }
146}
147
148type TransientFn = Arc<dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync>;
154
155pub struct TransientFactory {
159 factory: TransientFn,
161 #[cfg(feature = "logging")]
163 type_name: &'static str,
164}
165
166impl TransientFactory {
167 #[inline]
169 pub fn new<T: Injectable, F>(factory: F) -> Self
170 where
171 F: Fn() -> T + Send + Sync + 'static,
172 {
173 Self {
174 factory: Arc::new(move || Arc::new(factory()) as Arc<dyn Any + Send + Sync>),
175 #[cfg(feature = "logging")]
176 type_name: std::any::type_name::<T>(),
177 }
178 }
179
180 #[inline]
182 pub fn create(&self) -> Arc<dyn Any + Send + Sync> {
183 #[cfg(feature = "logging")]
184 trace!(
185 target: "dependency_injector",
186 service = self.type_name,
187 "Creating new transient instance"
188 );
189
190 (self.factory)()
191 }
192}
193
194impl Factory for TransientFactory {
195 #[inline]
196 fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
197 self.create()
198 }
199
200 #[inline]
201 fn is_transient(&self) -> bool {
202 true
203 }
204}
205
206pub(crate) enum AnyFactory {
224 Singleton(SingletonFactory),
226 Lazy(LazyFactory),
228 Transient(TransientFactory),
230}
231
232impl Clone for AnyFactory {
233 fn clone(&self) -> Self {
236 match self {
237 AnyFactory::Singleton(f) => AnyFactory::Singleton(SingletonFactory {
238 instance: Arc::clone(&f.instance),
239 }),
240 AnyFactory::Lazy(f) => {
241 let instance = f.resolve();
244 AnyFactory::Singleton(SingletonFactory { instance })
245 }
246 AnyFactory::Transient(f) => AnyFactory::Transient(TransientFactory {
247 factory: Arc::clone(&f.factory),
248 #[cfg(feature = "logging")]
249 type_name: f.type_name,
250 }),
251 }
252 }
253}
254
255impl AnyFactory {
256 #[inline]
258 pub fn singleton<T: Injectable>(instance: T) -> Self {
259 AnyFactory::Singleton(SingletonFactory::new(instance))
260 }
261
262 #[inline]
264 pub fn lazy<T: Injectable, F>(factory: F) -> Self
265 where
266 F: Fn() -> T + Send + Sync + 'static,
267 {
268 AnyFactory::Lazy(LazyFactory::new(factory))
269 }
270
271 #[inline]
273 pub fn transient<T: Injectable, F>(factory: F) -> Self
274 where
275 F: Fn() -> T + Send + Sync + 'static,
276 {
277 AnyFactory::Transient(TransientFactory::new(factory))
278 }
279
280 #[inline]
282 pub fn resolve(&self) -> Arc<dyn Any + Send + Sync> {
283 match self {
284 AnyFactory::Singleton(f) => f.resolve(),
285 AnyFactory::Lazy(f) => f.resolve(),
286 AnyFactory::Transient(f) => f.create(),
287 }
288 }
289
290 #[inline]
292 pub fn is_transient(&self) -> bool {
293 matches!(self, AnyFactory::Transient(_))
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use std::sync::atomic::{AtomicU32, Ordering};
301
302 #[derive(Clone)]
303 struct TestService {
304 id: u32,
305 }
306
307 #[test]
308 fn test_singleton_factory() {
309 let factory = AnyFactory::singleton(TestService { id: 42 });
310
311 let a = factory.resolve();
312 let b = factory.resolve();
313
314 let a = a.downcast::<TestService>().unwrap();
315 let b = b.downcast::<TestService>().unwrap();
316
317 assert_eq!(a.id, 42);
318 assert!(Arc::ptr_eq(&a, &b));
319 }
320
321 #[test]
322 fn test_lazy_factory() {
323 static COUNTER: AtomicU32 = AtomicU32::new(0);
324
325 let factory = AnyFactory::lazy(|| TestService {
326 id: COUNTER.fetch_add(1, Ordering::SeqCst),
327 });
328
329 assert_eq!(COUNTER.load(Ordering::SeqCst), 0);
330
331 let a = factory.resolve().downcast::<TestService>().unwrap();
332 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
333 assert_eq!(a.id, 0);
334
335 let b = factory.resolve().downcast::<TestService>().unwrap();
336 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
337 assert!(Arc::ptr_eq(&a, &b));
338 }
339
340 #[test]
341 fn test_transient_factory() {
342 static COUNTER: AtomicU32 = AtomicU32::new(0);
343
344 let factory = AnyFactory::transient(|| TestService {
345 id: COUNTER.fetch_add(1, Ordering::SeqCst),
346 });
347
348 let a = factory.resolve().downcast::<TestService>().unwrap();
349 let b = factory.resolve().downcast::<TestService>().unwrap();
350
351 assert_eq!(a.id, 0);
352 assert_eq!(b.id, 1);
353 assert!(!Arc::ptr_eq(&a, &b));
354 }
355
356 #[test]
357 fn test_is_transient() {
358 let singleton = AnyFactory::singleton(TestService { id: 1 });
359 let lazy = AnyFactory::lazy(|| TestService { id: 2 });
360 let transient = AnyFactory::transient(|| TestService { id: 3 });
361
362 assert!(!singleton.is_transient());
363 assert!(!lazy.is_transient());
364 assert!(transient.is_transient());
365 }
366}