Skip to main content

auto_di/
runtime.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet, VecDeque},
4    future::Future,
5    pin::Pin,
6    sync::{
7        Arc, Mutex, OnceLock,
8        atomic::{AtomicBool, Ordering},
9    },
10};
11
12pub type DynArc = Arc<dyn Any + Send + Sync>;
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14type Factory =
15    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
16type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum Scope {
20    Singleton,
21    Prototype,
22    Request,
23}
24
25#[doc(hidden)]
26pub struct ProviderDescriptor {
27    type_id: fn() -> TypeId,
28    type_name: fn() -> &'static str,
29    factory: Factory,
30    pub name: Option<&'static str>,
31    pub primary: bool,
32    pub scope: Scope,
33    pub eager: bool,
34    pub profile: Option<&'static str>,
35    pub condition_key: Option<&'static str>,
36    pub condition_value: Option<&'static str>,
37    pub destroy: Option<Destroy>,
38}
39
40impl ProviderDescriptor {
41    #[doc(hidden)]
42    pub const fn new(
43        type_id: fn() -> TypeId,
44        type_name: fn() -> &'static str,
45        factory: Factory,
46    ) -> Self {
47        Self::configured(
48            type_id,
49            type_name,
50            factory,
51            None,
52            false,
53            Scope::Singleton,
54            false,
55            None,
56            None,
57            None,
58            None,
59        )
60    }
61
62    #[allow(clippy::too_many_arguments)]
63    #[doc(hidden)]
64    pub const fn configured(
65        type_id: fn() -> TypeId,
66        type_name: fn() -> &'static str,
67        factory: Factory,
68        name: Option<&'static str>,
69        primary: bool,
70        scope: Scope,
71        eager: bool,
72        profile: Option<&'static str>,
73        condition_key: Option<&'static str>,
74        condition_value: Option<&'static str>,
75        destroy: Option<Destroy>,
76    ) -> Self {
77        Self {
78            type_id,
79            type_name,
80            factory,
81            name,
82            primary,
83            scope,
84            eager,
85            profile,
86            condition_key,
87            condition_value,
88            destroy,
89        }
90    }
91
92    fn active(&self, profiles: &[String]) -> bool {
93        let profile_matches = self
94            .profile
95            .is_none_or(|required| profiles.iter().any(|p| p == required));
96        let condition_matches = self.condition_key.is_none_or(|key| {
97            let actual = std::env::var(key).ok();
98            self.condition_value.map_or(actual.is_some(), |expected| {
99                actual.as_deref() == Some(expected)
100            })
101        });
102        profile_matches && condition_matches
103    }
104}
105
106inventory::collect!(ProviderDescriptor);
107
108type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
109
110struct RuntimeProvider {
111    descriptor: &'static ProviderDescriptor,
112    singleton: tokio::sync::OnceCell<DynArc>,
113}
114
115impl RuntimeProvider {
116    fn new(descriptor: &'static ProviderDescriptor) -> Self {
117        Self {
118            descriptor,
119            singleton: tokio::sync::OnceCell::new(),
120        }
121    }
122}
123
124#[derive(Clone, Default)]
125pub struct ResolutionContext {
126    pub(crate) chain: Vec<&'static str>,
127    provider_chain: Vec<usize>,
128    scope_chain: Vec<Scope>,
129    pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
130}
131
132#[derive(Debug, thiserror::Error)]
133pub enum DiError {
134    #[error("no active provider is registered for {0}")]
135    MissingProvider(&'static str),
136    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
137    AmbiguousProvider(&'static str),
138    #[error("multiple primary providers are registered for {0}")]
139    MultiplePrimary(&'static str),
140    #[error("duplicate provider name '{name}' is registered for {type_name}")]
141    DuplicateProviderName {
142        type_name: &'static str,
143        name: &'static str,
144    },
145    #[error("circular dependency detected: {0}")]
146    CircularDependency(String),
147    #[error("provider for {0} returned an incompatible type")]
148    TypeMismatch(&'static str),
149    #[error("request-scoped dependency {0} was resolved outside RequestContext")]
150    RequestScopeUnavailable(&'static str),
151    #[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
152    InvalidScope {
153        singleton: &'static str,
154        request: &'static str,
155    },
156    #[error("the dependency container has already been shut down")]
157    ContainerShutdown,
158    #[error("configuration property {key} is missing or invalid: {message}")]
159    Configuration { key: String, message: String },
160    #[error("provider for {provider} failed: {message}")]
161    Factory {
162        provider: &'static str,
163        message: String,
164    },
165    #[error("lifecycle hook failed for {provider}: {message}")]
166    Lifecycle {
167        provider: &'static str,
168        message: String,
169    },
170}
171
172struct ContainerInner {
173    providers: HashMap<TypeId, Vec<RuntimeProvider>>,
174    dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
175    shut_down: AtomicBool,
176}
177
178#[derive(Clone)]
179pub struct Container {
180    inner: Arc<ContainerInner>,
181}
182
183static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
184
185pub fn global_container() -> Result<&'static Container, DiError> {
186    if let Some(container) = GLOBAL_CONTAINER.get() {
187        return Ok(container);
188    }
189    let container = Container::new()?;
190    let _ = GLOBAL_CONTAINER.set(container);
191    Ok(GLOBAL_CONTAINER
192        .get()
193        .expect("global DI container initialized"))
194}
195
196pub async fn resolve<T>() -> Result<Arc<T>, DiError>
197where
198    T: Any + Send + Sync,
199{
200    global_container()?.resolve::<T>().await
201}
202
203impl Container {
204    pub fn new() -> Result<Self, DiError> {
205        let profiles = std::env::var("APP_PROFILES")
206            .unwrap_or_default()
207            .split(',')
208            .map(str::trim)
209            .filter(|p| !p.is_empty())
210            .map(str::to_owned)
211            .collect::<Vec<_>>();
212        Self::with_profiles(profiles)
213    }
214
215    pub fn with_profiles(
216        profiles: impl IntoIterator<Item = impl Into<String>>,
217    ) -> Result<Self, DiError> {
218        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
219        let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
220        for provider in inventory::iter::<ProviderDescriptor> {
221            if provider.active(&profiles) {
222                providers
223                    .entry((provider.type_id)())
224                    .or_default()
225                    .push(RuntimeProvider::new(provider));
226            }
227        }
228        for group in providers.values() {
229            if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
230                return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
231            }
232            let mut names = HashSet::new();
233            for provider in group {
234                if let Some(name) = provider.descriptor.name
235                    && !names.insert(name)
236                {
237                    return Err(DiError::DuplicateProviderName {
238                        type_name: (provider.descriptor.type_name)(),
239                        name,
240                    });
241                }
242            }
243        }
244        Ok(Self {
245            inner: Arc::new(ContainerInner {
246                providers,
247                dependency_graph: Mutex::new(HashMap::new()),
248                shut_down: AtomicBool::new(false),
249            }),
250        })
251    }
252
253    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
254    where
255        T: Any + Send + Sync,
256    {
257        self.resolve_dependency::<T>(&ResolutionContext::default())
258            .await
259    }
260
261    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
262    where
263        T: Any + Send + Sync,
264    {
265        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
266            .await
267    }
268
269    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
270    where
271        T: Any + Send + Sync,
272    {
273        match self.resolve::<T>().await {
274            Ok(value) => Ok(Some(value)),
275            Err(DiError::MissingProvider(_)) => Ok(None),
276            Err(error) => Err(error),
277        }
278    }
279
280    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
281    where
282        T: Any + Send + Sync,
283    {
284        self.resolve_all_dependency::<T>(&ResolutionContext::default())
285            .await
286    }
287
288    pub fn request_context(&self) -> RequestContext<'_> {
289        RequestContext {
290            container: self,
291            context: ResolutionContext {
292                chain: vec![],
293                provider_chain: vec![],
294                scope_chain: vec![],
295                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
296            },
297        }
298    }
299
300    pub async fn initialize_eager(&self) -> Result<(), DiError> {
301        for providers in self.inner.providers.values() {
302            for provider in providers.iter().filter(|p| p.descriptor.eager) {
303                self.resolve_provider(provider, ResolutionContext::default())
304                    .await?;
305            }
306        }
307        Ok(())
308    }
309
310    /// Validates the active singleton graph by constructing every singleton.
311    /// This surfaces missing, ambiguous, circular, and invalid-scope dependencies
312    /// during application startup instead of on the first request.
313    pub async fn validate(&self) -> Result<(), DiError> {
314        for providers in self.inner.providers.values() {
315            for provider in providers
316                .iter()
317                .filter(|provider| provider.descriptor.scope == Scope::Singleton)
318            {
319                self.resolve_provider(provider, ResolutionContext::default())
320                    .await?;
321            }
322        }
323        Ok(())
324    }
325
326    pub async fn shutdown(&self) -> Result<(), DiError> {
327        if self.inner.shut_down.swap(true, Ordering::AcqRel) {
328            return Ok(());
329        }
330        let graph = self
331            .inner
332            .dependency_graph
333            .lock()
334            .expect("DI dependency graph lock poisoned")
335            .clone();
336        let providers = self
337            .inner
338            .providers
339            .values()
340            .flatten()
341            .map(|provider| (runtime_provider_key(provider), provider))
342            .collect::<HashMap<_, _>>();
343        let order = shutdown_order(providers.keys().copied(), &graph);
344        for key in order {
345            let provider = providers[&key];
346            if let (Some(destroy), Some(value)) =
347                (provider.descriptor.destroy, provider.singleton.get())
348            {
349                destroy(value.clone()).await?;
350            }
351        }
352        Ok(())
353    }
354
355    #[doc(hidden)]
356    pub async fn resolve_dependency<T>(
357        &self,
358        context: &ResolutionContext,
359    ) -> Result<Arc<T>, DiError>
360    where
361        T: Any + Send + Sync,
362    {
363        self.resolve_selected::<T>(None, context).await
364    }
365
366    #[doc(hidden)]
367    pub async fn resolve_named_dependency<T>(
368        &self,
369        name: &str,
370        context: &ResolutionContext,
371    ) -> Result<Arc<T>, DiError>
372    where
373        T: Any + Send + Sync,
374    {
375        self.resolve_selected::<T>(Some(name), context).await
376    }
377
378    #[doc(hidden)]
379    pub async fn resolve_optional_dependency<T>(
380        &self,
381        context: &ResolutionContext,
382    ) -> Result<Option<Arc<T>>, DiError>
383    where
384        T: Any + Send + Sync,
385    {
386        match self.resolve_dependency::<T>(context).await {
387            Ok(value) => Ok(Some(value)),
388            Err(DiError::MissingProvider(_)) => Ok(None),
389            Err(error) => Err(error),
390        }
391    }
392
393    #[doc(hidden)]
394    pub async fn resolve_all_dependency<T>(
395        &self,
396        context: &ResolutionContext,
397    ) -> Result<Vec<Arc<T>>, DiError>
398    where
399        T: Any + Send + Sync,
400    {
401        let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
402            return Ok(vec![]);
403        };
404        let mut values = Vec::with_capacity(providers.len());
405        for provider in providers {
406            let value = self.resolve_provider(provider, context.clone()).await?;
407            values.push(
408                value
409                    .downcast::<T>()
410                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
411            );
412        }
413        Ok(values)
414    }
415
416    async fn resolve_selected<T>(
417        &self,
418        name: Option<&str>,
419        context: &ResolutionContext,
420    ) -> Result<Arc<T>, DiError>
421    where
422        T: Any + Send + Sync,
423    {
424        let type_name = std::any::type_name::<T>();
425        let providers = self
426            .inner
427            .providers
428            .get(&TypeId::of::<T>())
429            .ok_or(DiError::MissingProvider(type_name))?;
430        let selected = if let Some(name) = name {
431            providers
432                .iter()
433                .find(|p| p.descriptor.name == Some(name))
434                .ok_or(DiError::MissingProvider(type_name))?
435        } else if providers.len() == 1 {
436            &providers[0]
437        } else {
438            providers
439                .iter()
440                .find(|p| p.descriptor.primary)
441                .ok_or(DiError::AmbiguousProvider(type_name))?
442        };
443        let value = self.resolve_provider(selected, context.clone()).await?;
444        value
445            .downcast::<T>()
446            .map_err(|_| DiError::TypeMismatch(type_name))
447    }
448
449    fn resolve_provider<'a>(
450        &'a self,
451        provider: &'a RuntimeProvider,
452        mut context: ResolutionContext,
453    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
454        Box::pin(async move {
455            if self.inner.shut_down.load(Ordering::Acquire) {
456                return Err(DiError::ContainerShutdown);
457            }
458            let descriptor = provider.descriptor;
459            let type_name = (descriptor.type_name)();
460            if context.chain.contains(&type_name) {
461                context.chain.push(type_name);
462                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
463            }
464            let runtime_key = runtime_provider_key(provider);
465            if let Some(parent) = context.provider_chain.last().copied() {
466                self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
467            }
468            if descriptor.scope == Scope::Request
469                && let Some(position) = context
470                    .scope_chain
471                    .iter()
472                    .position(|scope| *scope == Scope::Singleton)
473            {
474                return Err(DiError::InvalidScope {
475                    singleton: context.chain[position],
476                    request: type_name,
477                });
478            }
479            context.chain.push(type_name);
480            context.provider_chain.push(runtime_key);
481            context.scope_chain.push(descriptor.scope);
482            if descriptor.scope == Scope::Prototype {
483                return (descriptor.factory)(self, context).await;
484            }
485            match descriptor.scope {
486                Scope::Singleton => {
487                    let value = provider
488                        .singleton
489                        .get_or_try_init(
490                            || async move { (descriptor.factory)(self, context).await },
491                        )
492                        .await?;
493                    Ok(value.clone())
494                }
495                Scope::Request => {
496                    let map = context
497                        .request_instances
498                        .as_deref()
499                        .ok_or(DiError::RequestScopeUnavailable(type_name))?;
500                    let cell = {
501                        let mut instances = map.lock().expect("DI instance lock poisoned");
502                        instances
503                            .entry(provider_key(descriptor))
504                            .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
505                            .clone()
506                    };
507                    let value = cell
508                        .get_or_try_init(
509                            || async move { (descriptor.factory)(self, context).await },
510                        )
511                        .await?;
512                    Ok(value.clone())
513                }
514                Scope::Prototype => unreachable!(),
515            }
516        })
517    }
518
519    fn add_dependency_edge(
520        &self,
521        parent: usize,
522        dependency: usize,
523        context: &ResolutionContext,
524        dependency_name: &'static str,
525    ) -> Result<(), DiError> {
526        let mut graph = self
527            .inner
528            .dependency_graph
529            .lock()
530            .expect("DI dependency graph lock poisoned");
531        graph.entry(parent).or_default().insert(dependency);
532        if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
533            let mut chain = context.chain.clone();
534            chain.push(dependency_name);
535            return Err(DiError::CircularDependency(chain.join(" -> ")));
536        }
537        Ok(())
538    }
539}
540
541fn provider_key(provider: &'static ProviderDescriptor) -> usize {
542    provider as *const ProviderDescriptor as usize
543}
544
545fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
546    provider as *const RuntimeProvider as usize
547}
548
549fn graph_path_exists(
550    graph: &HashMap<usize, HashSet<usize>>,
551    current: usize,
552    target: usize,
553    visited: &mut HashSet<usize>,
554) -> bool {
555    if current == target {
556        return true;
557    }
558    visited.insert(current)
559        && graph.get(&current).is_some_and(|dependencies| {
560            dependencies
561                .iter()
562                .any(|next| graph_path_exists(graph, *next, target, visited))
563        })
564}
565
566fn shutdown_order(
567    keys: impl IntoIterator<Item = usize>,
568    graph: &HashMap<usize, HashSet<usize>>,
569) -> Vec<usize> {
570    let keys = keys.into_iter().collect::<HashSet<_>>();
571    let mut incoming = keys
572        .iter()
573        .map(|key| (*key, 0usize))
574        .collect::<HashMap<_, _>>();
575    for (parent, dependencies) in graph {
576        if !keys.contains(parent) {
577            continue;
578        }
579        for dependency in dependencies {
580            if let Some(count) = incoming.get_mut(dependency) {
581                *count += 1;
582            }
583        }
584    }
585
586    let mut ready = incoming
587        .iter()
588        .filter_map(|(key, count)| (*count == 0).then_some(*key))
589        .collect::<VecDeque<_>>();
590    let mut order = Vec::with_capacity(keys.len());
591    while let Some(parent) = ready.pop_front() {
592        order.push(parent);
593        if let Some(dependencies) = graph.get(&parent) {
594            for dependency in dependencies {
595                if let Some(count) = incoming.get_mut(dependency) {
596                    *count -= 1;
597                    if *count == 0 {
598                        ready.push_back(*dependency);
599                    }
600                }
601            }
602        }
603    }
604    // A cycle is normally rejected during resolution. Preserve shutdown safety
605    // if state was externally corrupted or a future resolver bypasses that check.
606    let ordered = order.iter().copied().collect::<HashSet<_>>();
607    order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
608    order
609}
610
611pub struct RequestContext<'a> {
612    container: &'a Container,
613    context: ResolutionContext,
614}
615impl RequestContext<'_> {
616    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
617    where
618        T: Any + Send + Sync,
619    {
620        self.container.resolve_dependency::<T>(&self.context).await
621    }
622}
623
624#[doc(hidden)]
625pub mod __private {
626    pub use inventory;
627    pub use tokio;
628
629    use std::fmt::Display;
630
631    use crate::DiError;
632
633    pub fn factory_result<T, E: Display>(
634        result: Result<T, E>,
635        provider: &'static str,
636    ) -> Result<T, DiError> {
637        result.map_err(|error| DiError::Factory {
638            provider,
639            message: error.to_string(),
640        })
641    }
642
643    pub trait IntoLifecycleResult {
644        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
645    }
646
647    impl IntoLifecycleResult for () {
648        fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
649            Ok(())
650        }
651    }
652
653    impl<E: Display> IntoLifecycleResult for Result<(), E> {
654        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
655            self.map_err(|error| DiError::Lifecycle {
656                provider,
657                message: error.to_string(),
658            })
659        }
660    }
661
662    pub fn lifecycle_result<R: IntoLifecycleResult>(
663        result: R,
664        provider: &'static str,
665    ) -> Result<(), DiError> {
666        result.into_lifecycle_result(provider)
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use crate::{Lazy, Provider, configuration_properties, provider, singleton};
674    use std::sync::atomic::{AtomicUsize, Ordering};
675    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
676    struct SyncDependency;
677    #[singleton]
678    fn sync_dependency() -> SyncDependency {
679        CREATIONS.fetch_add(1, Ordering::SeqCst);
680        SyncDependency
681    }
682    struct AsyncDependency {
683        _sync: Arc<SyncDependency>,
684    }
685    #[singleton]
686    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
687        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
688        AsyncDependency { _sync: sync }
689    }
690    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
691    async fn singleton_is_concurrent_safe() {
692        let container = Arc::new(Container::new().unwrap());
693        let mut tasks = tokio::task::JoinSet::new();
694        for _ in 0..32 {
695            let c = container.clone();
696            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
697        }
698        let mut values = vec![];
699        while let Some(v) = tasks.join_next().await {
700            values.push(v.unwrap());
701        }
702        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
703        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
704    }
705
706    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
707    struct PrototypeBean(usize);
708    #[singleton(scope = "prototype")]
709    fn prototype_bean() -> PrototypeBean {
710        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
711    }
712
713    struct RequestBean;
714    #[singleton(scope = "request")]
715    fn request_bean() -> RequestBean {
716        RequestBean
717    }
718
719    trait Greeting: Send + Sync {
720        fn text(&self) -> &'static str;
721    }
722    struct English;
723    impl Greeting for English {
724        fn text(&self) -> &'static str {
725            "hello"
726        }
727    }
728    struct Hindi;
729    impl Greeting for Hindi {
730        fn text(&self) -> &'static str {
731            "namaste"
732        }
733    }
734
735    #[singleton(name = "english", primary)]
736    fn english_greeting() -> Arc<dyn Greeting> {
737        Arc::new(English)
738    }
739    #[singleton(name = "hindi")]
740    fn hindi_greeting() -> Arc<dyn Greeting> {
741        Arc::new(Hindi)
742    }
743
744    struct MissingOptional;
745    struct Greeter {
746        greeting: Arc<dyn Greeting>,
747        optional: Option<Arc<MissingOptional>>,
748    }
749    struct GreetingLabel(&'static str);
750    #[singleton]
751    impl Greeter {
752        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
753            Self { greeting, optional }
754        }
755
756        #[provider]
757        fn label(&self) -> GreetingLabel {
758            GreetingLabel(self.greeting.text())
759        }
760    }
761    struct QualifiedGreeter(Arc<dyn Greeting>);
762    #[singleton]
763    impl QualifiedGreeter {
764        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
765            Self(greeting)
766        }
767    }
768
769    struct StaticConfig(&'static str);
770    struct ServiceWithStaticBean(Arc<StaticConfig>);
771    #[singleton]
772    impl ServiceWithStaticBean {
773        fn new(config: Arc<StaticConfig>) -> Self {
774            Self(config)
775        }
776
777        // No `&self`: this provider can run before its owning singleton and can
778        // therefore participate in the service's constructor graph.
779        #[provider]
780        fn config() -> StaticConfig {
781            StaticConfig("static-bean")
782        }
783    }
784
785    static STARTED: AtomicUsize = AtomicUsize::new(0);
786    static STOPPED: AtomicUsize = AtomicUsize::new(0);
787    struct Managed;
788    #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
789    impl Managed {
790        fn new() -> Self {
791            Self
792        }
793        async fn start(&self) {
794            STARTED.fetch_add(1, Ordering::SeqCst);
795        }
796        async fn stop(&self) {
797            STOPPED.fetch_add(1, Ordering::SeqCst);
798        }
799    }
800
801    struct ProfileBean;
802    #[singleton(profile = "test")]
803    fn profile_bean() -> ProfileBean {
804        ProfileBean
805    }
806
807    #[derive(Debug)]
808    struct Handler(&'static str);
809    #[singleton(name = "first")]
810    fn first_handler() -> Handler {
811        Handler("first")
812    }
813    #[singleton(name = "second")]
814    fn second_handler() -> Handler {
815        Handler("second")
816    }
817    struct Pipeline(Vec<Arc<Handler>>);
818    #[singleton]
819    impl Pipeline {
820        fn new(handlers: Vec<Arc<Handler>>) -> Self {
821            Self(handlers)
822        }
823    }
824
825    #[configuration_properties("testing_dep")]
826    struct TestProperties {
827        port: u16,
828    }
829
830    struct DeferredTarget;
831    #[singleton]
832    fn deferred_target() -> DeferredTarget {
833        DeferredTarget
834    }
835    struct DeferredConsumer {
836        provider: Provider<DeferredTarget>,
837        lazy: Lazy<DeferredTarget>,
838    }
839    #[singleton]
840    impl DeferredConsumer {
841        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
842            Self { provider, lazy }
843        }
844    }
845
846    struct ConditionalBean;
847    #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
848    fn conditional_bean() -> ConditionalBean {
849        ConditionalBean
850    }
851
852    struct FallibleDependency;
853    #[provider]
854    fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
855        Err("expected factory failure")
856    }
857
858    struct RequestLeaf;
859    #[singleton(scope = "request")]
860    fn request_leaf() -> RequestLeaf {
861        RequestLeaf
862    }
863    struct CaptiveSingleton {
864        _request: Arc<RequestLeaf>,
865    }
866    #[singleton]
867    impl CaptiveSingleton {
868        fn new(request: Arc<RequestLeaf>) -> Self {
869            Self { _request: request }
870        }
871    }
872
873    struct CycleA {
874        _b: Arc<CycleB>,
875    }
876    struct CycleB {
877        _a: Arc<CycleA>,
878    }
879    #[singleton]
880    impl CycleA {
881        fn new(b: Arc<CycleB>) -> Self {
882            Self { _b: b }
883        }
884    }
885    #[singleton]
886    impl CycleB {
887        fn new(a: Arc<CycleA>) -> Self {
888            Self { _a: a }
889        }
890    }
891
892    struct AllGreetings(Vec<Arc<dyn Greeting>>);
893    #[singleton]
894    impl AllGreetings {
895        fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
896            Self(greetings)
897        }
898    }
899
900    struct ProfileDeferred;
901    #[singleton(profile = "test")]
902    fn profile_deferred() -> ProfileDeferred {
903        ProfileDeferred
904    }
905    struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
906    #[singleton]
907    impl ProfileDeferredConsumer {
908        fn new(provider: Provider<ProfileDeferred>) -> Self {
909            Self(provider)
910        }
911    }
912
913    struct StandaloneBean;
914    #[provider]
915    fn standalone_bean() -> StandaloneBean {
916        StandaloneBean
917    }
918
919    #[tokio::test]
920    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
921        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
922        let container = Container::with_profiles(["test"]).unwrap();
923        let first = container.resolve::<PrototypeBean>().await.unwrap();
924        let second = container.resolve::<PrototypeBean>().await.unwrap();
925        assert_ne!(first.0, second.0);
926
927        assert!(matches!(
928            container.resolve::<RequestBean>().await,
929            Err(DiError::RequestScopeUnavailable(_))
930        ));
931        let request = container.request_context();
932        let request_first = request.resolve::<RequestBean>().await.unwrap();
933        let request_second = request.resolve::<RequestBean>().await.unwrap();
934        assert!(Arc::ptr_eq(&request_first, &request_second));
935
936        let greeter = container.resolve::<Greeter>().await.unwrap();
937        assert_eq!(greeter.greeting.text(), "hello");
938        assert_eq!(
939            container.resolve::<GreetingLabel>().await.unwrap().0,
940            "hello"
941        );
942        assert!(greeter.optional.is_none());
943        assert_eq!(
944            container
945                .resolve::<QualifiedGreeter>()
946                .await
947                .unwrap()
948                .0
949                .text(),
950            "namaste"
951        );
952        let hindi = container
953            .resolve_named::<Arc<dyn Greeting>>("hindi")
954            .await
955            .unwrap();
956        assert_eq!(hindi.text(), "namaste");
957        let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
958        assert_eq!(static_bean_service.0.0, "static-bean");
959        container.resolve::<ProfileBean>().await.unwrap();
960        let pipeline = container.resolve::<Pipeline>().await.unwrap();
961        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
962        handler_names.sort_unstable();
963        assert_eq!(handler_names, ["first", "second"]);
964
965        // This process-local key is unique to the test crate.
966        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
967        let properties = container.resolve::<TestProperties>().await.unwrap();
968        assert_eq!(properties.port, 8080);
969        container.resolve::<ConditionalBean>().await.unwrap();
970        container.resolve::<StandaloneBean>().await.unwrap();
971
972        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
973        deferred.provider.get().await.unwrap();
974        deferred.lazy.get().await.unwrap();
975        let profile_deferred = container
976            .resolve::<ProfileDeferredConsumer>()
977            .await
978            .unwrap();
979        profile_deferred.0.get().await.unwrap();
980        let greetings = container.resolve::<AllGreetings>().await.unwrap();
981        assert_eq!(greetings.0.len(), 2);
982
983        container.initialize_eager().await.unwrap();
984        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
985        container.shutdown().await.unwrap();
986        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
987        assert!(matches!(
988            container.resolve::<SyncDependency>().await,
989            Err(DiError::ContainerShutdown)
990        ));
991    }
992
993    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
994    async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
995        let container = Arc::new(Container::new().unwrap());
996        assert!(matches!(
997            container.resolve::<FallibleDependency>().await,
998            Err(DiError::Factory { .. })
999        ));
1000
1001        let request = container.request_context();
1002        let captive = request.resolve::<CaptiveSingleton>().await;
1003        assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
1004
1005        let first = {
1006            let container = container.clone();
1007            tokio::spawn(async move { container.resolve::<CycleA>().await })
1008        };
1009        let second = {
1010            let container = container.clone();
1011            tokio::spawn(async move { container.resolve::<CycleB>().await })
1012        };
1013        let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
1014            (first.await.unwrap(), second.await.unwrap())
1015        })
1016        .await
1017        .expect("cross-task cycle resolution must not deadlock");
1018        assert!(
1019            matches!(results.0, Err(DiError::CircularDependency(_)))
1020                || matches!(results.1, Err(DiError::CircularDependency(_)))
1021        );
1022    }
1023}