Skip to main content

auto_di/
runtime.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet, VecDeque},
4    future::Future,
5    pin::Pin,
6    sync::{
7        Arc, Mutex, OnceLock,
8        atomic::{AtomicBool, Ordering},
9    },
10};
11
12pub type DynArc = Arc<dyn Any + Send + Sync>;
13pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
14type Factory =
15    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
16type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq)]
19pub enum Scope {
20    Singleton,
21    Prototype,
22    Request,
23}
24
25#[doc(hidden)]
26pub struct ProviderDescriptor {
27    type_id: fn() -> TypeId,
28    type_name: fn() -> &'static str,
29    factory: Factory,
30    pub name: Option<&'static str>,
31    pub primary: bool,
32    pub scope: Scope,
33    pub eager: bool,
34    pub profile: Option<&'static str>,
35    pub condition_key: Option<&'static str>,
36    pub condition_value: Option<&'static str>,
37    pub destroy: Option<Destroy>,
38}
39
40impl ProviderDescriptor {
41    #[doc(hidden)]
42    pub const fn new(
43        type_id: fn() -> TypeId,
44        type_name: fn() -> &'static str,
45        factory: Factory,
46    ) -> Self {
47        Self::configured(
48            type_id,
49            type_name,
50            factory,
51            None,
52            false,
53            Scope::Singleton,
54            false,
55            None,
56            None,
57            None,
58            None,
59        )
60    }
61
62    #[allow(clippy::too_many_arguments)]
63    #[doc(hidden)]
64    pub const fn configured(
65        type_id: fn() -> TypeId,
66        type_name: fn() -> &'static str,
67        factory: Factory,
68        name: Option<&'static str>,
69        primary: bool,
70        scope: Scope,
71        eager: bool,
72        profile: Option<&'static str>,
73        condition_key: Option<&'static str>,
74        condition_value: Option<&'static str>,
75        destroy: Option<Destroy>,
76    ) -> Self {
77        Self {
78            type_id,
79            type_name,
80            factory,
81            name,
82            primary,
83            scope,
84            eager,
85            profile,
86            condition_key,
87            condition_value,
88            destroy,
89        }
90    }
91
92    fn active(&self, profiles: &[String]) -> bool {
93        let profile_matches = self
94            .profile
95            .is_none_or(|required| profiles.iter().any(|p| p == required));
96        let condition_matches = self.condition_key.is_none_or(|key| {
97            let actual = std::env::var(key).ok();
98            self.condition_value.map_or(actual.is_some(), |expected| {
99                actual.as_deref() == Some(expected)
100            })
101        });
102        profile_matches && condition_matches
103    }
104}
105
106inventory::collect!(ProviderDescriptor);
107
108type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
109
110struct RuntimeProvider {
111    descriptor: &'static ProviderDescriptor,
112    singleton: tokio::sync::OnceCell<DynArc>,
113}
114
115impl RuntimeProvider {
116    fn new(descriptor: &'static ProviderDescriptor) -> Self {
117        Self {
118            descriptor,
119            singleton: tokio::sync::OnceCell::new(),
120        }
121    }
122}
123
124#[derive(Clone, Default)]
125pub struct ResolutionContext {
126    pub(crate) chain: Vec<&'static str>,
127    provider_chain: Vec<usize>,
128    scope_chain: Vec<Scope>,
129    pub(crate) request_instances: Option<Arc<Mutex<InstanceMap>>>,
130}
131
132#[derive(Debug, thiserror::Error)]
133pub enum DiError {
134    #[error("no active provider is registered for {0}")]
135    MissingProvider(&'static str),
136    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
137    AmbiguousProvider(&'static str),
138    #[error("multiple primary providers are registered for {0}")]
139    MultiplePrimary(&'static str),
140    #[error("duplicate provider name '{name}' is registered for {type_name}")]
141    DuplicateProviderName {
142        type_name: &'static str,
143        name: &'static str,
144    },
145    #[error("circular dependency detected: {0}")]
146    CircularDependency(String),
147    #[error("provider for {0} returned an incompatible type")]
148    TypeMismatch(&'static str),
149    #[error("request-scoped dependency {0} was resolved outside RequestContext")]
150    RequestScopeUnavailable(&'static str),
151    #[error("singleton dependency {singleton} cannot capture request-scoped {request}")]
152    InvalidScope {
153        singleton: &'static str,
154        request: &'static str,
155    },
156    #[error("the dependency container has already been shut down")]
157    ContainerShutdown,
158    #[error("configuration property {key} is missing or invalid: {message}")]
159    Configuration { key: String, message: String },
160    #[error("provider for {provider} failed: {message}")]
161    Factory {
162        provider: &'static str,
163        message: String,
164    },
165    #[error("lifecycle hook failed for {provider}: {message}")]
166    Lifecycle {
167        provider: &'static str,
168        message: String,
169    },
170}
171
172struct ContainerInner {
173    providers: HashMap<TypeId, Vec<RuntimeProvider>>,
174    dependency_graph: Mutex<HashMap<usize, HashSet<usize>>>,
175    shut_down: AtomicBool,
176}
177
178#[derive(Clone)]
179pub struct Container {
180    inner: Arc<ContainerInner>,
181}
182
183static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
184static RESOLUTION_ERROR_HANDLER: OnceLock<fn(DiError) -> !> = OnceLock::new();
185
186/// Installs the process-wide handler used when a plain `#[injected]` wrapper
187/// cannot resolve one of its dependencies.
188///
189/// The handler must not return. Install it before calling injected functions.
190pub fn set_resolution_error_handler(handler: fn(DiError) -> !) -> Result<(), fn(DiError) -> !> {
191    RESOLUTION_ERROR_HANDLER.set(handler)
192}
193
194#[doc(hidden)]
195pub fn handle_resolution_error(error: DiError) -> ! {
196    if let Some(handler) = RESOLUTION_ERROR_HANDLER.get() {
197        handler(error)
198    }
199    panic!("auto-di dependency resolution failed: {error}")
200}
201
202pub fn global_container() -> Result<&'static Container, DiError> {
203    if let Some(container) = GLOBAL_CONTAINER.get() {
204        return Ok(container);
205    }
206    let container = Container::new()?;
207    let _ = GLOBAL_CONTAINER.set(container);
208    Ok(GLOBAL_CONTAINER
209        .get()
210        .expect("global DI container initialized"))
211}
212
213pub async fn resolve<T>() -> Result<Arc<T>, DiError>
214where
215    T: Any + Send + Sync,
216{
217    global_container()?.resolve::<T>().await
218}
219
220impl Container {
221    pub fn new() -> Result<Self, DiError> {
222        let profiles = std::env::var("APP_PROFILES")
223            .unwrap_or_default()
224            .split(',')
225            .map(str::trim)
226            .filter(|p| !p.is_empty())
227            .map(str::to_owned)
228            .collect::<Vec<_>>();
229        Self::with_profiles(profiles)
230    }
231
232    pub fn with_profiles(
233        profiles: impl IntoIterator<Item = impl Into<String>>,
234    ) -> Result<Self, DiError> {
235        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
236        let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
237        for provider in inventory::iter::<ProviderDescriptor> {
238            if provider.active(&profiles) {
239                providers
240                    .entry((provider.type_id)())
241                    .or_default()
242                    .push(RuntimeProvider::new(provider));
243            }
244        }
245        for group in providers.values() {
246            if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
247                return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
248            }
249            let mut names = HashSet::new();
250            for provider in group {
251                if let Some(name) = provider.descriptor.name
252                    && !names.insert(name)
253                {
254                    return Err(DiError::DuplicateProviderName {
255                        type_name: (provider.descriptor.type_name)(),
256                        name,
257                    });
258                }
259            }
260        }
261        Ok(Self {
262            inner: Arc::new(ContainerInner {
263                providers,
264                dependency_graph: Mutex::new(HashMap::new()),
265                shut_down: AtomicBool::new(false),
266            }),
267        })
268    }
269
270    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
271    where
272        T: Any + Send + Sync,
273    {
274        self.resolve_dependency::<T>(&ResolutionContext::default())
275            .await
276    }
277
278    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
279    where
280        T: Any + Send + Sync,
281    {
282        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
283            .await
284    }
285
286    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
287    where
288        T: Any + Send + Sync,
289    {
290        match self.resolve::<T>().await {
291            Ok(value) => Ok(Some(value)),
292            Err(DiError::MissingProvider(_)) => Ok(None),
293            Err(error) => Err(error),
294        }
295    }
296
297    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
298    where
299        T: Any + Send + Sync,
300    {
301        self.resolve_all_dependency::<T>(&ResolutionContext::default())
302            .await
303    }
304
305    pub fn request_context(&self) -> RequestContext<'_> {
306        RequestContext {
307            container: self,
308            context: ResolutionContext {
309                chain: vec![],
310                provider_chain: vec![],
311                scope_chain: vec![],
312                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
313            },
314        }
315    }
316
317    pub async fn initialize_eager(&self) -> Result<(), DiError> {
318        for providers in self.inner.providers.values() {
319            for provider in providers.iter().filter(|p| p.descriptor.eager) {
320                self.resolve_provider(provider, ResolutionContext::default())
321                    .await?;
322            }
323        }
324        Ok(())
325    }
326
327    /// Validates the active singleton graph by constructing every singleton.
328    /// This surfaces missing, ambiguous, circular, and invalid-scope dependencies
329    /// during application startup instead of on the first request.
330    pub async fn validate(&self) -> Result<(), DiError> {
331        for providers in self.inner.providers.values() {
332            for provider in providers
333                .iter()
334                .filter(|provider| provider.descriptor.scope == Scope::Singleton)
335            {
336                self.resolve_provider(provider, ResolutionContext::default())
337                    .await?;
338            }
339        }
340        Ok(())
341    }
342
343    pub async fn shutdown(&self) -> Result<(), DiError> {
344        if self.inner.shut_down.swap(true, Ordering::AcqRel) {
345            return Ok(());
346        }
347        let graph = self
348            .inner
349            .dependency_graph
350            .lock()
351            .expect("DI dependency graph lock poisoned")
352            .clone();
353        let providers = self
354            .inner
355            .providers
356            .values()
357            .flatten()
358            .map(|provider| (runtime_provider_key(provider), provider))
359            .collect::<HashMap<_, _>>();
360        let order = shutdown_order(providers.keys().copied(), &graph);
361        for key in order {
362            let provider = providers[&key];
363            if let (Some(destroy), Some(value)) =
364                (provider.descriptor.destroy, provider.singleton.get())
365            {
366                destroy(value.clone()).await?;
367            }
368        }
369        Ok(())
370    }
371
372    #[doc(hidden)]
373    pub async fn resolve_dependency<T>(
374        &self,
375        context: &ResolutionContext,
376    ) -> Result<Arc<T>, DiError>
377    where
378        T: Any + Send + Sync,
379    {
380        self.resolve_selected::<T>(None, context).await
381    }
382
383    #[doc(hidden)]
384    pub async fn resolve_named_dependency<T>(
385        &self,
386        name: &str,
387        context: &ResolutionContext,
388    ) -> Result<Arc<T>, DiError>
389    where
390        T: Any + Send + Sync,
391    {
392        self.resolve_selected::<T>(Some(name), context).await
393    }
394
395    #[doc(hidden)]
396    pub async fn resolve_optional_dependency<T>(
397        &self,
398        context: &ResolutionContext,
399    ) -> Result<Option<Arc<T>>, DiError>
400    where
401        T: Any + Send + Sync,
402    {
403        match self.resolve_dependency::<T>(context).await {
404            Ok(value) => Ok(Some(value)),
405            Err(DiError::MissingProvider(_)) => Ok(None),
406            Err(error) => Err(error),
407        }
408    }
409
410    #[doc(hidden)]
411    pub async fn resolve_all_dependency<T>(
412        &self,
413        context: &ResolutionContext,
414    ) -> Result<Vec<Arc<T>>, DiError>
415    where
416        T: Any + Send + Sync,
417    {
418        let Some(providers) = self.inner.providers.get(&TypeId::of::<T>()) else {
419            return Ok(vec![]);
420        };
421        let mut values = Vec::with_capacity(providers.len());
422        for provider in providers {
423            let value = self.resolve_provider(provider, context.clone()).await?;
424            values.push(
425                value
426                    .downcast::<T>()
427                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
428            );
429        }
430        Ok(values)
431    }
432
433    async fn resolve_selected<T>(
434        &self,
435        name: Option<&str>,
436        context: &ResolutionContext,
437    ) -> Result<Arc<T>, DiError>
438    where
439        T: Any + Send + Sync,
440    {
441        let type_name = std::any::type_name::<T>();
442        let providers = self
443            .inner
444            .providers
445            .get(&TypeId::of::<T>())
446            .ok_or(DiError::MissingProvider(type_name))?;
447        let selected = if let Some(name) = name {
448            providers
449                .iter()
450                .find(|p| p.descriptor.name == Some(name))
451                .ok_or(DiError::MissingProvider(type_name))?
452        } else if providers.len() == 1 {
453            &providers[0]
454        } else {
455            providers
456                .iter()
457                .find(|p| p.descriptor.primary)
458                .ok_or(DiError::AmbiguousProvider(type_name))?
459        };
460        let value = self.resolve_provider(selected, context.clone()).await?;
461        value
462            .downcast::<T>()
463            .map_err(|_| DiError::TypeMismatch(type_name))
464    }
465
466    fn resolve_provider<'a>(
467        &'a self,
468        provider: &'a RuntimeProvider,
469        mut context: ResolutionContext,
470    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
471        Box::pin(async move {
472            if self.inner.shut_down.load(Ordering::Acquire) {
473                return Err(DiError::ContainerShutdown);
474            }
475            let descriptor = provider.descriptor;
476            let type_name = (descriptor.type_name)();
477            if context.chain.contains(&type_name) {
478                context.chain.push(type_name);
479                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
480            }
481            let runtime_key = runtime_provider_key(provider);
482            if let Some(parent) = context.provider_chain.last().copied() {
483                self.add_dependency_edge(parent, runtime_key, &context, type_name)?;
484            }
485            if descriptor.scope == Scope::Request
486                && let Some(position) = context
487                    .scope_chain
488                    .iter()
489                    .position(|scope| *scope == Scope::Singleton)
490            {
491                return Err(DiError::InvalidScope {
492                    singleton: context.chain[position],
493                    request: type_name,
494                });
495            }
496            context.chain.push(type_name);
497            context.provider_chain.push(runtime_key);
498            context.scope_chain.push(descriptor.scope);
499            if descriptor.scope == Scope::Prototype {
500                return (descriptor.factory)(self, context).await;
501            }
502            match descriptor.scope {
503                Scope::Singleton => {
504                    let value = provider
505                        .singleton
506                        .get_or_try_init(
507                            || async move { (descriptor.factory)(self, context).await },
508                        )
509                        .await?;
510                    Ok(value.clone())
511                }
512                Scope::Request => {
513                    let map = context
514                        .request_instances
515                        .as_deref()
516                        .ok_or(DiError::RequestScopeUnavailable(type_name))?;
517                    let cell = {
518                        let mut instances = map.lock().expect("DI instance lock poisoned");
519                        instances
520                            .entry(provider_key(descriptor))
521                            .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
522                            .clone()
523                    };
524                    let value = cell
525                        .get_or_try_init(
526                            || async move { (descriptor.factory)(self, context).await },
527                        )
528                        .await?;
529                    Ok(value.clone())
530                }
531                Scope::Prototype => unreachable!(),
532            }
533        })
534    }
535
536    fn add_dependency_edge(
537        &self,
538        parent: usize,
539        dependency: usize,
540        context: &ResolutionContext,
541        dependency_name: &'static str,
542    ) -> Result<(), DiError> {
543        let mut graph = self
544            .inner
545            .dependency_graph
546            .lock()
547            .expect("DI dependency graph lock poisoned");
548        graph.entry(parent).or_default().insert(dependency);
549        if graph_path_exists(&graph, dependency, parent, &mut HashSet::new()) {
550            let mut chain = context.chain.clone();
551            chain.push(dependency_name);
552            return Err(DiError::CircularDependency(chain.join(" -> ")));
553        }
554        Ok(())
555    }
556}
557
558fn provider_key(provider: &'static ProviderDescriptor) -> usize {
559    provider as *const ProviderDescriptor as usize
560}
561
562fn runtime_provider_key(provider: &RuntimeProvider) -> usize {
563    provider as *const RuntimeProvider as usize
564}
565
566fn graph_path_exists(
567    graph: &HashMap<usize, HashSet<usize>>,
568    current: usize,
569    target: usize,
570    visited: &mut HashSet<usize>,
571) -> bool {
572    if current == target {
573        return true;
574    }
575    visited.insert(current)
576        && graph.get(&current).is_some_and(|dependencies| {
577            dependencies
578                .iter()
579                .any(|next| graph_path_exists(graph, *next, target, visited))
580        })
581}
582
583fn shutdown_order(
584    keys: impl IntoIterator<Item = usize>,
585    graph: &HashMap<usize, HashSet<usize>>,
586) -> Vec<usize> {
587    let keys = keys.into_iter().collect::<HashSet<_>>();
588    let mut incoming = keys
589        .iter()
590        .map(|key| (*key, 0usize))
591        .collect::<HashMap<_, _>>();
592    for (parent, dependencies) in graph {
593        if !keys.contains(parent) {
594            continue;
595        }
596        for dependency in dependencies {
597            if let Some(count) = incoming.get_mut(dependency) {
598                *count += 1;
599            }
600        }
601    }
602
603    let mut ready = incoming
604        .iter()
605        .filter_map(|(key, count)| (*count == 0).then_some(*key))
606        .collect::<VecDeque<_>>();
607    let mut order = Vec::with_capacity(keys.len());
608    while let Some(parent) = ready.pop_front() {
609        order.push(parent);
610        if let Some(dependencies) = graph.get(&parent) {
611            for dependency in dependencies {
612                if let Some(count) = incoming.get_mut(dependency) {
613                    *count -= 1;
614                    if *count == 0 {
615                        ready.push_back(*dependency);
616                    }
617                }
618            }
619        }
620    }
621    // A cycle is normally rejected during resolution. Preserve shutdown safety
622    // if state was externally corrupted or a future resolver bypasses that check.
623    let ordered = order.iter().copied().collect::<HashSet<_>>();
624    order.extend(keys.into_iter().filter(|key| !ordered.contains(key)));
625    order
626}
627
628pub struct RequestContext<'a> {
629    container: &'a Container,
630    context: ResolutionContext,
631}
632impl RequestContext<'_> {
633    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
634    where
635        T: Any + Send + Sync,
636    {
637        self.container.resolve_dependency::<T>(&self.context).await
638    }
639}
640
641#[doc(hidden)]
642pub mod __private {
643    pub use inventory;
644    pub use tokio;
645
646    use std::fmt::Display;
647
648    use crate::DiError;
649
650    pub fn factory_result<T, E: Display>(
651        result: Result<T, E>,
652        provider: &'static str,
653    ) -> Result<T, DiError> {
654        result.map_err(|error| DiError::Factory {
655            provider,
656            message: error.to_string(),
657        })
658    }
659
660    pub trait IntoLifecycleResult {
661        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError>;
662    }
663
664    impl IntoLifecycleResult for () {
665        fn into_lifecycle_result(self, _provider: &'static str) -> Result<(), DiError> {
666            Ok(())
667        }
668    }
669
670    impl<E: Display> IntoLifecycleResult for Result<(), E> {
671        fn into_lifecycle_result(self, provider: &'static str) -> Result<(), DiError> {
672            self.map_err(|error| DiError::Lifecycle {
673                provider,
674                message: error.to_string(),
675            })
676        }
677    }
678
679    pub fn lifecycle_result<R: IntoLifecycleResult>(
680        result: R,
681        provider: &'static str,
682    ) -> Result<(), DiError> {
683        result.into_lifecycle_result(provider)
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use crate::{
691        Lazy, Provider, configuration_properties, injectable, injected, provider, singleton,
692    };
693    use std::sync::atomic::{AtomicUsize, Ordering};
694    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
695    struct SyncDependency;
696    #[singleton]
697    fn sync_dependency() -> SyncDependency {
698        CREATIONS.fetch_add(1, Ordering::SeqCst);
699        SyncDependency
700    }
701    struct AsyncDependency {
702        _sync: Arc<SyncDependency>,
703    }
704    #[singleton]
705    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
706        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
707        AsyncDependency { _sync: sync }
708    }
709    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
710    async fn singleton_is_concurrent_safe() {
711        let container = Arc::new(Container::new().unwrap());
712        let mut tasks = tokio::task::JoinSet::new();
713        for _ in 0..32 {
714            let c = container.clone();
715            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
716        }
717        let mut values = vec![];
718        while let Some(v) = tasks.join_next().await {
719            values.push(v.unwrap());
720        }
721        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
722        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
723    }
724
725    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
726    struct PrototypeBean(usize);
727    #[singleton(scope = "prototype")]
728    fn prototype_bean() -> PrototypeBean {
729        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
730    }
731
732    struct RequestBean;
733    #[singleton(scope = "request")]
734    fn request_bean() -> RequestBean {
735        RequestBean
736    }
737
738    trait Greeting: Send + Sync {
739        fn text(&self) -> &'static str;
740    }
741    struct English;
742    impl Greeting for English {
743        fn text(&self) -> &'static str {
744            "hello"
745        }
746    }
747    struct Hindi;
748    impl Greeting for Hindi {
749        fn text(&self) -> &'static str {
750            "namaste"
751        }
752    }
753
754    #[singleton(name = "english", primary)]
755    fn english_greeting() -> Arc<dyn Greeting> {
756        Arc::new(English)
757    }
758    #[singleton(name = "hindi")]
759    fn hindi_greeting() -> Arc<dyn Greeting> {
760        Arc::new(Hindi)
761    }
762
763    struct MissingOptional;
764    struct Greeter {
765        greeting: Arc<dyn Greeting>,
766        optional: Option<Arc<MissingOptional>>,
767    }
768    struct GreetingLabel(&'static str);
769    #[singleton]
770    impl Greeter {
771        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
772            Self { greeting, optional }
773        }
774
775        #[provider]
776        fn label(&self) -> GreetingLabel {
777            GreetingLabel(self.greeting.text())
778        }
779    }
780    struct QualifiedGreeter(Arc<dyn Greeting>);
781    #[singleton]
782    impl QualifiedGreeter {
783        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
784            Self(greeting)
785        }
786    }
787
788    struct StaticConfig(&'static str);
789    struct ServiceWithStaticBean(Arc<StaticConfig>);
790    #[singleton]
791    impl ServiceWithStaticBean {
792        fn new(config: Arc<StaticConfig>) -> Self {
793            Self(config)
794        }
795
796        // No `&self`: this provider can run before its owning singleton and can
797        // therefore participate in the service's constructor graph.
798        #[provider]
799        fn config() -> StaticConfig {
800            StaticConfig("static-bean")
801        }
802    }
803
804    static STARTED: AtomicUsize = AtomicUsize::new(0);
805    static STOPPED: AtomicUsize = AtomicUsize::new(0);
806    struct Managed;
807    #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
808    impl Managed {
809        fn new() -> Self {
810            Self
811        }
812        async fn start(&self) {
813            STARTED.fetch_add(1, Ordering::SeqCst);
814        }
815        async fn stop(&self) {
816            STOPPED.fetch_add(1, Ordering::SeqCst);
817        }
818    }
819
820    struct ProfileBean;
821    #[singleton(profile = "test")]
822    fn profile_bean() -> ProfileBean {
823        ProfileBean
824    }
825
826    #[derive(Debug)]
827    struct Handler(&'static str);
828    #[singleton(name = "first")]
829    fn first_handler() -> Handler {
830        Handler("first")
831    }
832    #[singleton(name = "second")]
833    fn second_handler() -> Handler {
834        Handler("second")
835    }
836    struct Pipeline(Vec<Arc<Handler>>);
837    #[singleton]
838    impl Pipeline {
839        fn new(handlers: Vec<Arc<Handler>>) -> Self {
840            Self(handlers)
841        }
842    }
843
844    #[configuration_properties("testing_dep")]
845    struct TestProperties {
846        port: u16,
847    }
848
849    struct DeferredTarget;
850    #[singleton]
851    fn deferred_target() -> DeferredTarget {
852        DeferredTarget
853    }
854    struct DeferredConsumer {
855        provider: Provider<DeferredTarget>,
856        lazy: Lazy<DeferredTarget>,
857    }
858    #[singleton]
859    impl DeferredConsumer {
860        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
861            Self { provider, lazy }
862        }
863    }
864
865    struct ConditionalBean;
866    #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
867    fn conditional_bean() -> ConditionalBean {
868        ConditionalBean
869    }
870
871    struct FallibleDependency;
872    #[provider]
873    fn fallible_dependency() -> Result<FallibleDependency, &'static str> {
874        Err("expected factory failure")
875    }
876
877    #[derive(Clone)]
878    struct FieldDependency(u32);
879    #[singleton]
880    fn field_dependency() -> FieldDependency {
881        FieldDependency(10)
882    }
883
884    #[injectable]
885    struct FieldInjectedFacade {
886        dependency: Arc<FieldDependency>,
887        #[inject(7)]
888        literal: u32,
889        #[inject(|dependency: Arc<FieldDependency>| dependency)]
890        transformed: Arc<FieldDependency>,
891    }
892
893    #[injectable]
894    struct TupleInjected(#[inject(123)] u32);
895
896    #[injected]
897    fn injected_total(
898        dependency: Arc<FieldDependency>,
899        #[inject(2)] offset: u32,
900        multiplier: u32,
901    ) -> u32 {
902        (dependency.0 + offset) * multiplier
903    }
904
905    #[injected(fallible)]
906    fn caller_owned_arc(
907        #[argument] dependency: Arc<FieldDependency>,
908        #[inject(1)] offset: u32,
909    ) -> u32 {
910        dependency.0 + offset
911    }
912
913    struct InjectedMethods;
914    impl InjectedMethods {
915        #[injected]
916        fn calculate(&self, dependency: Arc<FieldDependency>, value: u32) -> u32 {
917            dependency.0 + value
918        }
919    }
920
921    struct RequestLeaf;
922    #[singleton(scope = "request")]
923    fn request_leaf() -> RequestLeaf {
924        RequestLeaf
925    }
926    struct CaptiveSingleton {
927        _request: Arc<RequestLeaf>,
928    }
929    #[singleton]
930    impl CaptiveSingleton {
931        fn new(request: Arc<RequestLeaf>) -> Self {
932            Self { _request: request }
933        }
934    }
935
936    struct CycleA {
937        _b: Arc<CycleB>,
938    }
939    struct CycleB {
940        _a: Arc<CycleA>,
941    }
942    #[singleton]
943    impl CycleA {
944        fn new(b: Arc<CycleB>) -> Self {
945            Self { _b: b }
946        }
947    }
948    #[singleton]
949    impl CycleB {
950        fn new(a: Arc<CycleA>) -> Self {
951            Self { _a: a }
952        }
953    }
954
955    struct AllGreetings(Vec<Arc<dyn Greeting>>);
956    #[singleton]
957    impl AllGreetings {
958        fn new(greetings: Vec<Arc<dyn Greeting>>) -> Self {
959            Self(greetings)
960        }
961    }
962
963    struct ProfileDeferred;
964    #[singleton(profile = "test")]
965    fn profile_deferred() -> ProfileDeferred {
966        ProfileDeferred
967    }
968    struct ProfileDeferredConsumer(Provider<ProfileDeferred>);
969    #[singleton]
970    impl ProfileDeferredConsumer {
971        fn new(provider: Provider<ProfileDeferred>) -> Self {
972            Self(provider)
973        }
974    }
975
976    struct StandaloneBean;
977    #[provider]
978    fn standalone_bean() -> StandaloneBean {
979        StandaloneBean
980    }
981
982    #[tokio::test]
983    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
984        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
985        let container = Container::with_profiles(["test"]).unwrap();
986        let first = container.resolve::<PrototypeBean>().await.unwrap();
987        let second = container.resolve::<PrototypeBean>().await.unwrap();
988        assert_ne!(first.0, second.0);
989
990        assert!(matches!(
991            container.resolve::<RequestBean>().await,
992            Err(DiError::RequestScopeUnavailable(_))
993        ));
994        let request = container.request_context();
995        let request_first = request.resolve::<RequestBean>().await.unwrap();
996        let request_second = request.resolve::<RequestBean>().await.unwrap();
997        assert!(Arc::ptr_eq(&request_first, &request_second));
998
999        let greeter = container.resolve::<Greeter>().await.unwrap();
1000        assert_eq!(greeter.greeting.text(), "hello");
1001        assert_eq!(
1002            container.resolve::<GreetingLabel>().await.unwrap().0,
1003            "hello"
1004        );
1005        assert!(greeter.optional.is_none());
1006        assert_eq!(
1007            container
1008                .resolve::<QualifiedGreeter>()
1009                .await
1010                .unwrap()
1011                .0
1012                .text(),
1013            "namaste"
1014        );
1015        let hindi = container
1016            .resolve_named::<Arc<dyn Greeting>>("hindi")
1017            .await
1018            .unwrap();
1019        assert_eq!(hindi.text(), "namaste");
1020        let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
1021        assert_eq!(static_bean_service.0.0, "static-bean");
1022        container.resolve::<ProfileBean>().await.unwrap();
1023        let pipeline = container.resolve::<Pipeline>().await.unwrap();
1024        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
1025        handler_names.sort_unstable();
1026        assert_eq!(handler_names, ["first", "second"]);
1027
1028        // This process-local key is unique to the test crate.
1029        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
1030        let properties = container.resolve::<TestProperties>().await.unwrap();
1031        assert_eq!(properties.port, 8080);
1032        container.resolve::<ConditionalBean>().await.unwrap();
1033        container.resolve::<StandaloneBean>().await.unwrap();
1034
1035        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
1036        deferred.provider.get().await.unwrap();
1037        deferred.lazy.get().await.unwrap();
1038        let profile_deferred = container
1039            .resolve::<ProfileDeferredConsumer>()
1040            .await
1041            .unwrap();
1042        profile_deferred.0.get().await.unwrap();
1043        let greetings = container.resolve::<AllGreetings>().await.unwrap();
1044        assert_eq!(greetings.0.len(), 2);
1045
1046        container.initialize_eager().await.unwrap();
1047        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
1048        container.shutdown().await.unwrap();
1049        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
1050        assert!(matches!(
1051            container.resolve::<SyncDependency>().await,
1052            Err(DiError::ContainerShutdown)
1053        ));
1054    }
1055
1056    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1057    async fn failures_are_reported_instead_of_hanging_or_capturing_scope() {
1058        let container = Arc::new(Container::new().unwrap());
1059        assert!(matches!(
1060            container.resolve::<FallibleDependency>().await,
1061            Err(DiError::Factory { .. })
1062        ));
1063
1064        let request = container.request_context();
1065        let captive = request.resolve::<CaptiveSingleton>().await;
1066        assert!(matches!(captive, Err(DiError::InvalidScope { .. })));
1067
1068        let first = {
1069            let container = container.clone();
1070            tokio::spawn(async move { container.resolve::<CycleA>().await })
1071        };
1072        let second = {
1073            let container = container.clone();
1074            tokio::spawn(async move { container.resolve::<CycleB>().await })
1075        };
1076        let results = tokio::time::timeout(std::time::Duration::from_secs(1), async {
1077            (first.await.unwrap(), second.await.unwrap())
1078        })
1079        .await
1080        .expect("cross-task cycle resolution must not deadlock");
1081        assert!(
1082            matches!(results.0, Err(DiError::CircularDependency(_)))
1083                || matches!(results.1, Err(DiError::CircularDependency(_)))
1084        );
1085    }
1086
1087    #[tokio::test]
1088    async fn field_and_function_injection_are_automatic() {
1089        let facade = resolve::<FieldInjectedFacade>().await.unwrap();
1090        assert_eq!(facade.dependency.0, 10);
1091        assert_eq!(facade.literal, 7);
1092        assert!(Arc::ptr_eq(&facade.dependency, &facade.transformed));
1093        assert_eq!(resolve::<TupleInjected>().await.unwrap().0, 123);
1094        assert_eq!(injected_total(3).await, 36);
1095        assert_eq!(injected_total_with(Arc::new(FieldDependency(20)), 4, 2), 48);
1096        assert_eq!(
1097            caller_owned_arc(Arc::new(FieldDependency(20)))
1098                .await
1099                .unwrap(),
1100            21
1101        );
1102        assert_eq!(InjectedMethods.calculate(5).await, 15);
1103    }
1104}