Skip to main content

auto_di/
runtime.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet, VecDeque},
4    future::Future,
5    pin::Pin,
6    sync::{
7        Arc, Mutex, OnceLock,
8        atomic::{AtomicBool, Ordering},
9    },
10};
11
12pub type DynArc = Arc<dyn Any + Send + Sync>;
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14type Factory =
15    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
16type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum Scope {
20    Singleton,
21    Prototype,
22    Request,
23}
24
25#[doc(hidden)]
26pub struct ProviderDescriptor {
27    type_id: fn() -> TypeId,
28    type_name: fn() -> &'static str,
29    factory: Factory,
30    pub name: Option<&'static str>,
31    pub primary: bool,
32    pub scope: Scope,
33    pub eager: bool,
34    pub profile: Option<&'static str>,
35    pub condition_key: Option<&'static str>,
36    pub condition_value: Option<&'static str>,
37    pub destroy: Option<Destroy>,
38}
39
40impl ProviderDescriptor {
41    #[doc(hidden)]
42    pub const fn new(
43        type_id: fn() -> TypeId,
44        type_name: fn() -> &'static str,
45        factory: Factory,
46    ) -> Self {
47        Self::configured(
48            type_id,
49            type_name,
50            factory,
51            None,
52            false,
53            Scope::Singleton,
54            false,
55            None,
56            None,
57            None,
58            None,
59        )
60    }
61
62    #[allow(clippy::too_many_arguments)]
63    #[doc(hidden)]
64    pub const fn configured(
65        type_id: fn() -> TypeId,
66        type_name: fn() -> &'static str,
67        factory: Factory,
68        name: Option<&'static str>,
69        primary: bool,
70        scope: Scope,
71        eager: bool,
72        profile: Option<&'static str>,
73        condition_key: Option<&'static str>,
74        condition_value: Option<&'static str>,
75        destroy: Option<Destroy>,
76    ) -> Self {
77        Self {
78            type_id,
79            type_name,
80            factory,
81            name,
82            primary,
83            scope,
84            eager,
85            profile,
86            condition_key,
87            condition_value,
88            destroy,
89        }
90    }
91
92    fn active(&self, profiles: &[String]) -> bool {
93        let profile_matches = self
94            .profile
95            .is_none_or(|required| profiles.iter().any(|p| p == required));
96        let condition_matches = self.condition_key.is_none_or(|key| {
97            let actual = std::env::var(key).ok();
98            self.condition_value.map_or(actual.is_some(), |expected| {
99                actual.as_deref() == Some(expected)
100            })
101        });
102        profile_matches && condition_matches
103    }
104}
105
106inventory::collect!(ProviderDescriptor);
107
108type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
109
110struct RuntimeProvider {
111    descriptor: &'static ProviderDescriptor,
112    singleton: tokio::sync::OnceCell<DynArc>,
113}
114
115impl RuntimeProvider {
116    fn new(descriptor: &'static ProviderDescriptor) -> Self {
117        Self {
118            descriptor,
119            singleton: tokio::sync::OnceCell::new(),
120        }
121    }
122}
123
124#[derive(Clone, Default)]
125pub struct ResolutionContext {
126    pub(crate) chain: Vec<&'static str>,
127    provider_chain: Vec<usize>,
128    scope_chain: Vec<Scope>,
129    pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
130}
131
132#[derive(Debug, thiserror::Error)]
133pub enum DiError {
134    #[error("no active provider is registered for {0}")]
135    MissingProvider(&'static str),
136    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
137    AmbiguousProvider(&'static str),
138    #[error("multiple primary providers are registered for {0}")]
139    MultiplePrimary(&'static str),
140    #[error("duplicate provider name '{name}' is registered for {type_name}")]
141    DuplicateProviderName {
142        type_name: &'static str,
143        name: &'static str,
144    },
145    #[error("circular dependency detected: {0}")]
146    CircularDependency(String),
147    #[error("provider for {0} returned an incompatible type")]
148    TypeMismatch(&'static str),
149    #[error("request-scoped dependency {0} was resolved outside RequestContext")]
150    RequestScopeUnavailable(&'static str),
151    #[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
152    InvalidScope {
153        singleton: &'static str,
154        request: &'static str,
155    },
156    #[error("the dependency container has already been shut down")]
157    ContainerShutdown,
158    #[error("configuration property {key} is missing or invalid: {message}")]
159    Configuration { key: String, message: String },
160    #[error("provider for {provider} failed: {message}")]
161    Factory {
162        provider: &'static str,
163        message: String,
164    },
165    #[error("lifecycle hook failed for {provider}: {message}")]
166    Lifecycle {
167        provider: &'static str,
168        message: String,
169    },
170}
171
172struct ContainerInner {
173    providers: HashMap<TypeId, Vec<RuntimeProvider>>,
174    dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
175    shut_down: AtomicBool,
176}
177
178#[derive(Clone)]
179pub struct Container {
180    inner: Arc<ContainerInner>,
181}
182
183static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
184
185pub fn global_container() -> Result<&'static Container, DiError> {
186    if let Some(container) = GLOBAL_CONTAINER.get() {
187        return Ok(container);
188    }
189    let container = Container::new()?;
190    let _ = GLOBAL_CONTAINER.set(container);
191    Ok(GLOBAL_CONTAINER
192        .get()
193        .expect("global DI container initialized"))
194}
195
196pub async fn resolve<T>() -> Result<Arc<T>, DiError>
197where
198    T: Any + Send + Sync,
199{
200    global_container()?.resolve::<T>().await
201}
202
203impl Container {
204    pub fn new() -> Result<Self, DiError> {
205        let profiles = std::env::var("APP_PROFILES")
206            .unwrap_or_default()
207            .split(',')
208            .map(str::trim)
209            .filter(|p| !p.is_empty())
210            .map(str::to_owned)
211            .collect::<Vec<_>>();
212        Self::with_profiles(profiles)
213    }
214
215    pub fn with_profiles(
216        profiles: impl IntoIterator<Item = impl Into<String>>,
217    ) -> Result<Self, DiError> {
218        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
219        let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
220        for provider in inventory::iter::<ProviderDescriptor> {
221            if provider.active(&profiles) {
222                providers
223                    .entry((provider.type_id)())
224                    .or_default()
225                    .push(RuntimeProvider::new(provider));
226            }
227        }
228        for group in providers.values() {
229            if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
230                return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
231            }
232            let mut names = HashSet::new();
233            for provider in group {
234                if let Some(name) = provider.descriptor.name
235                    && !names.insert(name)
236                {
237                    return Err(DiError::DuplicateProviderName {
238                        type_name: (provider.descriptor.type_name)(),
239                        name,
240                    });
241                }
242            }
243        }
244        Ok(Self {
245            inner: Arc::new(ContainerInner {
246                providers,
247                dependency_graph: Mutex::new(HashMap::new()),
248                shut_down: AtomicBool::new(false),
249            }),
250        })
251    }
252
253    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
254    where
255        T: Any + Send + Sync,
256    {
257        self.resolve_dependency::<T>(&ResolutionContext::default())
258            .await
259    }
260
261    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
262    where
263        T: Any + Send + Sync,
264    {
265        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
266            .await
267    }
268
269    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
270    where
271        T: Any + Send + Sync,
272    {
273        match self.resolve::<T>().await {
274            Ok(value) => Ok(Some(value)),
275            Err(DiError::MissingProvider(_)) => Ok(None),
276            Err(error) => Err(error),
277        }
278    }
279
280    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
281    where
282        T: Any + Send + Sync,
283    {
284        self.resolve_all_dependency::<T>(&ResolutionContext::default())
285            .await
286    }
287
288    pub fn request_context(&self) -> RequestContext<'_> {
289        RequestContext {
290            container: self,
291            context: ResolutionContext {
292                chain: vec![],
293                provider_chain: vec![],
294                scope_chain: vec![],
295                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
296            },
297        }
298    }
299
300    pub async fn initialize_eager(&self) -> Result<(), DiError> {
301        for providers in self.inner.providers.values() {
302            for provider in providers.iter().filter(|p| p.descriptor.eager) {
303                self.resolve_provider(provider, ResolutionContext::default())
304                    .await?;
305            }
306        }
307        Ok(())
308    }
309
310    /// Validates the active singleton graph by constructing every singleton.
311    /// This surfaces missing, ambiguous, circular, and invalid-scope dependencies
312    /// during application startup instead of on the first request.
313    pub async fn validate(&self) -> Result<(), DiError> {
314        for providers in self.inner.providers.values() {
315            for provider in providers
316                .iter()
317                .filter(|provider| provider.descriptor.scope == Scope::Singleton)
318            {
319                self.resolve_provider(provider, ResolutionContext::default())
320                    .await?;
321            }
322        }
323        Ok(())
324    }
325
326    pub async fn shutdown(&self) -> Result<(), DiError> {
327        if self.inner.shut_down.swap(true, Ordering::AcqRel) {
328            return Ok(());
329        }
330        let graph = self
331            .inner
332            .dependency_graph
333            .lock()
334            .expect("DI dependency graph lock poisoned")
335            .clone();
336        let providers = self
337            .inner
338            .providers
339            .values()
340            .flatten()
341            .map(|provider| (runtime_provider_key(provider), provider))
342            .collect::<HashMap<_, _>>();
343        let order = shutdown_order(providers.keys().copied(), &graph);
344        for key in order {
345            let provider = providers[&key];
346            if let (Some(destroy), Some(value)) =
347                (provider.descriptor.destroy, provider.singleton.get())
348            {
349                destroy(value.clone()).await?;
350            }
351        }
352        Ok(())
353    }
354
355    #[doc(hidden)]
356    pub async fn resolve_dependency<T>(
357        &self,
358        context: &ResolutionContext,
359    ) -> Result<Arc<T>, DiError>
360    where
361        T: Any + Send + Sync,
362    {
363        self.resolve_selected::<T>(None, context).await
364    }
365
366    #[doc(hidden)]
367    pub async fn resolve_named_dependency<T>(
368        &self,
369        name: &str,
370        context: &ResolutionContext,
371    ) -> Result<Arc<T>, DiError>
372    where
373        T: Any + Send + Sync,
374    {
375        self.resolve_selected::<T>(Some(name), context).await
376    }
377
378    #[doc(hidden)]
379    pub async fn resolve_optional_dependency<T>(
380        &self,
381        context: &ResolutionContext,
382    ) -> Result<Option<Arc<T>>, DiError>
383    where
384        T: Any + Send + Sync,
385    {
386        match self.resolve_dependency::<T>(context).await {
387            Ok(value) => Ok(Some(value)),
388            Err(DiError::MissingProvider(_)) => Ok(None),
389            Err(error) => Err(error),
390        }
391    }
392
393    #[doc(hidden)]
394    pub async fn resolve_all_dependency<T>(
395        &self,
396        context: &ResolutionContext,
397    ) -> Result<Vec<Arc<T>>, DiError>
398    where
399        T: Any + Send + Sync,
400    {
401        let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
402            return Ok(vec![]);
403        };
404        let mut values = Vec::with_capacity(providers.len());
405        for provider in providers {
406            let value = self.resolve_provider(provider, context.clone()).await?;
407            values.push(
408                value
409                    .downcast::<T>()
410                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
411            );
412        }
413        Ok(values)
414    }
415
416    async fn resolve_selected<T>(
417        &self,
418        name: Option<&str>,
419        context: &ResolutionContext,
420    ) -> Result<Arc<T>, DiError>
421    where
422        T: Any + Send + Sync,
423    {
424        let type_name = std::any::type_name::<T>();
425        let providers = self
426            .inner
427            .providers
428            .get(&TypeId::of::<T>())
429            .ok_or(DiError::MissingProvider(type_name))?;
430        let selected = if let Some(name) = name {
431            providers
432                .iter()
433                .find(|p| p.descriptor.name == Some(name))
434                .ok_or(DiError::MissingProvider(type_name))?
435        } else if providers.len() == 1 {
436            &providers[0]
437        } else {
438            providers
439                .iter()
440                .find(|p| p.descriptor.primary)
441                .ok_or(DiError::AmbiguousProvider(type_name))?
442        };
443        let value = self.resolve_provider(selected, context.clone()).await?;
444        value
445            .downcast::<T>()
446            .map_err(|_| DiError::TypeMismatch(type_name))
447    }
448
449    fn resolve_provider<'a>(
450        &'a self,
451        provider: &'a RuntimeProvider,
452        mut context: ResolutionContext,
453    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
454        Box::pin(async move {
455            if self.inner.shut_down.load(Ordering::Acquire) {
456                return Err(DiError::ContainerShutdown);
457            }
458            let descriptor = provider.descriptor;
459            let type_name = (descriptor.type_name)();
460            if context.chain.contains(&type_name) {
461                context.chain.push(type_name);
462                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
463            }
464            let runtime_key = runtime_provider_key(provider);
465            if let Some(parent) = context.provider_chain.last().copied() {
466                self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
467            }
468            if descriptor.scope == Scope::Request
469                && let Some(position) = context
470                    .scope_chain
471                    .iter()
472                    .position(|scope| *scope == Scope::Singleton)
473            {
474                return Err(DiError::InvalidScope {
475                    singleton: context.chain[position],
476                    request: type_name,
477                });
478            }
479            context.chain.push(type_name);
480            context.provider_chain.push(runtime_key);
481            context.scope_chain.push(descriptor.scope);
482            if descriptor.scope == Scope::Prototype {
483                return (descriptor.factory)(self, context).await;
484            }
485            match descriptor.scope {
486                Scope::Singleton => {
487                    let value = provider
488                        .singleton
489                        .get_or_try_init(
490                            || async move { (descriptor.factory)(self, context).await },
491                        )
492                        .await?;
493                    Ok(value.clone())
494                }
495                Scope::Request => {
496                    let map = context
497                        .request_instances
498                        .as_deref()
499                        .ok_or(DiError::RequestScopeUnavailable(type_name))?;
500                    let cell = {
501                        let mut instances = map.lock().expect("DI instance lock poisoned");
502                        instances
503                            .entry(provider_key(descriptor))
504                            .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
505                            .clone()
506                    };
507                    let value = cell
508                        .get_or_try_init(
509                            || async move { (descriptor.factory)(self, context).await },
510                        )
511                        .await?;
512                    Ok(value.clone())
513                }
514                Scope::Prototype => unreachable!(),
515            }
516        })
517    }
518
519    fn add_dependency_edge(
520        &self,
521        parent: usize,
522        dependency: usize,
523        context: &ResolutionContext,
524        dependency_name: &'static str,
525    ) -> Result<(), DiError> {
526        let mut graph = self
527            .inner
528            .dependency_graph
529            .lock()
530            .expect("DI dependency graph lock poisoned");
531        graph.entry(parent).or_default().insert(dependency);
532        if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
533            let mut chain = context.chain.clone();
534            chain.push(dependency_name);
535            return Err(DiError::CircularDependency(chain.join(" -> ")));
536        }
537        Ok(())
538    }
539}
540
541fn provider_key(provider: &'static ProviderDescriptor) -> usize {
542    provider as *const ProviderDescriptor as usize
543}
544
545fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
546    provider as *const RuntimeProvider as usize
547}
548
549fn graph_path_exists(
550    graph: &HashMap<usize, HashSet<usize>>,
551    current: usize,
552    target: usize,
553    visited: &mut HashSet<usize>,
554) -> bool {
555    if current == target {
556        return true;
557    }
558    visited.insert(current)
559        && graph.get(&current).is_some_and(|dependencies| {
560            dependencies
561                .iter()
562                .any(|next| graph_path_exists(graph, *next, target, visited))
563        })
564}
565
566fn shutdown_order(
567    keys: impl IntoIterator<Item = usize>,
568    graph: &HashMap<usize, HashSet<usize>>,
569) -> Vec<usize> {
570    let keys = keys.into_iter().collect::<HashSet<_>>();
571    let mut incoming = keys
572        .iter()
573        .map(|key| (*key, 0usize))
574        .collect::<HashMap<_, _>>();
575    for (parent, dependencies) in graph {
576        if !keys.contains(parent) {
577            continue;
578        }
579        for dependency in dependencies {
580            if let Some(count) = incoming.get_mut(dependency) {
581                *count += 1;
582            }
583        }
584    }
585
586    let mut ready = incoming
587        .iter()
588        .filter_map(|(key, count)| (*count == 0).then_some(*key))
589        .collect::<VecDeque<_>>();
590    let mut order = Vec::with_capacity(keys.len());
591    while let Some(parent) = ready.pop_front() {
592        order.push(parent);
593        if let Some(dependencies) = graph.get(&parent) {
594            for dependency in dependencies {
595                if let Some(count) = incoming.get_mut(dependency) {
596                    *count -= 1;
597                    if *count == 0 {
598                        ready.push_back(*dependency);
599                    }
600                }
601            }
602        }
603    }
604    // A cycle is normally rejected during resolution. Preserve shutdown safety
605    // if state was externally corrupted or a future resolver bypasses that check.
606    let ordered = order.iter().copied().collect::<HashSet<_>>();
607    order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
608    order
609}
610
611pub struct RequestContext<'a> {
612    container: &'a Container,
613    context: ResolutionContext,
614}
615impl RequestContext<'_> {
616    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
617    where
618        T: Any + Send + Sync,
619    {
620        self.container.resolve_dependency::<T>(&self.context).await
621    }
622}
623
624#[doc(hidden)]
625pub mod __private {
626    pub use inventory;
627    pub use tokio;
628
629    use std::fmt::Display;
630
631    use crate::DiError;
632
633    pub fn factory_result<T, E: Display>(
634        result: Result<T, E>,
635        provider: &'static str,
636    ) -> Result<T, DiError> {
637        result.map_err(|error| DiError::Factory {
638            provider,
639            message: error.to_string(),
640        })
641    }
642
643    pub trait IntoLifecycleResult {
644        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
645    }
646
647    impl IntoLifecycleResult for () {
648        fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
649            Ok(())
650        }
651    }
652
653    impl<E: Display> IntoLifecycleResult for Result<(), E> {
654        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
655            self.map_err(|error| DiError::Lifecycle {
656                provider,
657                message: error.to_string(),
658            })
659        }
660    }
661
662    pub fn lifecycle_result<R: IntoLifecycleResult>(
663        result: R,
664        provider: &'static str,
665    ) -> Result<(), DiError> {
666        result.into_lifecycle_result(provider)
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use crate::{
674        Lazy, Provider, configuration_properties, injectable, injected, provider, singleton,
675    };
676    use std::sync::atomic::{AtomicUsize, Ordering};
677    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
678    struct SyncDependency;
679    #[singleton]
680    fn sync_dependency() -> SyncDependency {
681        CREATIONS.fetch_add(1, Ordering::SeqCst);
682        SyncDependency
683    }
684    struct AsyncDependency {
685        _sync: Arc<SyncDependency>,
686    }
687    #[singleton]
688    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
689        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
690        AsyncDependency { _sync: sync }
691    }
692    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
693    async fn singleton_is_concurrent_safe() {
694        let container = Arc::new(Container::new().unwrap());
695        let mut tasks = tokio::task::JoinSet::new();
696        for _ in 0..32 {
697            let c = container.clone();
698            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
699        }
700        let mut values = vec![];
701        while let Some(v) = tasks.join_next().await {
702            values.push(v.unwrap());
703        }
704        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
705        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
706    }
707
708    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
709    struct PrototypeBean(usize);
710    #[singleton(scope = "prototype")]
711    fn prototype_bean() -> PrototypeBean {
712        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
713    }
714
715    struct RequestBean;
716    #[singleton(scope = "request")]
717    fn request_bean() -> RequestBean {
718        RequestBean
719    }
720
721    trait Greeting: Send + Sync {
722        fn text(&self) -> &'static str;
723    }
724    struct English;
725    impl Greeting for English {
726        fn text(&self) -> &'static str {
727            "hello"
728        }
729    }
730    struct Hindi;
731    impl Greeting for Hindi {
732        fn text(&self) -> &'static str {
733            "namaste"
734        }
735    }
736
737    #[singleton(name = "english", primary)]
738    fn english_greeting() -> Arc<dyn Greeting> {
739        Arc::new(English)
740    }
741    #[singleton(name = "hindi")]
742    fn hindi_greeting() -> Arc<dyn Greeting> {
743        Arc::new(Hindi)
744    }
745
746    struct MissingOptional;
747    struct Greeter {
748        greeting: Arc<dyn Greeting>,
749        optional: Option<Arc<MissingOptional>>,
750    }
751    struct GreetingLabel(&'static str);
752    #[singleton]
753    impl Greeter {
754        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
755            Self { greeting, optional }
756        }
757
758        #[provider]
759        fn label(&self) -> GreetingLabel {
760            GreetingLabel(self.greeting.text())
761        }
762    }
763    struct QualifiedGreeter(Arc<dyn Greeting>);
764    #[singleton]
765    impl QualifiedGreeter {
766        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
767            Self(greeting)
768        }
769    }
770
771    struct StaticConfig(&'static str);
772    struct ServiceWithStaticBean(Arc<StaticConfig>);
773    #[singleton]
774    impl ServiceWithStaticBean {
775        fn new(config: Arc<StaticConfig>) -> Self {
776            Self(config)
777        }
778
779        // No `&self`: this provider can run before its owning singleton and can
780        // therefore participate in the service's constructor graph.
781        #[provider]
782        fn config() -> StaticConfig {
783            StaticConfig("static-bean")
784        }
785    }
786
787    static STARTED: AtomicUsize = AtomicUsize::new(0);
788    static STOPPED: AtomicUsize = AtomicUsize::new(0);
789    struct Managed;
790    #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
791    impl Managed {
792        fn new() -> Self {
793            Self
794        }
795        async fn start(&self) {
796            STARTED.fetch_add(1, Ordering::SeqCst);
797        }
798        async fn stop(&self) {
799            STOPPED.fetch_add(1, Ordering::SeqCst);
800        }
801    }
802
803    struct ProfileBean;
804    #[singleton(profile = "test")]
805    fn profile_bean() -> ProfileBean {
806        ProfileBean
807    }
808
809    #[derive(Debug)]
810    struct Handler(&'static str);
811    #[singleton(name = "first")]
812    fn first_handler() -> Handler {
813        Handler("first")
814    }
815    #[singleton(name = "second")]
816    fn second_handler() -> Handler {
817        Handler("second")
818    }
819    struct Pipeline(Vec<Arc<Handler>>);
820    #[singleton]
821    impl Pipeline {
822        fn new(handlers: Vec<Arc<Handler>>) -> Self {
823            Self(handlers)
824        }
825    }
826
827    #[configuration_properties("testing_dep")]
828    struct TestProperties {
829        port: u16,
830    }
831
832    struct DeferredTarget;
833    #[singleton]
834    fn deferred_target() -> DeferredTarget {
835        DeferredTarget
836    }
837    struct DeferredConsumer {
838        provider: Provider<DeferredTarget>,
839        lazy: Lazy<DeferredTarget>,
840    }
841    #[singleton]
842    impl DeferredConsumer {
843        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
844            Self { provider, lazy }
845        }
846    }
847
848    struct ConditionalBean;
849    #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
850    fn conditional_bean() -> ConditionalBean {
851        ConditionalBean
852    }
853
854    struct FallibleDependency;
855    #[provider]
856    fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
857        Err("expected factory failure")
858    }
859
860    #[derive(Clone)]
861    struct FieldDependency(u32);
862    #[singleton]
863    fn field_dependency() -> FieldDependency {
864        FieldDependency(10)
865    }
866
867    #[injectable]
868    struct FieldInjectedFacade {
869        dependency: Arc<FieldDependency>,
870        #[inject(7)]
871        literal: u32,
872        #[inject(|dependency: Arc<FieldDependency>| dependency)]
873        transformed: Arc<FieldDependency>,
874    }
875
876    #[injectable]
877    struct TupleInjected(#[inject(123)] u32);
878
879    #[injected]
880    fn injected_total(
881        dependency: Arc<FieldDependency>,
882        #[inject(2)] offset: u32,
883        multiplier: u32,
884    ) -> u32 {
885        (dependency.0 + offset) * multiplier
886    }
887
888    #[injected]
889    fn caller_owned_arc(
890        #[argument] dependency: Arc<FieldDependency>,
891        #[inject(1)] offset: u32,
892    ) -> u32 {
893        dependency.0 + offset
894    }
895
896    struct InjectedMethods;
897    impl InjectedMethods {
898        #[injected]
899        fn calculate(&self, dependency: Arc<FieldDependency>, value: u32) -> u32 {
900            dependency.0 + value
901        }
902    }
903
904    struct RequestLeaf;
905    #[singleton(scope = "request")]
906    fn request_leaf() -> RequestLeaf {
907        RequestLeaf
908    }
909    struct CaptiveSingleton {
910        _request: Arc<RequestLeaf>,
911    }
912    #[singleton]
913    impl CaptiveSingleton {
914        fn new(request: Arc<RequestLeaf>) -> Self {
915            Self { _request: request }
916        }
917    }
918
919    struct CycleA {
920        _b: Arc<CycleB>,
921    }
922    struct CycleB {
923        _a: Arc<CycleA>,
924    }
925    #[singleton]
926    impl CycleA {
927        fn new(b: Arc<CycleB>) -> Self {
928            Self { _b: b }
929        }
930    }
931    #[singleton]
932    impl CycleB {
933        fn new(a: Arc<CycleA>) -> Self {
934            Self { _a: a }
935        }
936    }
937
938    struct AllGreetings(Vec<Arc<dyn Greeting>>);
939    #[singleton]
940    impl AllGreetings {
941        fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
942            Self(greetings)
943        }
944    }
945
946    struct ProfileDeferred;
947    #[singleton(profile = "test")]
948    fn profile_deferred() -> ProfileDeferred {
949        ProfileDeferred
950    }
951    struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
952    #[singleton]
953    impl ProfileDeferredConsumer {
954        fn new(provider: Provider<ProfileDeferred>) -> Self {
955            Self(provider)
956        }
957    }
958
959    struct StandaloneBean;
960    #[provider]
961    fn standalone_bean() -> StandaloneBean {
962        StandaloneBean
963    }
964
965    #[tokio::test]
966    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
967        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
968        let container = Container::with_profiles(["test"]).unwrap();
969        let first = container.resolve::<PrototypeBean>().await.unwrap();
970        let second = container.resolve::<PrototypeBean>().await.unwrap();
971        assert_ne!(first.0, second.0);
972
973        assert!(matches!(
974            container.resolve::<RequestBean>().await,
975            Err(DiError::RequestScopeUnavailable(_))
976        ));
977        let request = container.request_context();
978        let request_first = request.resolve::<RequestBean>().await.unwrap();
979        let request_second = request.resolve::<RequestBean>().await.unwrap();
980        assert!(Arc::ptr_eq(&request_first, &request_second));
981
982        let greeter = container.resolve::<Greeter>().await.unwrap();
983        assert_eq!(greeter.greeting.text(), "hello");
984        assert_eq!(
985            container.resolve::<GreetingLabel>().await.unwrap().0,
986            "hello"
987        );
988        assert!(greeter.optional.is_none());
989        assert_eq!(
990            container
991                .resolve::<QualifiedGreeter>()
992                .await
993                .unwrap()
994                .0
995                .text(),
996            "namaste"
997        );
998        let hindi = container
999            .resolve_named::<Arc<dyn Greeting>>("hindi")
1000            .await
1001            .unwrap();
1002        assert_eq!(hindi.text(), "namaste");
1003        let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
1004        assert_eq!(static_bean_service.0.0, "static-bean");
1005        container.resolve::<ProfileBean>().await.unwrap();
1006        let pipeline = container.resolve::<Pipeline>().await.unwrap();
1007        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
1008        handler_names.sort_unstable();
1009        assert_eq!(handler_names, ["first", "second"]);
1010
1011        // This process-local key is unique to the test crate.
1012        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
1013        let properties = container.resolve::<TestProperties>().await.unwrap();
1014        assert_eq!(properties.port, 8080);
1015        container.resolve::<ConditionalBean>().await.unwrap();
1016        container.resolve::<StandaloneBean>().await.unwrap();
1017
1018        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
1019        deferred.provider.get().await.unwrap();
1020        deferred.lazy.get().await.unwrap();
1021        let profile_deferred = container
1022            .resolve::<ProfileDeferredConsumer>()
1023            .await
1024            .unwrap();
1025        profile_deferred.0.get().await.unwrap();
1026        let greetings = container.resolve::<AllGreetings>().await.unwrap();
1027        assert_eq!(greetings.0.len(), 2);
1028
1029        container.initialize_eager().await.unwrap();
1030        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
1031        container.shutdown().await.unwrap();
1032        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
1033        assert!(matches!(
1034            container.resolve::<SyncDependency>().await,
1035            Err(DiError::ContainerShutdown)
1036        ));
1037    }
1038
1039    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1040    async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
1041        let container = Arc::new(Container::new().unwrap());
1042        assert!(matches!(
1043            container.resolve::<FallibleDependency>().await,
1044            Err(DiError::Factory { .. })
1045        ));
1046
1047        let request = container.request_context();
1048        let captive = request.resolve::<CaptiveSingleton>().await;
1049        assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
1050
1051        let first = {
1052            let container = container.clone();
1053            tokio::spawn(async move { container.resolve::<CycleA>().await })
1054        };
1055        let second = {
1056            let container = container.clone();
1057            tokio::spawn(async move { container.resolve::<CycleB>().await })
1058        };
1059        let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
1060            (first.await.unwrap(), second.await.unwrap())
1061        })
1062        .await
1063        .expect("cross-task cycle resolution must not deadlock");
1064        assert!(
1065            matches!(results.0, Err(DiError::CircularDependency(_)))
1066                || matches!(results.1, Err(DiError::CircularDependency(_)))
1067        );
1068    }
1069
1070    #[tokio::test]
1071    async fn field_and_function_injection_are_automatic() {
1072        let facade = resolve::<FieldInjectedFacade>().await.unwrap();
1073        assert_eq!(facade.dependency.0, 10);
1074        assert_eq!(facade.literal, 7);
1075        assert!(Arc::ptr_eq(&facade.dependency, &facade.transformed));
1076        assert_eq!(resolve::<TupleInjected>().await.unwrap().0, 123);
1077        assert_eq!(injected_total(3).await.unwrap(), 36);
1078        assert_eq!(injected_total_with(Arc::new(FieldDependency(20)), 4, 2), 48);
1079        assert_eq!(
1080            caller_owned_arc(Arc::new(FieldDependency(20)))
1081                .await
1082                .unwrap(),
1083            21
1084        );
1085        assert_eq!(InjectedMethods.calculate(5).await.unwrap(), 15);
1086    }
1087}