1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet, VecDeque},
4 future::Future,
5 pin::Pin,
6 sync::{
7 Arc, Mutex, OnceLock,
8 atomic::{AtomicBool, Ordering},
9 },
10};
11
12pub type DynArc = Arc<dyn Any + Send + Sync>;
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14type Factory =
15 for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
16type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum Scope {
20 Singleton,
21 Prototype,
22 Request,
23}
24
25#[doc(hidden)]
26pub struct ProviderDescriptor {
27 type_id: fn() -> TypeId,
28 type_name: fn() -> &'static str,
29 factory: Factory,
30 pub name: Option<&'static str>,
31 pub primary: bool,
32 pub scope: Scope,
33 pub eager: bool,
34 pub profile: Option<&'static str>,
35 pub condition_key: Option<&'static str>,
36 pub condition_value: Option<&'static str>,
37 pub destroy: Option<Destroy>,
38}
39
40impl ProviderDescriptor {
41 #[doc(hidden)]
42 pub const fn new(
43 type_id: fn() -> TypeId,
44 type_name: fn() -> &'static str,
45 factory: Factory,
46 ) -> Self {
47 Self::configured(
48 type_id,
49 type_name,
50 factory,
51 None,
52 false,
53 Scope::Singleton,
54 false,
55 None,
56 None,
57 None,
58 None,
59 )
60 }
61
62 #[allow(clippy::too_many_arguments)]
63 #[doc(hidden)]
64 pub const fn configured(
65 type_id: fn() -> TypeId,
66 type_name: fn() -> &'static str,
67 factory: Factory,
68 name: Option<&'static str>,
69 primary: bool,
70 scope: Scope,
71 eager: bool,
72 profile: Option<&'static str>,
73 condition_key: Option<&'static str>,
74 condition_value: Option<&'static str>,
75 destroy: Option<Destroy>,
76 ) -> Self {
77 Self {
78 type_id,
79 type_name,
80 factory,
81 name,
82 primary,
83 scope,
84 eager,
85 profile,
86 condition_key,
87 condition_value,
88 destroy,
89 }
90 }
91
92 fn active(&self, profiles: &[String]) -> bool {
93 let profile_matches = self
94 .profile
95 .is_none_or(|required| profiles.iter().any(|p| p == required));
96 let condition_matches = self.condition_key.is_none_or(|key| {
97 let actual = std::env::var(key).ok();
98 self.condition_value.map_or(actual.is_some(), |expected| {
99 actual.as_deref() == Some(expected)
100 })
101 });
102 profile_matches && condition_matches
103 }
104}
105
106inventory::collect!(ProviderDescriptor);
107
108type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
109
110struct RuntimeProvider {
111 descriptor: &'static ProviderDescriptor,
112 singleton: tokio::sync::OnceCell<DynArc>,
113}
114
115impl RuntimeProvider {
116 fn new(descriptor: &'static ProviderDescriptor) -> Self {
117 Self {
118 descriptor,
119 singleton: tokio::sync::OnceCell::new(),
120 }
121 }
122}
123
124#[derive(Clone, Default)]
125pub struct ResolutionContext {
126 pub(crate) chain: Vec<&'static str>,
127 provider_chain: Vec<usize>,
128 scope_chain: Vec<Scope>,
129 pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
130}
131
132#[derive(Debug, thiserror::Error)]
133pub enum DiError {
134 #[error("no active provider is registered for {0}")]
135 MissingProvider(&'static str),
136 #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
137 AmbiguousProvider(&'static str),
138 #[error("multiple primary providers are registered for {0}")]
139 MultiplePrimary(&'static str),
140 #[error("duplicate provider name '{name}' is registered for {type_name}")]
141 DuplicateProviderName {
142 type_name: &'static str,
143 name: &'static str,
144 },
145 #[error("circular dependency detected: {0}")]
146 CircularDependency(String),
147 #[error("provider for {0} returned an incompatible type")]
148 TypeMismatch(&'static str),
149 #[error("request-scoped dependency {0} was resolved outside RequestContext")]
150 RequestScopeUnavailable(&'static str),
151 #[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
152 InvalidScope {
153 singleton: &'static str,
154 request: &'static str,
155 },
156 #[error("the dependency container has already been shut down")]
157 ContainerShutdown,
158 #[error("configuration property {key} is missing or invalid: {message}")]
159 Configuration { key: String, message: String },
160 #[error("provider for {provider} failed: {message}")]
161 Factory {
162 provider: &'static str,
163 message: String,
164 },
165 #[error("lifecycle hook failed for {provider}: {message}")]
166 Lifecycle {
167 provider: &'static str,
168 message: String,
169 },
170}
171
172struct ContainerInner {
173 providers: HashMap<TypeId, Vec<RuntimeProvider>>,
174 dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
175 shut_down: AtomicBool,
176}
177
178#[derive(Clone)]
179pub struct Container {
180 inner: Arc<ContainerInner>,
181}
182
183static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
184
185pub fn global_container() -> Result<&'static Container, DiError> {
186 if let Some(container) = GLOBAL_CONTAINER.get() {
187 return Ok(container);
188 }
189 let container = Container::new()?;
190 let _ = GLOBAL_CONTAINER.set(container);
191 Ok(GLOBAL_CONTAINER
192 .get()
193 .expect("global DI container initialized"))
194}
195
196pub async fn resolve<T>() -> Result<Arc<T>, DiError>
197where
198 T: Any + Send + Sync,
199{
200 global_container()?.resolve::<T>().await
201}
202
203impl Container {
204 pub fn new() -> Result<Self, DiError> {
205 let profiles = std::env::var("APP_PROFILES")
206 .unwrap_or_default()
207 .split(',')
208 .map(str::trim)
209 .filter(|p| !p.is_empty())
210 .map(str::to_owned)
211 .collect::<Vec<_>>();
212 Self::with_profiles(profiles)
213 }
214
215 pub fn with_profiles(
216 profiles: impl IntoIterator<Item = impl Into<String>>,
217 ) -> Result<Self, DiError> {
218 let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
219 let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
220 for provider in inventory::iter::<ProviderDescriptor> {
221 if provider.active(&profiles) {
222 providers
223 .entry((provider.type_id)())
224 .or_default()
225 .push(RuntimeProvider::new(provider));
226 }
227 }
228 for group in providers.values() {
229 if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
230 return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
231 }
232 let mut names = HashSet::new();
233 for provider in group {
234 if let Some(name) = provider.descriptor.name
235 && !names.insert(name)
236 {
237 return Err(DiError::DuplicateProviderName {
238 type_name: (provider.descriptor.type_name)(),
239 name,
240 });
241 }
242 }
243 }
244 Ok(Self {
245 inner: Arc::new(ContainerInner {
246 providers,
247 dependency_graph: Mutex::new(HashMap::new()),
248 shut_down: AtomicBool::new(false),
249 }),
250 })
251 }
252
253 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
254 where
255 T: Any + Send + Sync,
256 {
257 self.resolve_dependency::<T>(&ResolutionContext::default())
258 .await
259 }
260
261 pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
262 where
263 T: Any + Send + Sync,
264 {
265 self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
266 .await
267 }
268
269 pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
270 where
271 T: Any + Send + Sync,
272 {
273 match self.resolve::<T>().await {
274 Ok(value) => Ok(Some(value)),
275 Err(DiError::MissingProvider(_)) => Ok(None),
276 Err(error) => Err(error),
277 }
278 }
279
280 pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
281 where
282 T: Any + Send + Sync,
283 {
284 self.resolve_all_dependency::<T>(&ResolutionContext::default())
285 .await
286 }
287
288 pub fn request_context(&self) -> RequestContext<'_> {
289 RequestContext {
290 container: self,
291 context: ResolutionContext {
292 chain: vec![],
293 provider_chain: vec![],
294 scope_chain: vec![],
295 request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
296 },
297 }
298 }
299
300 pub async fn initialize_eager(&self) -> Result<(), DiError> {
301 for providers in self.inner.providers.values() {
302 for provider in providers.iter().filter(|p| p.descriptor.eager) {
303 self.resolve_provider(provider, ResolutionContext::default())
304 .await?;
305 }
306 }
307 Ok(())
308 }
309
310 pub async fn validate(&self) -> Result<(), DiError> {
314 for providers in self.inner.providers.values() {
315 for provider in providers
316 .iter()
317 .filter(|provider| provider.descriptor.scope == Scope::Singleton)
318 {
319 self.resolve_provider(provider, ResolutionContext::default())
320 .await?;
321 }
322 }
323 Ok(())
324 }
325
326 pub async fn shutdown(&self) -> Result<(), DiError> {
327 if self.inner.shut_down.swap(true, Ordering::AcqRel) {
328 return Ok(());
329 }
330 let graph = self
331 .inner
332 .dependency_graph
333 .lock()
334 .expect("DI dependency graph lock poisoned")
335 .clone();
336 let providers = self
337 .inner
338 .providers
339 .values()
340 .flatten()
341 .map(|provider| (runtime_provider_key(provider), provider))
342 .collect::<HashMap<_, _>>();
343 let order = shutdown_order(providers.keys().copied(), &graph);
344 for key in order {
345 let provider = providers[&key];
346 if let (Some(destroy), Some(value)) =
347 (provider.descriptor.destroy, provider.singleton.get())
348 {
349 destroy(value.clone()).await?;
350 }
351 }
352 Ok(())
353 }
354
355 #[doc(hidden)]
356 pub async fn resolve_dependency<T>(
357 &self,
358 context: &ResolutionContext,
359 ) -> Result<Arc<T>, DiError>
360 where
361 T: Any + Send + Sync,
362 {
363 self.resolve_selected::<T>(None, context).await
364 }
365
366 #[doc(hidden)]
367 pub async fn resolve_named_dependency<T>(
368 &self,
369 name: &str,
370 context: &ResolutionContext,
371 ) -> Result<Arc<T>, DiError>
372 where
373 T: Any + Send + Sync,
374 {
375 self.resolve_selected::<T>(Some(name), context).await
376 }
377
378 #[doc(hidden)]
379 pub async fn resolve_optional_dependency<T>(
380 &self,
381 context: &ResolutionContext,
382 ) -> Result<Option<Arc<T>>, DiError>
383 where
384 T: Any + Send + Sync,
385 {
386 match self.resolve_dependency::<T>(context).await {
387 Ok(value) => Ok(Some(value)),
388 Err(DiError::MissingProvider(_)) => Ok(None),
389 Err(error) => Err(error),
390 }
391 }
392
393 #[doc(hidden)]
394 pub async fn resolve_all_dependency<T>(
395 &self,
396 context: &ResolutionContext,
397 ) -> Result<Vec<Arc<T>>, DiError>
398 where
399 T: Any + Send + Sync,
400 {
401 let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
402 return Ok(vec![]);
403 };
404 let mut values = Vec::with_capacity(providers.len());
405 for provider in providers {
406 let value = self.resolve_provider(provider, context.clone()).await?;
407 values.push(
408 value
409 .downcast::<T>()
410 .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
411 );
412 }
413 Ok(values)
414 }
415
416 async fn resolve_selected<T>(
417 &self,
418 name: Option<&str>,
419 context: &ResolutionContext,
420 ) -> Result<Arc<T>, DiError>
421 where
422 T: Any + Send + Sync,
423 {
424 let type_name = std::any::type_name::<T>();
425 let providers = self
426 .inner
427 .providers
428 .get(&TypeId::of::<T>())
429 .ok_or(DiError::MissingProvider(type_name))?;
430 let selected = if let Some(name) = name {
431 providers
432 .iter()
433 .find(|p| p.descriptor.name == Some(name))
434 .ok_or(DiError::MissingProvider(type_name))?
435 } else if providers.len() == 1 {
436 &providers[0]
437 } else {
438 providers
439 .iter()
440 .find(|p| p.descriptor.primary)
441 .ok_or(DiError::AmbiguousProvider(type_name))?
442 };
443 let value = self.resolve_provider(selected, context.clone()).await?;
444 value
445 .downcast::<T>()
446 .map_err(|_| DiError::TypeMismatch(type_name))
447 }
448
449 fn resolve_provider<'a>(
450 &'a self,
451 provider: &'a RuntimeProvider,
452 mut context: ResolutionContext,
453 ) -> BoxFuture<'a, Result<DynArc, DiError>> {
454 Box::pin(async move {
455 if self.inner.shut_down.load(Ordering::Acquire) {
456 return Err(DiError::ContainerShutdown);
457 }
458 let descriptor = provider.descriptor;
459 let type_name = (descriptor.type_name)();
460 if context.chain.contains(&type_name) {
461 context.chain.push(type_name);
462 return Err(DiError::CircularDependency(context.chain.join(" -> ")));
463 }
464 let runtime_key = runtime_provider_key(provider);
465 if let Some(parent) = context.provider_chain.last().copied() {
466 self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
467 }
468 if descriptor.scope == Scope::Request
469 && let Some(position) = context
470 .scope_chain
471 .iter()
472 .position(|scope| *scope == Scope::Singleton)
473 {
474 return Err(DiError::InvalidScope {
475 singleton: context.chain[position],
476 request: type_name,
477 });
478 }
479 context.chain.push(type_name);
480 context.provider_chain.push(runtime_key);
481 context.scope_chain.push(descriptor.scope);
482 if descriptor.scope == Scope::Prototype {
483 return (descriptor.factory)(self, context).await;
484 }
485 match descriptor.scope {
486 Scope::Singleton => {
487 let value = provider
488 .singleton
489 .get_or_try_init(
490 || async move { (descriptor.factory)(self, context).await },
491 )
492 .await?;
493 Ok(value.clone())
494 }
495 Scope::Request => {
496 let map = context
497 .request_instances
498 .as_deref()
499 .ok_or(DiError::RequestScopeUnavailable(type_name))?;
500 let cell = {
501 let mut instances = map.lock().expect("DI instance lock poisoned");
502 instances
503 .entry(provider_key(descriptor))
504 .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
505 .clone()
506 };
507 let value = cell
508 .get_or_try_init(
509 || async move { (descriptor.factory)(self, context).await },
510 )
511 .await?;
512 Ok(value.clone())
513 }
514 Scope::Prototype => unreachable!(),
515 }
516 })
517 }
518
519 fn add_dependency_edge(
520 &self,
521 parent: usize,
522 dependency: usize,
523 context: &ResolutionContext,
524 dependency_name: &'static str,
525 ) -> Result<(), DiError> {
526 let mut graph = self
527 .inner
528 .dependency_graph
529 .lock()
530 .expect("DI dependency graph lock poisoned");
531 graph.entry(parent).or_default().insert(dependency);
532 if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
533 let mut chain = context.chain.clone();
534 chain.push(dependency_name);
535 return Err(DiError::CircularDependency(chain.join(" -> ")));
536 }
537 Ok(())
538 }
539}
540
541fn provider_key(provider: &'static ProviderDescriptor) -> usize {
542 provider as *const ProviderDescriptor as usize
543}
544
545fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
546 provider as *const RuntimeProvider as usize
547}
548
549fn graph_path_exists(
550 graph: &HashMap<usize, HashSet<usize>>,
551 current: usize,
552 target: usize,
553 visited: &mut HashSet<usize>,
554) -> bool {
555 if current == target {
556 return true;
557 }
558 visited.insert(current)
559 && graph.get(¤t).is_some_and(|dependencies| {
560 dependencies
561 .iter()
562 .any(|next| graph_path_exists(graph, *next, target, visited))
563 })
564}
565
566fn shutdown_order(
567 keys: impl IntoIterator<Item = usize>,
568 graph: &HashMap<usize, HashSet<usize>>,
569) -> Vec<usize> {
570 let keys = keys.into_iter().collect::<HashSet<_>>();
571 let mut incoming = keys
572 .iter()
573 .map(|key| (*key, 0usize))
574 .collect::<HashMap<_, _>>();
575 for (parent, dependencies) in graph {
576 if !keys.contains(parent) {
577 continue;
578 }
579 for dependency in dependencies {
580 if let Some(count) = incoming.get_mut(dependency) {
581 *count += 1;
582 }
583 }
584 }
585
586 let mut ready = incoming
587 .iter()
588 .filter_map(|(key, count)| (*count == 0).then_some(*key))
589 .collect::<VecDeque<_>>();
590 let mut order = Vec::with_capacity(keys.len());
591 while let Some(parent) = ready.pop_front() {
592 order.push(parent);
593 if let Some(dependencies) = graph.get(&parent) {
594 for dependency in dependencies {
595 if let Some(count) = incoming.get_mut(dependency) {
596 *count -= 1;
597 if *count == 0 {
598 ready.push_back(*dependency);
599 }
600 }
601 }
602 }
603 }
604 let ordered = order.iter().copied().collect::<HashSet<_>>();
607 order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
608 order
609}
610
611pub struct RequestContext<'a> {
612 container: &'a Container,
613 context: ResolutionContext,
614}
615impl RequestContext<'_> {
616 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
617 where
618 T: Any + Send + Sync,
619 {
620 self.container.resolve_dependency::<T>(&self.context).await
621 }
622}
623
624#[doc(hidden)]
625pub mod __private {
626 pub use inventory;
627 pub use tokio;
628
629 use std::fmt::Display;
630
631 use crate::DiError;
632
633 pub fn factory_result<T, E: Display>(
634 result: Result<T, E>,
635 provider: &'static str,
636 ) -> Result<T, DiError> {
637 result.map_err(|error| DiError::Factory {
638 provider,
639 message: error.to_string(),
640 })
641 }
642
643 pub trait IntoLifecycleResult {
644 fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
645 }
646
647 impl IntoLifecycleResult for () {
648 fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
649 Ok(())
650 }
651 }
652
653 impl<E: Display> IntoLifecycleResult for Result<(), E> {
654 fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
655 self.map_err(|error| DiError::Lifecycle {
656 provider,
657 message: error.to_string(),
658 })
659 }
660 }
661
662 pub fn lifecycle_result<R: IntoLifecycleResult>(
663 result: R,
664 provider: &'static str,
665 ) -> Result<(), DiError> {
666 result.into_lifecycle_result(provider)
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673 use crate::{Lazy, Provider, configuration_properties, provider, singleton};
674 use std::sync::atomic::{AtomicUsize, Ordering};
675 static CREATIONS: AtomicUsize = AtomicUsize::new(0);
676 struct SyncDependency;
677 #[singleton]
678 fn sync_dependency() -> SyncDependency {
679 CREATIONS.fetch_add(1, Ordering::SeqCst);
680 SyncDependency
681 }
682 struct AsyncDependency {
683 _sync: Arc<SyncDependency>,
684 }
685 #[singleton]
686 async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
687 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
688 AsyncDependency { _sync: sync }
689 }
690 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
691 async fn singleton_is_concurrent_safe() {
692 let container = Arc::new(Container::new().unwrap());
693 let mut tasks = tokio::task::JoinSet::new();
694 for _ in 0..32 {
695 let c = container.clone();
696 tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
697 }
698 let mut values = vec![];
699 while let Some(v) = tasks.join_next().await {
700 values.push(v.unwrap());
701 }
702 assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
703 assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
704 }
705
706 static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
707 struct PrototypeBean(usize);
708 #[singleton(scope = "prototype")]
709 fn prototype_bean() -> PrototypeBean {
710 PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
711 }
712
713 struct RequestBean;
714 #[singleton(scope = "request")]
715 fn request_bean() -> RequestBean {
716 RequestBean
717 }
718
719 trait Greeting: Send + Sync {
720 fn text(&self) -> &'static str;
721 }
722 struct English;
723 impl Greeting for English {
724 fn text(&self) -> &'static str {
725 "hello"
726 }
727 }
728 struct Hindi;
729 impl Greeting for Hindi {
730 fn text(&self) -> &'static str {
731 "namaste"
732 }
733 }
734
735 #[singleton(name = "english", primary)]
736 fn english_greeting() -> Arc<dyn Greeting> {
737 Arc::new(English)
738 }
739 #[singleton(name = "hindi")]
740 fn hindi_greeting() -> Arc<dyn Greeting> {
741 Arc::new(Hindi)
742 }
743
744 struct MissingOptional;
745 struct Greeter {
746 greeting: Arc<dyn Greeting>,
747 optional: Option<Arc<MissingOptional>>,
748 }
749 struct GreetingLabel(&'static str);
750 #[singleton]
751 impl Greeter {
752 fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
753 Self { greeting, optional }
754 }
755
756 #[provider]
757 fn label(&self) -> GreetingLabel {
758 GreetingLabel(self.greeting.text())
759 }
760 }
761 struct QualifiedGreeter(Arc<dyn Greeting>);
762 #[singleton]
763 impl QualifiedGreeter {
764 fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
765 Self(greeting)
766 }
767 }
768
769 struct StaticConfig(&'static str);
770 struct ServiceWithStaticBean(Arc<StaticConfig>);
771 #[singleton]
772 impl ServiceWithStaticBean {
773 fn new(config: Arc<StaticConfig>) -> Self {
774 Self(config)
775 }
776
777 #[provider]
780 fn config() -> StaticConfig {
781 StaticConfig("static-bean")
782 }
783 }
784
785 static STARTED: AtomicUsize = AtomicUsize::new(0);
786 static STOPPED: AtomicUsize = AtomicUsize::new(0);
787 struct Managed;
788 #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
789 impl Managed {
790 fn new() -> Self {
791 Self
792 }
793 async fn start(&self) {
794 STARTED.fetch_add(1, Ordering::SeqCst);
795 }
796 async fn stop(&self) {
797 STOPPED.fetch_add(1, Ordering::SeqCst);
798 }
799 }
800
801 struct ProfileBean;
802 #[singleton(profile = "test")]
803 fn profile_bean() -> ProfileBean {
804 ProfileBean
805 }
806
807 #[derive(Debug)]
808 struct Handler(&'static str);
809 #[singleton(name = "first")]
810 fn first_handler() -> Handler {
811 Handler("first")
812 }
813 #[singleton(name = "second")]
814 fn second_handler() -> Handler {
815 Handler("second")
816 }
817 struct Pipeline(Vec<Arc<Handler>>);
818 #[singleton]
819 impl Pipeline {
820 fn new(handlers: Vec<Arc<Handler>>) -> Self {
821 Self(handlers)
822 }
823 }
824
825 #[configuration_properties("testing_dep")]
826 struct TestProperties {
827 port: u16,
828 }
829
830 struct DeferredTarget;
831 #[singleton]
832 fn deferred_target() -> DeferredTarget {
833 DeferredTarget
834 }
835 struct DeferredConsumer {
836 provider: Provider<DeferredTarget>,
837 lazy: Lazy<DeferredTarget>,
838 }
839 #[singleton]
840 impl DeferredConsumer {
841 fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
842 Self { provider, lazy }
843 }
844 }
845
846 struct ConditionalBean;
847 #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
848 fn conditional_bean() -> ConditionalBean {
849 ConditionalBean
850 }
851
852 struct FallibleDependency;
853 #[provider]
854 fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
855 Err("expected factory failure")
856 }
857
858 struct RequestLeaf;
859 #[singleton(scope = "request")]
860 fn request_leaf() -> RequestLeaf {
861 RequestLeaf
862 }
863 struct CaptiveSingleton {
864 _request: Arc<RequestLeaf>,
865 }
866 #[singleton]
867 impl CaptiveSingleton {
868 fn new(request: Arc<RequestLeaf>) -> Self {
869 Self { _request: request }
870 }
871 }
872
873 struct CycleA {
874 _b: Arc<CycleB>,
875 }
876 struct CycleB {
877 _a: Arc<CycleA>,
878 }
879 #[singleton]
880 impl CycleA {
881 fn new(b: Arc<CycleB>) -> Self {
882 Self { _b: b }
883 }
884 }
885 #[singleton]
886 impl CycleB {
887 fn new(a: Arc<CycleA>) -> Self {
888 Self { _a: a }
889 }
890 }
891
892 struct AllGreetings(Vec<Arc<dyn Greeting>>);
893 #[singleton]
894 impl AllGreetings {
895 fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
896 Self(greetings)
897 }
898 }
899
900 struct ProfileDeferred;
901 #[singleton(profile = "test")]
902 fn profile_deferred() -> ProfileDeferred {
903 ProfileDeferred
904 }
905 struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
906 #[singleton]
907 impl ProfileDeferredConsumer {
908 fn new(provider: Provider<ProfileDeferred>) -> Self {
909 Self(provider)
910 }
911 }
912
913 struct StandaloneBean;
914 #[provider]
915 fn standalone_bean() -> StandaloneBean {
916 StandaloneBean
917 }
918
919 #[tokio::test]
920 async fn scopes_traits_primary_profiles_and_lifecycle_work() {
921 unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
922 let container = Container::with_profiles(["test"]).unwrap();
923 let first = container.resolve::<PrototypeBean>().await.unwrap();
924 let second = container.resolve::<PrototypeBean>().await.unwrap();
925 assert_ne!(first.0, second.0);
926
927 assert!(matches!(
928 container.resolve::<RequestBean>().await,
929 Err(DiError::RequestScopeUnavailable(_))
930 ));
931 let request = container.request_context();
932 let request_first = request.resolve::<RequestBean>().await.unwrap();
933 let request_second = request.resolve::<RequestBean>().await.unwrap();
934 assert!(Arc::ptr_eq(&request_first, &request_second));
935
936 let greeter = container.resolve::<Greeter>().await.unwrap();
937 assert_eq!(greeter.greeting.text(), "hello");
938 assert_eq!(
939 container.resolve::<GreetingLabel>().await.unwrap().0,
940 "hello"
941 );
942 assert!(greeter.optional.is_none());
943 assert_eq!(
944 container
945 .resolve::<QualifiedGreeter>()
946 .await
947 .unwrap()
948 .0
949 .text(),
950 "namaste"
951 );
952 let hindi = container
953 .resolve_named::<Arc<dyn Greeting>>("hindi")
954 .await
955 .unwrap();
956 assert_eq!(hindi.text(), "namaste");
957 let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
958 assert_eq!(static_bean_service.0.0, "static-bean");
959 container.resolve::<ProfileBean>().await.unwrap();
960 let pipeline = container.resolve::<Pipeline>().await.unwrap();
961 let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
962 handler_names.sort_unstable();
963 assert_eq!(handler_names, ["first", "second"]);
964
965 unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
967 let properties = container.resolve::<TestProperties>().await.unwrap();
968 assert_eq!(properties.port, 8080);
969 container.resolve::<ConditionalBean>().await.unwrap();
970 container.resolve::<StandaloneBean>().await.unwrap();
971
972 let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
973 deferred.provider.get().await.unwrap();
974 deferred.lazy.get().await.unwrap();
975 let profile_deferred = container
976 .resolve::<ProfileDeferredConsumer>()
977 .await
978 .unwrap();
979 profile_deferred.0.get().await.unwrap();
980 let greetings = container.resolve::<AllGreetings>().await.unwrap();
981 assert_eq!(greetings.0.len(), 2);
982
983 container.initialize_eager().await.unwrap();
984 assert_eq!(STARTED.load(Ordering::SeqCst), 1);
985 container.shutdown().await.unwrap();
986 assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
987 assert!(matches!(
988 container.resolve::<SyncDependency>().await,
989 Err(DiError::ContainerShutdown)
990 ));
991 }
992
993 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
994 async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
995 let container = Arc::new(Container::new().unwrap());
996 assert!(matches!(
997 container.resolve::<FallibleDependency>().await,
998 Err(DiError::Factory { .. })
999 ));
1000
1001 let request = container.request_context();
1002 let captive = request.resolve::<CaptiveSingleton>().await;
1003 assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
1004
1005 let first = {
1006 let container = container.clone();
1007 tokio::spawn(async move { container.resolve::<CycleA>().await })
1008 };
1009 let second = {
1010 let container = container.clone();
1011 tokio::spawn(async move { container.resolve::<CycleB>().await })
1012 };
1013 let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
1014 (first.await.unwrap(), second.await.unwrap())
1015 })
1016 .await
1017 .expect("cross-task cycle resolution must not deadlock");
1018 assert!(
1019 matches!(results.0, Err(DiError::CircularDependency(_)))
1020 || matches!(results.1, Err(DiError::CircularDependency(_)))
1021 );
1022 }
1023}