1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet, VecDeque},
4 future::Future,
5 pin::Pin,
6 sync::{
7 Arc, Mutex, OnceLock,
8 atomic::{AtomicBool, Ordering},
9 },
10};
11
12pub type DynArc = Arc<dyn Any + Send + Sync>;
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14type Factory =
15 for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
16type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum Scope {
20 Singleton,
21 Prototype,
22 Request,
23}
24
25#[doc(hidden)]
26pub struct ProviderDescriptor {
27 type_id: fn() -> TypeId,
28 type_name: fn() -> &'static str,
29 factory: Factory,
30 pub name: Option<&'static str>,
31 pub primary: bool,
32 pub scope: Scope,
33 pub eager: bool,
34 pub profile: Option<&'static str>,
35 pub condition_key: Option<&'static str>,
36 pub condition_value: Option<&'static str>,
37 pub destroy: Option<Destroy>,
38}
39
40impl ProviderDescriptor {
41 #[doc(hidden)]
42 pub const fn new(
43 type_id: fn() -> TypeId,
44 type_name: fn() -> &'static str,
45 factory: Factory,
46 ) -> Self {
47 Self::configured(
48 type_id,
49 type_name,
50 factory,
51 None,
52 false,
53 Scope::Singleton,
54 false,
55 None,
56 None,
57 None,
58 None,
59 )
60 }
61
62 #[allow(clippy::too_many_arguments)]
63 #[doc(hidden)]
64 pub const fn configured(
65 type_id: fn() -> TypeId,
66 type_name: fn() -> &'static str,
67 factory: Factory,
68 name: Option<&'static str>,
69 primary: bool,
70 scope: Scope,
71 eager: bool,
72 profile: Option<&'static str>,
73 condition_key: Option<&'static str>,
74 condition_value: Option<&'static str>,
75 destroy: Option<Destroy>,
76 ) -> Self {
77 Self {
78 type_id,
79 type_name,
80 factory,
81 name,
82 primary,
83 scope,
84 eager,
85 profile,
86 condition_key,
87 condition_value,
88 destroy,
89 }
90 }
91
92 fn active(&self, profiles: &[String]) -> bool {
93 let profile_matches = self
94 .profile
95 .is_none_or(|required| profiles.iter().any(|p| p == required));
96 let condition_matches = self.condition_key.is_none_or(|key| {
97 let actual = std::env::var(key).ok();
98 self.condition_value.map_or(actual.is_some(), |expected| {
99 actual.as_deref() == Some(expected)
100 })
101 });
102 profile_matches && condition_matches
103 }
104}
105
106inventory::collect!(ProviderDescriptor);
107
108type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
109
110struct RuntimeProvider {
111 descriptor: &'static ProviderDescriptor,
112 singleton: tokio::sync::OnceCell<DynArc>,
113}
114
115impl RuntimeProvider {
116 fn new(descriptor: &'static ProviderDescriptor) -> Self {
117 Self {
118 descriptor,
119 singleton: tokio::sync::OnceCell::new(),
120 }
121 }
122}
123
124#[derive(Clone, Default)]
125pub struct ResolutionContext {
126 pub(crate) chain: Vec<&'static str>,
127 provider_chain: Vec<usize>,
128 scope_chain: Vec<Scope>,
129 pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
130}
131
132#[derive(Debug, thiserror::Error)]
133pub enum DiError {
134 #[error("no active provider is registered for {0}")]
135 MissingProvider(&'static str),
136 #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
137 AmbiguousProvider(&'static str),
138 #[error("multiple primary providers are registered for {0}")]
139 MultiplePrimary(&'static str),
140 #[error("duplicate provider name '{name}' is registered for {type_name}")]
141 DuplicateProviderName {
142 type_name: &'static str,
143 name: &'static str,
144 },
145 #[error("circular dependency detected: {0}")]
146 CircularDependency(String),
147 #[error("provider for {0} returned an incompatible type")]
148 TypeMismatch(&'static str),
149 #[error("request-scoped dependency {0} was resolved outside RequestContext")]
150 RequestScopeUnavailable(&'static str),
151 #[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
152 InvalidScope {
153 singleton: &'static str,
154 request: &'static str,
155 },
156 #[error("the dependency container has already been shut down")]
157 ContainerShutdown,
158 #[error("configuration property {key} is missing or invalid: {message}")]
159 Configuration { key: String, message: String },
160 #[error("provider for {provider} failed: {message}")]
161 Factory {
162 provider: &'static str,
163 message: String,
164 },
165 #[error("lifecycle hook failed for {provider}: {message}")]
166 Lifecycle {
167 provider: &'static str,
168 message: String,
169 },
170}
171
172struct ContainerInner {
173 providers: HashMap<TypeId, Vec<RuntimeProvider>>,
174 dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
175 shut_down: AtomicBool,
176}
177
178#[derive(Clone)]
179pub struct Container {
180 inner: Arc<ContainerInner>,
181}
182
183static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
184static RESOLUTION_ERROR_HANDLER: OnceLock<fn(DiError) -> !> = OnceLock::new();
185
186pub fn set_resolution_error_handler(handler: fn(DiError) -> !) -> Result<(), fn(DiError) -> !> {
191 RESOLUTION_ERROR_HANDLER.set(handler)
192}
193
194#[doc(hidden)]
195pub fn handle_resolution_error(error: DiError) -> ! {
196 if let Some(handler) = RESOLUTION_ERROR_HANDLER.get() {
197 handler(error)
198 }
199 panic!("auto-di dependency resolution failed: {error}")
200}
201
202pub fn global_container() -> Result<&'static Container, DiError> {
203 if let Some(container) = GLOBAL_CONTAINER.get() {
204 return Ok(container);
205 }
206 let container = Container::new()?;
207 let _ = GLOBAL_CONTAINER.set(container);
208 Ok(GLOBAL_CONTAINER
209 .get()
210 .expect("global DI container initialized"))
211}
212
213pub async fn resolve<T>() -> Result<Arc<T>, DiError>
214where
215 T: Any + Send + Sync,
216{
217 global_container()?.resolve::<T>().await
218}
219
220impl Container {
221 pub fn new() -> Result<Self, DiError> {
222 let profiles = std::env::var("APP_PROFILES")
223 .unwrap_or_default()
224 .split(',')
225 .map(str::trim)
226 .filter(|p| !p.is_empty())
227 .map(str::to_owned)
228 .collect::<Vec<_>>();
229 Self::with_profiles(profiles)
230 }
231
232 pub fn with_profiles(
233 profiles: impl IntoIterator<Item = impl Into<String>>,
234 ) -> Result<Self, DiError> {
235 let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
236 let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
237 for provider in inventory::iter::<ProviderDescriptor> {
238 if provider.active(&profiles) {
239 providers
240 .entry((provider.type_id)())
241 .or_default()
242 .push(RuntimeProvider::new(provider));
243 }
244 }
245 for group in providers.values() {
246 if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
247 return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
248 }
249 let mut names = HashSet::new();
250 for provider in group {
251 if let Some(name) = provider.descriptor.name
252 && !names.insert(name)
253 {
254 return Err(DiError::DuplicateProviderName {
255 type_name: (provider.descriptor.type_name)(),
256 name,
257 });
258 }
259 }
260 }
261 Ok(Self {
262 inner: Arc::new(ContainerInner {
263 providers,
264 dependency_graph: Mutex::new(HashMap::new()),
265 shut_down: AtomicBool::new(false),
266 }),
267 })
268 }
269
270 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
271 where
272 T: Any + Send + Sync,
273 {
274 self.resolve_dependency::<T>(&ResolutionContext::default())
275 .await
276 }
277
278 pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
279 where
280 T: Any + Send + Sync,
281 {
282 self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
283 .await
284 }
285
286 pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
287 where
288 T: Any + Send + Sync,
289 {
290 match self.resolve::<T>().await {
291 Ok(value) => Ok(Some(value)),
292 Err(DiError::MissingProvider(_)) => Ok(None),
293 Err(error) => Err(error),
294 }
295 }
296
297 pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
298 where
299 T: Any + Send + Sync,
300 {
301 self.resolve_all_dependency::<T>(&ResolutionContext::default())
302 .await
303 }
304
305 pub fn request_context(&self) -> RequestContext<'_> {
306 RequestContext {
307 container: self,
308 context: ResolutionContext {
309 chain: vec![],
310 provider_chain: vec![],
311 scope_chain: vec![],
312 request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
313 },
314 }
315 }
316
317 pub async fn initialize_eager(&self) -> Result<(), DiError> {
318 for providers in self.inner.providers.values() {
319 for provider in providers.iter().filter(|p| p.descriptor.eager) {
320 self.resolve_provider(provider, ResolutionContext::default())
321 .await?;
322 }
323 }
324 Ok(())
325 }
326
327 pub async fn validate(&self) -> Result<(), DiError> {
331 for providers in self.inner.providers.values() {
332 for provider in providers
333 .iter()
334 .filter(|provider| provider.descriptor.scope == Scope::Singleton)
335 {
336 self.resolve_provider(provider, ResolutionContext::default())
337 .await?;
338 }
339 }
340 Ok(())
341 }
342
343 pub async fn shutdown(&self) -> Result<(), DiError> {
344 if self.inner.shut_down.swap(true, Ordering::AcqRel) {
345 return Ok(());
346 }
347 let graph = self
348 .inner
349 .dependency_graph
350 .lock()
351 .expect("DI dependency graph lock poisoned")
352 .clone();
353 let providers = self
354 .inner
355 .providers
356 .values()
357 .flatten()
358 .map(|provider| (runtime_provider_key(provider), provider))
359 .collect::<HashMap<_, _>>();
360 let order = shutdown_order(providers.keys().copied(), &graph);
361 for key in order {
362 let provider = providers[&key];
363 if let (Some(destroy), Some(value)) =
364 (provider.descriptor.destroy, provider.singleton.get())
365 {
366 destroy(value.clone()).await?;
367 }
368 }
369 Ok(())
370 }
371
372 #[doc(hidden)]
373 pub async fn resolve_dependency<T>(
374 &self,
375 context: &ResolutionContext,
376 ) -> Result<Arc<T>, DiError>
377 where
378 T: Any + Send + Sync,
379 {
380 self.resolve_selected::<T>(None, context).await
381 }
382
383 #[doc(hidden)]
384 pub async fn resolve_named_dependency<T>(
385 &self,
386 name: &str,
387 context: &ResolutionContext,
388 ) -> Result<Arc<T>, DiError>
389 where
390 T: Any + Send + Sync,
391 {
392 self.resolve_selected::<T>(Some(name), context).await
393 }
394
395 #[doc(hidden)]
396 pub async fn resolve_optional_dependency<T>(
397 &self,
398 context: &ResolutionContext,
399 ) -> Result<Option<Arc<T>>, DiError>
400 where
401 T: Any + Send + Sync,
402 {
403 match self.resolve_dependency::<T>(context).await {
404 Ok(value) => Ok(Some(value)),
405 Err(DiError::MissingProvider(_)) => Ok(None),
406 Err(error) => Err(error),
407 }
408 }
409
410 #[doc(hidden)]
411 pub async fn resolve_all_dependency<T>(
412 &self,
413 context: &ResolutionContext,
414 ) -> Result<Vec<Arc<T>>, DiError>
415 where
416 T: Any + Send + Sync,
417 {
418 let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
419 return Ok(vec![]);
420 };
421 let mut values = Vec::with_capacity(providers.len());
422 for provider in providers {
423 let value = self.resolve_provider(provider, context.clone()).await?;
424 values.push(
425 value
426 .downcast::<T>()
427 .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
428 );
429 }
430 Ok(values)
431 }
432
433 async fn resolve_selected<T>(
434 &self,
435 name: Option<&str>,
436 context: &ResolutionContext,
437 ) -> Result<Arc<T>, DiError>
438 where
439 T: Any + Send + Sync,
440 {
441 let type_name = std::any::type_name::<T>();
442 let providers = self
443 .inner
444 .providers
445 .get(&TypeId::of::<T>())
446 .ok_or(DiError::MissingProvider(type_name))?;
447 let selected = if let Some(name) = name {
448 providers
449 .iter()
450 .find(|p| p.descriptor.name == Some(name))
451 .ok_or(DiError::MissingProvider(type_name))?
452 } else if providers.len() == 1 {
453 &providers[0]
454 } else {
455 providers
456 .iter()
457 .find(|p| p.descriptor.primary)
458 .ok_or(DiError::AmbiguousProvider(type_name))?
459 };
460 let value = self.resolve_provider(selected, context.clone()).await?;
461 value
462 .downcast::<T>()
463 .map_err(|_| DiError::TypeMismatch(type_name))
464 }
465
466 fn resolve_provider<'a>(
467 &'a self,
468 provider: &'a RuntimeProvider,
469 mut context: ResolutionContext,
470 ) -> BoxFuture<'a, Result<DynArc, DiError>> {
471 Box::pin(async move {
472 if self.inner.shut_down.load(Ordering::Acquire) {
473 return Err(DiError::ContainerShutdown);
474 }
475 let descriptor = provider.descriptor;
476 let type_name = (descriptor.type_name)();
477 if context.chain.contains(&type_name) {
478 context.chain.push(type_name);
479 return Err(DiError::CircularDependency(context.chain.join(" -> ")));
480 }
481 let runtime_key = runtime_provider_key(provider);
482 if let Some(parent) = context.provider_chain.last().copied() {
483 self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
484 }
485 if descriptor.scope == Scope::Request
486 && let Some(position) = context
487 .scope_chain
488 .iter()
489 .position(|scope| *scope == Scope::Singleton)
490 {
491 return Err(DiError::InvalidScope {
492 singleton: context.chain[position],
493 request: type_name,
494 });
495 }
496 context.chain.push(type_name);
497 context.provider_chain.push(runtime_key);
498 context.scope_chain.push(descriptor.scope);
499 if descriptor.scope == Scope::Prototype {
500 return (descriptor.factory)(self, context).await;
501 }
502 match descriptor.scope {
503 Scope::Singleton => {
504 let value = provider
505 .singleton
506 .get_or_try_init(
507 || async move { (descriptor.factory)(self, context).await },
508 )
509 .await?;
510 Ok(value.clone())
511 }
512 Scope::Request => {
513 let map = context
514 .request_instances
515 .as_deref()
516 .ok_or(DiError::RequestScopeUnavailable(type_name))?;
517 let cell = {
518 let mut instances = map.lock().expect("DI instance lock poisoned");
519 instances
520 .entry(provider_key(descriptor))
521 .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
522 .clone()
523 };
524 let value = cell
525 .get_or_try_init(
526 || async move { (descriptor.factory)(self, context).await },
527 )
528 .await?;
529 Ok(value.clone())
530 }
531 Scope::Prototype => unreachable!(),
532 }
533 })
534 }
535
536 fn add_dependency_edge(
537 &self,
538 parent: usize,
539 dependency: usize,
540 context: &ResolutionContext,
541 dependency_name: &'static str,
542 ) -> Result<(), DiError> {
543 let mut graph = self
544 .inner
545 .dependency_graph
546 .lock()
547 .expect("DI dependency graph lock poisoned");
548 graph.entry(parent).or_default().insert(dependency);
549 if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
550 let mut chain = context.chain.clone();
551 chain.push(dependency_name);
552 return Err(DiError::CircularDependency(chain.join(" -> ")));
553 }
554 Ok(())
555 }
556}
557
558fn provider_key(provider: &'static ProviderDescriptor) -> usize {
559 provider as *const ProviderDescriptor as usize
560}
561
562fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
563 provider as *const RuntimeProvider as usize
564}
565
566fn graph_path_exists(
567 graph: &HashMap<usize, HashSet<usize>>,
568 current: usize,
569 target: usize,
570 visited: &mut HashSet<usize>,
571) -> bool {
572 if current == target {
573 return true;
574 }
575 visited.insert(current)
576 && graph.get(¤t).is_some_and(|dependencies| {
577 dependencies
578 .iter()
579 .any(|next| graph_path_exists(graph, *next, target, visited))
580 })
581}
582
583fn shutdown_order(
584 keys: impl IntoIterator<Item = usize>,
585 graph: &HashMap<usize, HashSet<usize>>,
586) -> Vec<usize> {
587 let keys = keys.into_iter().collect::<HashSet<_>>();
588 let mut incoming = keys
589 .iter()
590 .map(|key| (*key, 0usize))
591 .collect::<HashMap<_, _>>();
592 for (parent, dependencies) in graph {
593 if !keys.contains(parent) {
594 continue;
595 }
596 for dependency in dependencies {
597 if let Some(count) = incoming.get_mut(dependency) {
598 *count += 1;
599 }
600 }
601 }
602
603 let mut ready = incoming
604 .iter()
605 .filter_map(|(key, count)| (*count == 0).then_some(*key))
606 .collect::<VecDeque<_>>();
607 let mut order = Vec::with_capacity(keys.len());
608 while let Some(parent) = ready.pop_front() {
609 order.push(parent);
610 if let Some(dependencies) = graph.get(&parent) {
611 for dependency in dependencies {
612 if let Some(count) = incoming.get_mut(dependency) {
613 *count -= 1;
614 if *count == 0 {
615 ready.push_back(*dependency);
616 }
617 }
618 }
619 }
620 }
621 let ordered = order.iter().copied().collect::<HashSet<_>>();
624 order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
625 order
626}
627
628pub struct RequestContext<'a> {
629 container: &'a Container,
630 context: ResolutionContext,
631}
632impl RequestContext<'_> {
633 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
634 where
635 T: Any + Send + Sync,
636 {
637 self.container.resolve_dependency::<T>(&self.context).await
638 }
639}
640
641#[doc(hidden)]
642pub mod __private {
643 pub use inventory;
644 pub use tokio;
645
646 use std::fmt::Display;
647
648 use crate::DiError;
649
650 pub fn factory_result<T, E: Display>(
651 result: Result<T, E>,
652 provider: &'static str,
653 ) -> Result<T, DiError> {
654 result.map_err(|error| DiError::Factory {
655 provider,
656 message: error.to_string(),
657 })
658 }
659
660 pub trait IntoLifecycleResult {
661 fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
662 }
663
664 impl IntoLifecycleResult for () {
665 fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
666 Ok(())
667 }
668 }
669
670 impl<E: Display> IntoLifecycleResult for Result<(), E> {
671 fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
672 self.map_err(|error| DiError::Lifecycle {
673 provider,
674 message: error.to_string(),
675 })
676 }
677 }
678
679 pub fn lifecycle_result<R: IntoLifecycleResult>(
680 result: R,
681 provider: &'static str,
682 ) -> Result<(), DiError> {
683 result.into_lifecycle_result(provider)
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::{
691 Lazy, Provider, configuration_properties, injectable, injected, provider, singleton,
692 };
693 use std::sync::atomic::{AtomicUsize, Ordering};
694 static CREATIONS: AtomicUsize = AtomicUsize::new(0);
695 struct SyncDependency;
696 #[singleton]
697 fn sync_dependency() -> SyncDependency {
698 CREATIONS.fetch_add(1, Ordering::SeqCst);
699 SyncDependency
700 }
701 struct AsyncDependency {
702 _sync: Arc<SyncDependency>,
703 }
704 #[singleton]
705 async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
706 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
707 AsyncDependency { _sync: sync }
708 }
709 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
710 async fn singleton_is_concurrent_safe() {
711 let container = Arc::new(Container::new().unwrap());
712 let mut tasks = tokio::task::JoinSet::new();
713 for _ in 0..32 {
714 let c = container.clone();
715 tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
716 }
717 let mut values = vec![];
718 while let Some(v) = tasks.join_next().await {
719 values.push(v.unwrap());
720 }
721 assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
722 assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
723 }
724
725 static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
726 struct PrototypeBean(usize);
727 #[singleton(scope = "prototype")]
728 fn prototype_bean() -> PrototypeBean {
729 PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
730 }
731
732 struct RequestBean;
733 #[singleton(scope = "request")]
734 fn request_bean() -> RequestBean {
735 RequestBean
736 }
737
738 trait Greeting: Send + Sync {
739 fn text(&self) -> &'static str;
740 }
741 struct English;
742 impl Greeting for English {
743 fn text(&self) -> &'static str {
744 "hello"
745 }
746 }
747 struct Hindi;
748 impl Greeting for Hindi {
749 fn text(&self) -> &'static str {
750 "namaste"
751 }
752 }
753
754 #[singleton(name = "english", primary)]
755 fn english_greeting() -> Arc<dyn Greeting> {
756 Arc::new(English)
757 }
758 #[singleton(name = "hindi")]
759 fn hindi_greeting() -> Arc<dyn Greeting> {
760 Arc::new(Hindi)
761 }
762
763 struct MissingOptional;
764 struct Greeter {
765 greeting: Arc<dyn Greeting>,
766 optional: Option<Arc<MissingOptional>>,
767 }
768 struct GreetingLabel(&'static str);
769 #[singleton]
770 impl Greeter {
771 fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
772 Self { greeting, optional }
773 }
774
775 #[provider]
776 fn label(&self) -> GreetingLabel {
777 GreetingLabel(self.greeting.text())
778 }
779 }
780 struct QualifiedGreeter(Arc<dyn Greeting>);
781 #[singleton]
782 impl QualifiedGreeter {
783 fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
784 Self(greeting)
785 }
786 }
787
788 struct StaticConfig(&'static str);
789 struct ServiceWithStaticBean(Arc<StaticConfig>);
790 #[singleton]
791 impl ServiceWithStaticBean {
792 fn new(config: Arc<StaticConfig>) -> Self {
793 Self(config)
794 }
795
796 #[provider]
799 fn config() -> StaticConfig {
800 StaticConfig("static-bean")
801 }
802 }
803
804 static STARTED: AtomicUsize = AtomicUsize::new(0);
805 static STOPPED: AtomicUsize = AtomicUsize::new(0);
806 struct Managed;
807 #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
808 impl Managed {
809 fn new() -> Self {
810 Self
811 }
812 async fn start(&self) {
813 STARTED.fetch_add(1, Ordering::SeqCst);
814 }
815 async fn stop(&self) {
816 STOPPED.fetch_add(1, Ordering::SeqCst);
817 }
818 }
819
820 struct ProfileBean;
821 #[singleton(profile = "test")]
822 fn profile_bean() -> ProfileBean {
823 ProfileBean
824 }
825
826 #[derive(Debug)]
827 struct Handler(&'static str);
828 #[singleton(name = "first")]
829 fn first_handler() -> Handler {
830 Handler("first")
831 }
832 #[singleton(name = "second")]
833 fn second_handler() -> Handler {
834 Handler("second")
835 }
836 struct Pipeline(Vec<Arc<Handler>>);
837 #[singleton]
838 impl Pipeline {
839 fn new(handlers: Vec<Arc<Handler>>) -> Self {
840 Self(handlers)
841 }
842 }
843
844 #[configuration_properties("testing_dep")]
845 struct TestProperties {
846 port: u16,
847 }
848
849 struct DeferredTarget;
850 #[singleton]
851 fn deferred_target() -> DeferredTarget {
852 DeferredTarget
853 }
854 struct DeferredConsumer {
855 provider: Provider<DeferredTarget>,
856 lazy: Lazy<DeferredTarget>,
857 }
858 #[singleton]
859 impl DeferredConsumer {
860 fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
861 Self { provider, lazy }
862 }
863 }
864
865 struct ConditionalBean;
866 #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
867 fn conditional_bean() -> ConditionalBean {
868 ConditionalBean
869 }
870
871 struct FallibleDependency;
872 #[provider]
873 fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
874 Err("expected factory failure")
875 }
876
877 #[derive(Clone)]
878 struct FieldDependency(u32);
879 #[singleton]
880 fn field_dependency() -> FieldDependency {
881 FieldDependency(10)
882 }
883
884 #[injectable]
885 struct FieldInjectedFacade {
886 dependency: Arc<FieldDependency>,
887 #[inject(7)]
888 literal: u32,
889 #[inject(|dependency: Arc<FieldDependency>| dependency)]
890 transformed: Arc<FieldDependency>,
891 }
892
893 #[injectable]
894 struct TupleInjected(#[inject(123)] u32);
895
896 #[injected]
897 fn injected_total(
898 dependency: Arc<FieldDependency>,
899 #[inject(2)] offset: u32,
900 multiplier: u32,
901 ) -> u32 {
902 (dependency.0 + offset) * multiplier
903 }
904
905 #[injected(fallible)]
906 fn caller_owned_arc(
907 #[argument] dependency: Arc<FieldDependency>,
908 #[inject(1)] offset: u32,
909 ) -> u32 {
910 dependency.0 + offset
911 }
912
913 struct InjectedMethods;
914 impl InjectedMethods {
915 #[injected]
916 fn calculate(&self, dependency: Arc<FieldDependency>, value: u32) -> u32 {
917 dependency.0 + value
918 }
919 }
920
921 struct RequestLeaf;
922 #[singleton(scope = "request")]
923 fn request_leaf() -> RequestLeaf {
924 RequestLeaf
925 }
926 struct CaptiveSingleton {
927 _request: Arc<RequestLeaf>,
928 }
929 #[singleton]
930 impl CaptiveSingleton {
931 fn new(request: Arc<RequestLeaf>) -> Self {
932 Self { _request: request }
933 }
934 }
935
936 struct CycleA {
937 _b: Arc<CycleB>,
938 }
939 struct CycleB {
940 _a: Arc<CycleA>,
941 }
942 #[singleton]
943 impl CycleA {
944 fn new(b: Arc<CycleB>) -> Self {
945 Self { _b: b }
946 }
947 }
948 #[singleton]
949 impl CycleB {
950 fn new(a: Arc<CycleA>) -> Self {
951 Self { _a: a }
952 }
953 }
954
955 struct AllGreetings(Vec<Arc<dyn Greeting>>);
956 #[singleton]
957 impl AllGreetings {
958 fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
959 Self(greetings)
960 }
961 }
962
963 struct ProfileDeferred;
964 #[singleton(profile = "test")]
965 fn profile_deferred() -> ProfileDeferred {
966 ProfileDeferred
967 }
968 struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
969 #[singleton]
970 impl ProfileDeferredConsumer {
971 fn new(provider: Provider<ProfileDeferred>) -> Self {
972 Self(provider)
973 }
974 }
975
976 struct StandaloneBean;
977 #[provider]
978 fn standalone_bean() -> StandaloneBean {
979 StandaloneBean
980 }
981
982 #[tokio::test]
983 async fn scopes_traits_primary_profiles_and_lifecycle_work() {
984 unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
985 let container = Container::with_profiles(["test"]).unwrap();
986 let first = container.resolve::<PrototypeBean>().await.unwrap();
987 let second = container.resolve::<PrototypeBean>().await.unwrap();
988 assert_ne!(first.0, second.0);
989
990 assert!(matches!(
991 container.resolve::<RequestBean>().await,
992 Err(DiError::RequestScopeUnavailable(_))
993 ));
994 let request = container.request_context();
995 let request_first = request.resolve::<RequestBean>().await.unwrap();
996 let request_second = request.resolve::<RequestBean>().await.unwrap();
997 assert!(Arc::ptr_eq(&request_first, &request_second));
998
999 let greeter = container.resolve::<Greeter>().await.unwrap();
1000 assert_eq!(greeter.greeting.text(), "hello");
1001 assert_eq!(
1002 container.resolve::<GreetingLabel>().await.unwrap().0,
1003 "hello"
1004 );
1005 assert!(greeter.optional.is_none());
1006 assert_eq!(
1007 container
1008 .resolve::<QualifiedGreeter>()
1009 .await
1010 .unwrap()
1011 .0
1012 .text(),
1013 "namaste"
1014 );
1015 let hindi = container
1016 .resolve_named::<Arc<dyn Greeting>>("hindi")
1017 .await
1018 .unwrap();
1019 assert_eq!(hindi.text(), "namaste");
1020 let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
1021 assert_eq!(static_bean_service.0.0, "static-bean");
1022 container.resolve::<ProfileBean>().await.unwrap();
1023 let pipeline = container.resolve::<Pipeline>().await.unwrap();
1024 let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
1025 handler_names.sort_unstable();
1026 assert_eq!(handler_names, ["first", "second"]);
1027
1028 unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
1030 let properties = container.resolve::<TestProperties>().await.unwrap();
1031 assert_eq!(properties.port, 8080);
1032 container.resolve::<ConditionalBean>().await.unwrap();
1033 container.resolve::<StandaloneBean>().await.unwrap();
1034
1035 let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
1036 deferred.provider.get().await.unwrap();
1037 deferred.lazy.get().await.unwrap();
1038 let profile_deferred = container
1039 .resolve::<ProfileDeferredConsumer>()
1040 .await
1041 .unwrap();
1042 profile_deferred.0.get().await.unwrap();
1043 let greetings = container.resolve::<AllGreetings>().await.unwrap();
1044 assert_eq!(greetings.0.len(), 2);
1045
1046 container.initialize_eager().await.unwrap();
1047 assert_eq!(STARTED.load(Ordering::SeqCst), 1);
1048 container.shutdown().await.unwrap();
1049 assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
1050 assert!(matches!(
1051 container.resolve::<SyncDependency>().await,
1052 Err(DiError::ContainerShutdown)
1053 ));
1054 }
1055
1056 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1057 async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
1058 let container = Arc::new(Container::new().unwrap());
1059 assert!(matches!(
1060 container.resolve::<FallibleDependency>().await,
1061 Err(DiError::Factory { .. })
1062 ));
1063
1064 let request = container.request_context();
1065 let captive = request.resolve::<CaptiveSingleton>().await;
1066 assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
1067
1068 let first = {
1069 let container = container.clone();
1070 tokio::spawn(async move { container.resolve::<CycleA>().await })
1071 };
1072 let second = {
1073 let container = container.clone();
1074 tokio::spawn(async move { container.resolve::<CycleB>().await })
1075 };
1076 let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
1077 (first.await.unwrap(), second.await.unwrap())
1078 })
1079 .await
1080 .expect("cross-task cycle resolution must not deadlock");
1081 assert!(
1082 matches!(results.0, Err(DiError::CircularDependency(_)))
1083 || matches!(results.1, Err(DiError::CircularDependency(_)))
1084 );
1085 }
1086
1087 #[tokio::test]
1088 async fn field_and_function_injection_are_automatic() {
1089 let facade = resolve::<FieldInjectedFacade>().await.unwrap();
1090 assert_eq!(facade.dependency.0, 10);
1091 assert_eq!(facade.literal, 7);
1092 assert!(Arc::ptr_eq(&facade.dependency, &facade.transformed));
1093 assert_eq!(resolve::<TupleInjected>().await.unwrap().0, 123);
1094 assert_eq!(injected_total(3).await, 36);
1095 assert_eq!(injected_total_with(Arc::new(FieldDependency(20)), 4, 2), 48);
1096 assert_eq!(
1097 caller_owned_arc(Arc::new(FieldDependency(20)))
1098 .await
1099 .unwrap(),
1100 21
1101 );
1102 assert_eq!(InjectedMethods.calculate(5).await, 15);
1103 }
1104}