1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex, OnceLock},
7};
8
9pub type DynArc = Arc<dyn Any + Send + Sync>;
10pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
11type Factory =
12 for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
13type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Scope {
17 Singleton,
18 Prototype,
19 Request,
20}
21
22#[doc(hidden)]
23pub struct ProviderDescriptor {
24 type_id: fn() -> TypeId,
25 type_name: fn() -> &'static str,
26 factory: Factory,
27 pub name: Option<&'static str>,
28 pub primary: bool,
29 pub scope: Scope,
30 pub eager: bool,
31 pub profile: Option<&'static str>,
32 pub condition_key: Option<&'static str>,
33 pub condition_value: Option<&'static str>,
34 pub destroy: Option<Destroy>,
35}
36
37impl ProviderDescriptor {
38 #[doc(hidden)]
39 pub const fn new(
40 type_id: fn() -> TypeId,
41 type_name: fn() -> &'static str,
42 factory: Factory,
43 ) -> Self {
44 Self::configured(
45 type_id,
46 type_name,
47 factory,
48 None,
49 false,
50 Scope::Singleton,
51 false,
52 None,
53 None,
54 None,
55 None,
56 )
57 }
58
59 #[allow(clippy::too_many_arguments)]
60 #[doc(hidden)]
61 pub const fn configured(
62 type_id: fn() -> TypeId,
63 type_name: fn() -> &'static str,
64 factory: Factory,
65 name: Option<&'static str>,
66 primary: bool,
67 scope: Scope,
68 eager: bool,
69 profile: Option<&'static str>,
70 condition_key: Option<&'static str>,
71 condition_value: Option<&'static str>,
72 destroy: Option<Destroy>,
73 ) -> Self {
74 Self {
75 type_id,
76 type_name,
77 factory,
78 name,
79 primary,
80 scope,
81 eager,
82 profile,
83 condition_key,
84 condition_value,
85 destroy,
86 }
87 }
88
89 fn active(&self, profiles: &[String]) -> bool {
90 let profile_matches = self
91 .profile
92 .is_none_or(|required| profiles.iter().any(|p| p == required));
93 let condition_matches = self.condition_key.is_none_or(|key| {
94 let actual = std::env::var(key).ok();
95 self.condition_value.map_or(actual.is_some(), |expected| {
96 actual.as_deref() == Some(expected)
97 })
98 });
99 profile_matches && condition_matches
100 }
101}
102
103inventory::collect!(ProviderDescriptor);
104
105type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
106
107#[derive(Clone, Default)]
108pub struct ResolutionContext {
109 chain: Vec<&'static str>,
110 request_instances: Option<Arc<Mutex<InstanceMap>>>,
111}
112
113#[derive(Debug, thiserror::Error)]
114pub enum DiError {
115 #[error("no active provider is registered for {0}")]
116 MissingProvider(&'static str),
117 #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
118 AmbiguousProvider(&'static str),
119 #[error("multiple primary providers are registered for {0}")]
120 MultiplePrimary(&'static str),
121 #[error("circular dependency detected: {0}")]
122 CircularDependency(String),
123 #[error("provider for {0} returned an incompatible type")]
124 TypeMismatch(&'static str),
125 #[error("request-scoped dependency {0} was resolved outside RequestContext")]
126 RequestScopeUnavailable(&'static str),
127 #[error("configuration property {key} is missing or invalid: {message}")]
128 Configuration { key: String, message: String },
129 #[error("lifecycle hook failed for {0}")]
130 Lifecycle(&'static str),
131}
132
133pub struct Container {
134 providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>>,
135 instances: Mutex<InstanceMap>,
136}
137
138static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
139
140pub fn global_container() -> Result<&'static Container, DiError> {
141 if let Some(container) = GLOBAL_CONTAINER.get() {
142 return Ok(container);
143 }
144 let container = Container::new()?;
145 let _ = GLOBAL_CONTAINER.set(container);
146 Ok(GLOBAL_CONTAINER
147 .get()
148 .expect("global DI container initialized"))
149}
150
151pub async fn resolve<T>() -> Result<Arc<T>, DiError>
152where
153 T: Any + Send + Sync,
154{
155 global_container()?.resolve::<T>().await
156}
157
158impl Container {
159 pub fn new() -> Result<Self, DiError> {
160 let profiles = std::env::var("APP_PROFILES")
161 .unwrap_or_default()
162 .split(',')
163 .map(str::trim)
164 .filter(|p| !p.is_empty())
165 .map(str::to_owned)
166 .collect::<Vec<_>>();
167 Self::with_profiles(profiles)
168 }
169
170 pub fn with_profiles(
171 profiles: impl IntoIterator<Item = impl Into<String>>,
172 ) -> Result<Self, DiError> {
173 let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
174 let mut providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>> = HashMap::new();
175 for provider in inventory::iter::<ProviderDescriptor> {
176 if provider.active(&profiles) {
177 providers
178 .entry((provider.type_id)())
179 .or_default()
180 .push(provider);
181 }
182 }
183 for group in providers.values() {
184 if group.iter().filter(|p| p.primary).count() > 1 {
185 return Err(DiError::MultiplePrimary((group[0].type_name)()));
186 }
187 }
188 Ok(Self {
189 providers,
190 instances: Mutex::new(HashMap::new()),
191 })
192 }
193
194 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
195 where
196 T: Any + Send + Sync,
197 {
198 self.resolve_dependency::<T>(&ResolutionContext::default())
199 .await
200 }
201
202 pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
203 where
204 T: Any + Send + Sync,
205 {
206 self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
207 .await
208 }
209
210 pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
211 where
212 T: Any + Send + Sync,
213 {
214 match self.resolve::<T>().await {
215 Ok(value) => Ok(Some(value)),
216 Err(DiError::MissingProvider(_)) => Ok(None),
217 Err(error) => Err(error),
218 }
219 }
220
221 pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
222 where
223 T: Any + Send + Sync,
224 {
225 self.resolve_all_dependency::<T>(&ResolutionContext::default())
226 .await
227 }
228
229 pub fn request_context(&self) -> RequestContext<'_> {
230 RequestContext {
231 container: self,
232 context: ResolutionContext {
233 chain: vec![],
234 request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
235 },
236 }
237 }
238
239 pub async fn initialize_eager(&self) -> Result<(), DiError> {
240 for providers in self.providers.values() {
241 for provider in providers.iter().filter(|p| p.eager) {
242 self.resolve_provider(provider, ResolutionContext::default())
243 .await?;
244 }
245 }
246 Ok(())
247 }
248
249 pub async fn shutdown(&self) -> Result<(), DiError> {
250 let initialized = self
251 .instances
252 .lock()
253 .expect("DI instance lock poisoned")
254 .iter()
255 .filter_map(|(key, cell)| cell.get().map(|value| (*key, value.clone())))
256 .collect::<Vec<_>>();
257 for providers in self.providers.values() {
258 for provider in providers {
259 let key = provider_key(provider);
260 if let (Some(destroy), Some((_, value))) = (
261 provider.destroy,
262 initialized.iter().find(|(k, _)| *k == key),
263 ) {
264 destroy(value.clone()).await?;
265 }
266 }
267 }
268 Ok(())
269 }
270
271 #[doc(hidden)]
272 pub async fn resolve_dependency<T>(
273 &self,
274 context: &ResolutionContext,
275 ) -> Result<Arc<T>, DiError>
276 where
277 T: Any + Send + Sync,
278 {
279 self.resolve_selected::<T>(None, context).await
280 }
281
282 #[doc(hidden)]
283 pub async fn resolve_named_dependency<T>(
284 &self,
285 name: &str,
286 context: &ResolutionContext,
287 ) -> Result<Arc<T>, DiError>
288 where
289 T: Any + Send + Sync,
290 {
291 self.resolve_selected::<T>(Some(name), context).await
292 }
293
294 #[doc(hidden)]
295 pub async fn resolve_optional_dependency<T>(
296 &self,
297 context: &ResolutionContext,
298 ) -> Result<Option<Arc<T>>, DiError>
299 where
300 T: Any + Send + Sync,
301 {
302 match self.resolve_dependency::<T>(context).await {
303 Ok(value) => Ok(Some(value)),
304 Err(DiError::MissingProvider(_)) => Ok(None),
305 Err(error) => Err(error),
306 }
307 }
308
309 #[doc(hidden)]
310 pub async fn resolve_all_dependency<T>(
311 &self,
312 context: &ResolutionContext,
313 ) -> Result<Vec<Arc<T>>, DiError>
314 where
315 T: Any + Send + Sync,
316 {
317 let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
318 return Ok(vec![]);
319 };
320 let mut values = Vec::with_capacity(providers.len());
321 for provider in providers {
322 let value = self.resolve_provider(provider, context.clone()).await?;
323 values.push(
324 value
325 .downcast::<T>()
326 .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
327 );
328 }
329 Ok(values)
330 }
331
332 async fn resolve_selected<T>(
333 &self,
334 name: Option<&str>,
335 context: &ResolutionContext,
336 ) -> Result<Arc<T>, DiError>
337 where
338 T: Any + Send + Sync,
339 {
340 let type_name = std::any::type_name::<T>();
341 let providers = self
342 .providers
343 .get(&TypeId::of::<T>())
344 .ok_or(DiError::MissingProvider(type_name))?;
345 let selected = if let Some(name) = name {
346 providers
347 .iter()
348 .copied()
349 .find(|p| p.name == Some(name))
350 .ok_or(DiError::MissingProvider(type_name))?
351 } else if providers.len() == 1 {
352 providers[0]
353 } else {
354 providers
355 .iter()
356 .copied()
357 .find(|p| p.primary)
358 .ok_or(DiError::AmbiguousProvider(type_name))?
359 };
360 let value = self.resolve_provider(selected, context.clone()).await?;
361 value
362 .downcast::<T>()
363 .map_err(|_| DiError::TypeMismatch(type_name))
364 }
365
366 fn resolve_provider<'a>(
367 &'a self,
368 provider: &'static ProviderDescriptor,
369 mut context: ResolutionContext,
370 ) -> BoxFuture<'a, Result<DynArc, DiError>> {
371 Box::pin(async move {
372 let type_name = (provider.type_name)();
373 if context.chain.contains(&type_name) {
374 context.chain.push(type_name);
375 return Err(DiError::CircularDependency(context.chain.join(" -> ")));
376 }
377 context.chain.push(type_name);
378 if provider.scope == Scope::Prototype {
379 return (provider.factory)(self, context).await;
380 }
381 let map = match provider.scope {
382 Scope::Singleton => &self.instances,
383 Scope::Request => context
384 .request_instances
385 .as_deref()
386 .ok_or(DiError::RequestScopeUnavailable(type_name))?,
387 Scope::Prototype => unreachable!(),
388 };
389 let cell = {
390 let mut instances = map.lock().expect("DI instance lock poisoned");
391 instances
392 .entry(provider_key(provider))
393 .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
394 .clone()
395 };
396 let value = cell
397 .get_or_try_init(|| async move { (provider.factory)(self, context).await })
398 .await?;
399 Ok(value.clone())
400 })
401 }
402}
403
404fn provider_key(provider: &'static ProviderDescriptor) -> usize {
405 provider as *const ProviderDescriptor as usize
406}
407
408pub struct RequestContext<'a> {
409 container: &'a Container,
410 context: ResolutionContext,
411}
412impl RequestContext<'_> {
413 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
414 where
415 T: Any + Send + Sync,
416 {
417 self.container.resolve_dependency::<T>(&self.context).await
418 }
419}
420
421#[doc(hidden)]
422pub mod __private {
423 pub use inventory;
424 pub use tokio;
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::{Lazy, Provider, configuration_properties, provider, singleton};
431 use std::sync::atomic::{AtomicUsize, Ordering};
432 static CREATIONS: AtomicUsize = AtomicUsize::new(0);
433 struct SyncDependency;
434 #[singleton]
435 fn sync_dependency() -> SyncDependency {
436 CREATIONS.fetch_add(1, Ordering::SeqCst);
437 SyncDependency
438 }
439 struct AsyncDependency {
440 _sync: Arc<SyncDependency>,
441 }
442 #[singleton]
443 async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
444 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
445 AsyncDependency { _sync: sync }
446 }
447 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
448 async fn singleton_is_concurrent_safe() {
449 let container = Arc::new(Container::new().unwrap());
450 let mut tasks = tokio::task::JoinSet::new();
451 for _ in 0..32 {
452 let c = container.clone();
453 tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
454 }
455 let mut values = vec![];
456 while let Some(v) = tasks.join_next().await {
457 values.push(v.unwrap());
458 }
459 assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
460 assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
461 }
462
463 static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
464 struct PrototypeBean(usize);
465 #[singleton(scope = "prototype")]
466 fn prototype_bean() -> PrototypeBean {
467 PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
468 }
469
470 struct RequestBean;
471 #[singleton(scope = "request")]
472 fn request_bean() -> RequestBean {
473 RequestBean
474 }
475
476 trait Greeting: Send + Sync {
477 fn text(&self) -> &'static str;
478 }
479 struct English;
480 impl Greeting for English {
481 fn text(&self) -> &'static str {
482 "hello"
483 }
484 }
485 struct Hindi;
486 impl Greeting for Hindi {
487 fn text(&self) -> &'static str {
488 "namaste"
489 }
490 }
491
492 #[singleton(name = "english", primary)]
493 fn english_greeting() -> Arc<dyn Greeting> {
494 Arc::new(English)
495 }
496 #[singleton(name = "hindi")]
497 fn hindi_greeting() -> Arc<dyn Greeting> {
498 Arc::new(Hindi)
499 }
500
501 struct MissingOptional;
502 struct Greeter {
503 greeting: Arc<dyn Greeting>,
504 optional: Option<Arc<MissingOptional>>,
505 }
506 struct GreetingLabel(&'static str);
507 #[singleton]
508 impl Greeter {
509 fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
510 Self { greeting, optional }
511 }
512
513 #[provider]
514 fn label(&self) -> GreetingLabel {
515 GreetingLabel(self.greeting.text())
516 }
517 }
518 struct QualifiedGreeter(Arc<dyn Greeting>);
519 #[singleton]
520 impl QualifiedGreeter {
521 fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
522 Self(greeting)
523 }
524 }
525
526 struct StaticConfig(&'static str);
527 struct ServiceWithStaticBean(Arc<StaticConfig>);
528 #[singleton]
529 impl ServiceWithStaticBean {
530 fn new(config: Arc<StaticConfig>) -> Self {
531 Self(config)
532 }
533
534 #[provider]
537 fn config() -> StaticConfig {
538 StaticConfig("static-bean")
539 }
540 }
541
542 static STARTED: AtomicUsize = AtomicUsize::new(0);
543 static STOPPED: AtomicUsize = AtomicUsize::new(0);
544 struct Managed;
545 #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
546 impl Managed {
547 fn new() -> Self {
548 Self
549 }
550 async fn start(&self) {
551 STARTED.fetch_add(1, Ordering::SeqCst);
552 }
553 async fn stop(&self) {
554 STOPPED.fetch_add(1, Ordering::SeqCst);
555 }
556 }
557
558 struct ProfileBean;
559 #[singleton(profile = "test")]
560 fn profile_bean() -> ProfileBean {
561 ProfileBean
562 }
563
564 #[derive(Debug)]
565 struct Handler(&'static str);
566 #[singleton(name = "first")]
567 fn first_handler() -> Handler {
568 Handler("first")
569 }
570 #[singleton(name = "second")]
571 fn second_handler() -> Handler {
572 Handler("second")
573 }
574 struct Pipeline(Vec<Arc<Handler>>);
575 #[singleton]
576 impl Pipeline {
577 fn new(handlers: Vec<Arc<Handler>>) -> Self {
578 Self(handlers)
579 }
580 }
581
582 #[configuration_properties("testing_dep")]
583 struct TestProperties {
584 port: u16,
585 }
586
587 struct DeferredTarget;
588 #[singleton]
589 fn deferred_target() -> DeferredTarget {
590 DeferredTarget
591 }
592 struct DeferredConsumer {
593 provider: Provider<DeferredTarget>,
594 lazy: Lazy<DeferredTarget>,
595 }
596 #[singleton]
597 impl DeferredConsumer {
598 fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
599 Self { provider, lazy }
600 }
601 }
602
603 struct ConditionalBean;
604 #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
605 fn conditional_bean() -> ConditionalBean {
606 ConditionalBean
607 }
608
609 struct StandaloneBean;
610 #[provider]
611 fn standalone_bean() -> StandaloneBean {
612 StandaloneBean
613 }
614
615 #[tokio::test]
616 async fn scopes_traits_primary_profiles_and_lifecycle_work() {
617 unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
618 let container = Container::with_profiles(["test"]).unwrap();
619 let first = container.resolve::<PrototypeBean>().await.unwrap();
620 let second = container.resolve::<PrototypeBean>().await.unwrap();
621 assert_ne!(first.0, second.0);
622
623 assert!(matches!(
624 container.resolve::<RequestBean>().await,
625 Err(DiError::RequestScopeUnavailable(_))
626 ));
627 let request = container.request_context();
628 let request_first = request.resolve::<RequestBean>().await.unwrap();
629 let request_second = request.resolve::<RequestBean>().await.unwrap();
630 assert!(Arc::ptr_eq(&request_first, &request_second));
631
632 let greeter = container.resolve::<Greeter>().await.unwrap();
633 assert_eq!(greeter.greeting.text(), "hello");
634 assert_eq!(
635 container.resolve::<GreetingLabel>().await.unwrap().0,
636 "hello"
637 );
638 assert!(greeter.optional.is_none());
639 assert_eq!(
640 container
641 .resolve::<QualifiedGreeter>()
642 .await
643 .unwrap()
644 .0
645 .text(),
646 "namaste"
647 );
648 let hindi = container
649 .resolve_named::<Arc<dyn Greeting>>("hindi")
650 .await
651 .unwrap();
652 assert_eq!(hindi.text(), "namaste");
653 let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
654 assert_eq!(static_bean_service.0.0, "static-bean");
655 container.resolve::<ProfileBean>().await.unwrap();
656 let pipeline = container.resolve::<Pipeline>().await.unwrap();
657 let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
658 handler_names.sort_unstable();
659 assert_eq!(handler_names, ["first", "second"]);
660
661 unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
663 let properties = container.resolve::<TestProperties>().await.unwrap();
664 assert_eq!(properties.port, 8080);
665 container.resolve::<ConditionalBean>().await.unwrap();
666 container.resolve::<StandaloneBean>().await.unwrap();
667
668 let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
669 deferred.provider.get().await.unwrap();
670 deferred.lazy.get().await.unwrap();
671
672 container.initialize_eager().await.unwrap();
673 assert_eq!(STARTED.load(Ordering::SeqCst), 1);
674 container.shutdown().await.unwrap();
675 assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
676 }
677}