1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex, OnceLock},
7};
8
9pub type DynArc = Arc<dyn Any + Send + Sync>;
10pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
11type Factory =
12 for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
13type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Scope {
17 Singleton,
18 Prototype,
19 Request,
20}
21
22#[doc(hidden)]
23pub struct ProviderDescriptor {
24 type_id: fn() -> TypeId,
25 type_name: fn() -> &'static str,
26 factory: Factory,
27 pub name: Option<&'static str>,
28 pub primary: bool,
29 pub scope: Scope,
30 pub eager: bool,
31 pub profile: Option<&'static str>,
32 pub condition_key: Option<&'static str>,
33 pub condition_value: Option<&'static str>,
34 pub destroy: Option<Destroy>,
35}
36
37impl ProviderDescriptor {
38 #[doc(hidden)]
39 pub const fn new(
40 type_id: fn() -> TypeId,
41 type_name: fn() -> &'static str,
42 factory: Factory,
43 ) -> Self {
44 Self::configured(
45 type_id,
46 type_name,
47 factory,
48 None,
49 false,
50 Scope::Singleton,
51 false,
52 None,
53 None,
54 None,
55 None,
56 )
57 }
58
59 #[allow(clippy::too_many_arguments)]
60 #[doc(hidden)]
61 pub const fn configured(
62 type_id: fn() -> TypeId,
63 type_name: fn() -> &'static str,
64 factory: Factory,
65 name: Option<&'static str>,
66 primary: bool,
67 scope: Scope,
68 eager: bool,
69 profile: Option<&'static str>,
70 condition_key: Option<&'static str>,
71 condition_value: Option<&'static str>,
72 destroy: Option<Destroy>,
73 ) -> Self {
74 Self {
75 type_id,
76 type_name,
77 factory,
78 name,
79 primary,
80 scope,
81 eager,
82 profile,
83 condition_key,
84 condition_value,
85 destroy,
86 }
87 }
88
89 fn active(&self, profiles: &[String]) -> bool {
90 let profile_matches = self
91 .profile
92 .is_none_or(|required| profiles.iter().any(|p| p == required));
93 let condition_matches = self.condition_key.is_none_or(|key| {
94 let actual = std::env::var(key).ok();
95 self.condition_value.map_or(actual.is_some(), |expected| {
96 actual.as_deref() == Some(expected)
97 })
98 });
99 profile_matches && condition_matches
100 }
101}
102
103inventory::collect!(ProviderDescriptor);
104
105type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
106
107struct RuntimeProvider {
108 descriptor: &'static ProviderDescriptor,
109 singleton: tokio::sync::OnceCell<DynArc>,
110}
111
112impl RuntimeProvider {
113 fn new(descriptor: &'static ProviderDescriptor) -> Self {
114 Self {
115 descriptor,
116 singleton: tokio::sync::OnceCell::new(),
117 }
118 }
119}
120
121#[derive(Clone, Default)]
122pub struct ResolutionContext {
123 chain: Vec<&'static str>,
124 request_instances: Option<Arc<Mutex<InstanceMap>>>,
125}
126
127#[derive(Debug, thiserror::Error)]
128pub enum DiError {
129 #[error("no active provider is registered for {0}")]
130 MissingProvider(&'static str),
131 #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
132 AmbiguousProvider(&'static str),
133 #[error("multiple primary providers are registered for {0}")]
134 MultiplePrimary(&'static str),
135 #[error("circular dependency detected: {0}")]
136 CircularDependency(String),
137 #[error("provider for {0} returned an incompatible type")]
138 TypeMismatch(&'static str),
139 #[error("request-scoped dependency {0} was resolved outside RequestContext")]
140 RequestScopeUnavailable(&'static str),
141 #[error("configuration property {key} is missing or invalid: {message}")]
142 Configuration { key: String, message: String },
143 #[error("lifecycle hook failed for {0}")]
144 Lifecycle(&'static str),
145}
146
147pub struct Container {
148 providers: HashMap<TypeId, Vec<RuntimeProvider>>,
149}
150
151static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
152
153pub fn global_container() -> Result<&'static Container, DiError> {
154 if let Some(container) = GLOBAL_CONTAINER.get() {
155 return Ok(container);
156 }
157 let container = Container::new()?;
158 let _ = GLOBAL_CONTAINER.set(container);
159 Ok(GLOBAL_CONTAINER
160 .get()
161 .expect("global DI container initialized"))
162}
163
164pub async fn resolve<T>() -> Result<Arc<T>, DiError>
165where
166 T: Any + Send + Sync,
167{
168 global_container()?.resolve::<T>().await
169}
170
171impl Container {
172 pub fn new() -> Result<Self, DiError> {
173 let profiles = std::env::var("APP_PROFILES")
174 .unwrap_or_default()
175 .split(',')
176 .map(str::trim)
177 .filter(|p| !p.is_empty())
178 .map(str::to_owned)
179 .collect::<Vec<_>>();
180 Self::with_profiles(profiles)
181 }
182
183 pub fn with_profiles(
184 profiles: impl IntoIterator<Item = impl Into<String>>,
185 ) -> Result<Self, DiError> {
186 let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
187 let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
188 for provider in inventory::iter::<ProviderDescriptor> {
189 if provider.active(&profiles) {
190 providers
191 .entry((provider.type_id)())
192 .or_default()
193 .push(RuntimeProvider::new(provider));
194 }
195 }
196 for group in providers.values() {
197 if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
198 return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
199 }
200 }
201 Ok(Self { providers })
202 }
203
204 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
205 where
206 T: Any + Send + Sync,
207 {
208 self.resolve_dependency::<T>(&ResolutionContext::default())
209 .await
210 }
211
212 pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
213 where
214 T: Any + Send + Sync,
215 {
216 self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
217 .await
218 }
219
220 pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
221 where
222 T: Any + Send + Sync,
223 {
224 match self.resolve::<T>().await {
225 Ok(value) => Ok(Some(value)),
226 Err(DiError::MissingProvider(_)) => Ok(None),
227 Err(error) => Err(error),
228 }
229 }
230
231 pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
232 where
233 T: Any + Send + Sync,
234 {
235 self.resolve_all_dependency::<T>(&ResolutionContext::default())
236 .await
237 }
238
239 pub fn request_context(&self) -> RequestContext<'_> {
240 RequestContext {
241 container: self,
242 context: ResolutionContext {
243 chain: vec![],
244 request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
245 },
246 }
247 }
248
249 pub async fn initialize_eager(&self) -> Result<(), DiError> {
250 for providers in self.providers.values() {
251 for provider in providers.iter().filter(|p| p.descriptor.eager) {
252 self.resolve_provider(provider, ResolutionContext::default())
253 .await?;
254 }
255 }
256 Ok(())
257 }
258
259 pub async fn shutdown(&self) -> Result<(), DiError> {
260 for providers in self.providers.values() {
261 for provider in providers {
262 if let (Some(destroy), Some(value)) =
263 (provider.descriptor.destroy, provider.singleton.get())
264 {
265 destroy(value.clone()).await?;
266 }
267 }
268 }
269 Ok(())
270 }
271
272 #[doc(hidden)]
273 pub async fn resolve_dependency<T>(
274 &self,
275 context: &ResolutionContext,
276 ) -> Result<Arc<T>, DiError>
277 where
278 T: Any + Send + Sync,
279 {
280 self.resolve_selected::<T>(None, context).await
281 }
282
283 #[doc(hidden)]
284 pub async fn resolve_named_dependency<T>(
285 &self,
286 name: &str,
287 context: &ResolutionContext,
288 ) -> Result<Arc<T>, DiError>
289 where
290 T: Any + Send + Sync,
291 {
292 self.resolve_selected::<T>(Some(name), context).await
293 }
294
295 #[doc(hidden)]
296 pub async fn resolve_optional_dependency<T>(
297 &self,
298 context: &ResolutionContext,
299 ) -> Result<Option<Arc<T>>, DiError>
300 where
301 T: Any + Send + Sync,
302 {
303 match self.resolve_dependency::<T>(context).await {
304 Ok(value) => Ok(Some(value)),
305 Err(DiError::MissingProvider(_)) => Ok(None),
306 Err(error) => Err(error),
307 }
308 }
309
310 #[doc(hidden)]
311 pub async fn resolve_all_dependency<T>(
312 &self,
313 context: &ResolutionContext,
314 ) -> Result<Vec<Arc<T>>, DiError>
315 where
316 T: Any + Send + Sync,
317 {
318 let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
319 return Ok(vec![]);
320 };
321 let mut values = Vec::with_capacity(providers.len());
322 for provider in providers {
323 let value = self.resolve_provider(provider, context.clone()).await?;
324 values.push(
325 value
326 .downcast::<T>()
327 .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
328 );
329 }
330 Ok(values)
331 }
332
333 async fn resolve_selected<T>(
334 &self,
335 name: Option<&str>,
336 context: &ResolutionContext,
337 ) -> Result<Arc<T>, DiError>
338 where
339 T: Any + Send + Sync,
340 {
341 let type_name = std::any::type_name::<T>();
342 let providers = self
343 .providers
344 .get(&TypeId::of::<T>())
345 .ok_or(DiError::MissingProvider(type_name))?;
346 let selected = if let Some(name) = name {
347 providers
348 .iter()
349 .find(|p| p.descriptor.name == Some(name))
350 .ok_or(DiError::MissingProvider(type_name))?
351 } else if providers.len() == 1 {
352 &providers[0]
353 } else {
354 providers
355 .iter()
356 .find(|p| p.descriptor.primary)
357 .ok_or(DiError::AmbiguousProvider(type_name))?
358 };
359 let value = self.resolve_provider(selected, context.clone()).await?;
360 value
361 .downcast::<T>()
362 .map_err(|_| DiError::TypeMismatch(type_name))
363 }
364
365 fn resolve_provider<'a>(
366 &'a self,
367 provider: &'a RuntimeProvider,
368 mut context: ResolutionContext,
369 ) -> BoxFuture<'a, Result<DynArc, DiError>> {
370 Box::pin(async move {
371 let descriptor = provider.descriptor;
372 let type_name = (descriptor.type_name)();
373 if context.chain.contains(&type_name) {
374 context.chain.push(type_name);
375 return Err(DiError::CircularDependency(context.chain.join(" -> ")));
376 }
377 context.chain.push(type_name);
378 if descriptor.scope == Scope::Prototype {
379 return (descriptor.factory)(self, context).await;
380 }
381 match descriptor.scope {
382 Scope::Singleton => {
383 let value = provider
384 .singleton
385 .get_or_try_init(
386 || async move { (descriptor.factory)(self, context).await },
387 )
388 .await?;
389 Ok(value.clone())
390 }
391 Scope::Request => {
392 let map = context
393 .request_instances
394 .as_deref()
395 .ok_or(DiError::RequestScopeUnavailable(type_name))?;
396 let cell = {
397 let mut instances = map.lock().expect("DI instance lock poisoned");
398 instances
399 .entry(provider_key(descriptor))
400 .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
401 .clone()
402 };
403 let value = cell
404 .get_or_try_init(
405 || async move { (descriptor.factory)(self, context).await },
406 )
407 .await?;
408 Ok(value.clone())
409 }
410 Scope::Prototype => unreachable!(),
411 }
412 })
413 }
414}
415
416fn provider_key(provider: &'static ProviderDescriptor) -> usize {
417 provider as *const ProviderDescriptor as usize
418}
419
420pub struct RequestContext<'a> {
421 container: &'a Container,
422 context: ResolutionContext,
423}
424impl RequestContext<'_> {
425 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
426 where
427 T: Any + Send + Sync,
428 {
429 self.container.resolve_dependency::<T>(&self.context).await
430 }
431}
432
433#[doc(hidden)]
434pub mod __private {
435 pub use inventory;
436 pub use tokio;
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::{Lazy, Provider, configuration_properties, provider, singleton};
443 use std::sync::atomic::{AtomicUsize, Ordering};
444 static CREATIONS: AtomicUsize = AtomicUsize::new(0);
445 struct SyncDependency;
446 #[singleton]
447 fn sync_dependency() -> SyncDependency {
448 CREATIONS.fetch_add(1, Ordering::SeqCst);
449 SyncDependency
450 }
451 struct AsyncDependency {
452 _sync: Arc<SyncDependency>,
453 }
454 #[singleton]
455 async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
456 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
457 AsyncDependency { _sync: sync }
458 }
459 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
460 async fn singleton_is_concurrent_safe() {
461 let container = Arc::new(Container::new().unwrap());
462 let mut tasks = tokio::task::JoinSet::new();
463 for _ in 0..32 {
464 let c = container.clone();
465 tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
466 }
467 let mut values = vec![];
468 while let Some(v) = tasks.join_next().await {
469 values.push(v.unwrap());
470 }
471 assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
472 assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
473 }
474
475 static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
476 struct PrototypeBean(usize);
477 #[singleton(scope = "prototype")]
478 fn prototype_bean() -> PrototypeBean {
479 PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
480 }
481
482 struct RequestBean;
483 #[singleton(scope = "request")]
484 fn request_bean() -> RequestBean {
485 RequestBean
486 }
487
488 trait Greeting: Send + Sync {
489 fn text(&self) -> &'static str;
490 }
491 struct English;
492 impl Greeting for English {
493 fn text(&self) -> &'static str {
494 "hello"
495 }
496 }
497 struct Hindi;
498 impl Greeting for Hindi {
499 fn text(&self) -> &'static str {
500 "namaste"
501 }
502 }
503
504 #[singleton(name = "english", primary)]
505 fn english_greeting() -> Arc<dyn Greeting> {
506 Arc::new(English)
507 }
508 #[singleton(name = "hindi")]
509 fn hindi_greeting() -> Arc<dyn Greeting> {
510 Arc::new(Hindi)
511 }
512
513 struct MissingOptional;
514 struct Greeter {
515 greeting: Arc<dyn Greeting>,
516 optional: Option<Arc<MissingOptional>>,
517 }
518 struct GreetingLabel(&'static str);
519 #[singleton]
520 impl Greeter {
521 fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
522 Self { greeting, optional }
523 }
524
525 #[provider]
526 fn label(&self) -> GreetingLabel {
527 GreetingLabel(self.greeting.text())
528 }
529 }
530 struct QualifiedGreeter(Arc<dyn Greeting>);
531 #[singleton]
532 impl QualifiedGreeter {
533 fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
534 Self(greeting)
535 }
536 }
537
538 struct StaticConfig(&'static str);
539 struct ServiceWithStaticBean(Arc<StaticConfig>);
540 #[singleton]
541 impl ServiceWithStaticBean {
542 fn new(config: Arc<StaticConfig>) -> Self {
543 Self(config)
544 }
545
546 #[provider]
549 fn config() -> StaticConfig {
550 StaticConfig("static-bean")
551 }
552 }
553
554 static STARTED: AtomicUsize = AtomicUsize::new(0);
555 static STOPPED: AtomicUsize = AtomicUsize::new(0);
556 struct Managed;
557 #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
558 impl Managed {
559 fn new() -> Self {
560 Self
561 }
562 async fn start(&self) {
563 STARTED.fetch_add(1, Ordering::SeqCst);
564 }
565 async fn stop(&self) {
566 STOPPED.fetch_add(1, Ordering::SeqCst);
567 }
568 }
569
570 struct ProfileBean;
571 #[singleton(profile = "test")]
572 fn profile_bean() -> ProfileBean {
573 ProfileBean
574 }
575
576 #[derive(Debug)]
577 struct Handler(&'static str);
578 #[singleton(name = "first")]
579 fn first_handler() -> Handler {
580 Handler("first")
581 }
582 #[singleton(name = "second")]
583 fn second_handler() -> Handler {
584 Handler("second")
585 }
586 struct Pipeline(Vec<Arc<Handler>>);
587 #[singleton]
588 impl Pipeline {
589 fn new(handlers: Vec<Arc<Handler>>) -> Self {
590 Self(handlers)
591 }
592 }
593
594 #[configuration_properties("testing_dep")]
595 struct TestProperties {
596 port: u16,
597 }
598
599 struct DeferredTarget;
600 #[singleton]
601 fn deferred_target() -> DeferredTarget {
602 DeferredTarget
603 }
604 struct DeferredConsumer {
605 provider: Provider<DeferredTarget>,
606 lazy: Lazy<DeferredTarget>,
607 }
608 #[singleton]
609 impl DeferredConsumer {
610 fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
611 Self { provider, lazy }
612 }
613 }
614
615 struct ConditionalBean;
616 #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
617 fn conditional_bean() -> ConditionalBean {
618 ConditionalBean
619 }
620
621 struct StandaloneBean;
622 #[provider]
623 fn standalone_bean() -> StandaloneBean {
624 StandaloneBean
625 }
626
627 #[tokio::test]
628 async fn scopes_traits_primary_profiles_and_lifecycle_work() {
629 unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
630 let container = Container::with_profiles(["test"]).unwrap();
631 let first = container.resolve::<PrototypeBean>().await.unwrap();
632 let second = container.resolve::<PrototypeBean>().await.unwrap();
633 assert_ne!(first.0, second.0);
634
635 assert!(matches!(
636 container.resolve::<RequestBean>().await,
637 Err(DiError::RequestScopeUnavailable(_))
638 ));
639 let request = container.request_context();
640 let request_first = request.resolve::<RequestBean>().await.unwrap();
641 let request_second = request.resolve::<RequestBean>().await.unwrap();
642 assert!(Arc::ptr_eq(&request_first, &request_second));
643
644 let greeter = container.resolve::<Greeter>().await.unwrap();
645 assert_eq!(greeter.greeting.text(), "hello");
646 assert_eq!(
647 container.resolve::<GreetingLabel>().await.unwrap().0,
648 "hello"
649 );
650 assert!(greeter.optional.is_none());
651 assert_eq!(
652 container
653 .resolve::<QualifiedGreeter>()
654 .await
655 .unwrap()
656 .0
657 .text(),
658 "namaste"
659 );
660 let hindi = container
661 .resolve_named::<Arc<dyn Greeting>>("hindi")
662 .await
663 .unwrap();
664 assert_eq!(hindi.text(), "namaste");
665 let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
666 assert_eq!(static_bean_service.0.0, "static-bean");
667 container.resolve::<ProfileBean>().await.unwrap();
668 let pipeline = container.resolve::<Pipeline>().await.unwrap();
669 let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
670 handler_names.sort_unstable();
671 assert_eq!(handler_names, ["first", "second"]);
672
673 unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
675 let properties = container.resolve::<TestProperties>().await.unwrap();
676 assert_eq!(properties.port, 8080);
677 container.resolve::<ConditionalBean>().await.unwrap();
678 container.resolve::<StandaloneBean>().await.unwrap();
679
680 let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
681 deferred.provider.get().await.unwrap();
682 deferred.lazy.get().await.unwrap();
683
684 container.initialize_eager().await.unwrap();
685 assert_eq!(STARTED.load(Ordering::SeqCst), 1);
686 container.shutdown().await.unwrap();
687 assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
688 }
689}