Skip to main content

auto_di/
runtime.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex, OnceLock},
7};
8
9pub type DynArc = Arc<dyn Any + Send + Sync>;
10pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
11type Factory =
12    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
13type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Scope {
17    Singleton,
18    Prototype,
19    Request,
20}
21
22#[doc(hidden)]
23pub struct ProviderDescriptor {
24    type_id: fn() -> TypeId,
25    type_name: fn() -> &'static str,
26    factory: Factory,
27    pub name: Option<&'static str>,
28    pub primary: bool,
29    pub scope: Scope,
30    pub eager: bool,
31    pub profile: Option<&'static str>,
32    pub condition_key: Option<&'static str>,
33    pub condition_value: Option<&'static str>,
34    pub destroy: Option<Destroy>,
35}
36
37impl ProviderDescriptor {
38    #[doc(hidden)]
39    pub const fn new(
40        type_id: fn() -> TypeId,
41        type_name: fn() -> &'static str,
42        factory: Factory,
43    ) -> Self {
44        Self::configured(
45            type_id,
46            type_name,
47            factory,
48            None,
49            false,
50            Scope::Singleton,
51            false,
52            None,
53            None,
54            None,
55            None,
56        )
57    }
58
59    #[allow(clippy::too_many_arguments)]
60    #[doc(hidden)]
61    pub const fn configured(
62        type_id: fn() -> TypeId,
63        type_name: fn() -> &'static str,
64        factory: Factory,
65        name: Option<&'static str>,
66        primary: bool,
67        scope: Scope,
68        eager: bool,
69        profile: Option<&'static str>,
70        condition_key: Option<&'static str>,
71        condition_value: Option<&'static str>,
72        destroy: Option<Destroy>,
73    ) -> Self {
74        Self {
75            type_id,
76            type_name,
77            factory,
78            name,
79            primary,
80            scope,
81            eager,
82            profile,
83            condition_key,
84            condition_value,
85            destroy,
86        }
87    }
88
89    fn active(&self, profiles: &[String]) -> bool {
90        let profile_matches = self
91            .profile
92            .is_none_or(|required| profiles.iter().any(|p| p == required));
93        let condition_matches = self.condition_key.is_none_or(|key| {
94            let actual = std::env::var(key).ok();
95            self.condition_value.map_or(actual.is_some(), |expected| {
96                actual.as_deref() == Some(expected)
97            })
98        });
99        profile_matches && condition_matches
100    }
101}
102
103inventory::collect!(ProviderDescriptor);
104
105type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
106
107struct RuntimeProvider {
108    descriptor: &'static ProviderDescriptor,
109    singleton: tokio::sync::OnceCell<DynArc>,
110}
111
112impl RuntimeProvider {
113    fn new(descriptor: &'static ProviderDescriptor) -> Self {
114        Self {
115            descriptor,
116            singleton: tokio::sync::OnceCell::new(),
117        }
118    }
119}
120
121#[derive(Clone, Default)]
122pub struct ResolutionContext {
123    chain: Vec<&'static str>,
124    request_instances: Option<Arc<Mutex<InstanceMap>>>,
125}
126
127#[derive(Debug, thiserror::Error)]
128pub enum DiError {
129    #[error("no active provider is registered for {0}")]
130    MissingProvider(&'static str),
131    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
132    AmbiguousProvider(&'static str),
133    #[error("multiple primary providers are registered for {0}")]
134    MultiplePrimary(&'static str),
135    #[error("circular dependency detected: {0}")]
136    CircularDependency(String),
137    #[error("provider for {0} returned an incompatible type")]
138    TypeMismatch(&'static str),
139    #[error("request-scoped dependency {0} was resolved outside RequestContext")]
140    RequestScopeUnavailable(&'static str),
141    #[error("configuration property {key} is missing or invalid: {message}")]
142    Configuration { key: String, message: String },
143    #[error("lifecycle hook failed for {0}")]
144    Lifecycle(&'static str),
145}
146
147pub struct Container {
148    providers: HashMap<TypeId, Vec<RuntimeProvider>>,
149}
150
151static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
152
153pub fn global_container() -> Result<&'static Container, DiError> {
154    if let Some(container) = GLOBAL_CONTAINER.get() {
155        return Ok(container);
156    }
157    let container = Container::new()?;
158    let _ = GLOBAL_CONTAINER.set(container);
159    Ok(GLOBAL_CONTAINER
160        .get()
161        .expect("global DI container initialized"))
162}
163
164pub async fn resolve<T>() -> Result<Arc<T>, DiError>
165where
166    T: Any + Send + Sync,
167{
168    global_container()?.resolve::<T>().await
169}
170
171impl Container {
172    pub fn new() -> Result<Self, DiError> {
173        let profiles = std::env::var("APP_PROFILES")
174            .unwrap_or_default()
175            .split(',')
176            .map(str::trim)
177            .filter(|p| !p.is_empty())
178            .map(str::to_owned)
179            .collect::<Vec<_>>();
180        Self::with_profiles(profiles)
181    }
182
183    pub fn with_profiles(
184        profiles: impl IntoIterator<Item = impl Into<String>>,
185    ) -> Result<Self, DiError> {
186        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
187        let mut providers: HashMap<TypeId, Vec<RuntimeProvider>> = HashMap::new();
188        for provider in inventory::iter::<ProviderDescriptor> {
189            if provider.active(&profiles) {
190                providers
191                    .entry((provider.type_id)())
192                    .or_default()
193                    .push(RuntimeProvider::new(provider));
194            }
195        }
196        for group in providers.values() {
197            if group.iter().filter(|p| p.descriptor.primary).count() > 1 {
198                return Err(DiError::MultiplePrimary((group[0].descriptor.type_name)()));
199            }
200        }
201        Ok(Self { providers })
202    }
203
204    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
205    where
206        T: Any + Send + Sync,
207    {
208        self.resolve_dependency::<T>(&ResolutionContext::default())
209            .await
210    }
211
212    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
213    where
214        T: Any + Send + Sync,
215    {
216        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
217            .await
218    }
219
220    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
221    where
222        T: Any + Send + Sync,
223    {
224        match self.resolve::<T>().await {
225            Ok(value) => Ok(Some(value)),
226            Err(DiError::MissingProvider(_)) => Ok(None),
227            Err(error) => Err(error),
228        }
229    }
230
231    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
232    where
233        T: Any + Send + Sync,
234    {
235        self.resolve_all_dependency::<T>(&ResolutionContext::default())
236            .await
237    }
238
239    pub fn request_context(&self) -> RequestContext<'_> {
240        RequestContext {
241            container: self,
242            context: ResolutionContext {
243                chain: vec![],
244                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
245            },
246        }
247    }
248
249    pub async fn initialize_eager(&self) -> Result<(), DiError> {
250        for providers in self.providers.values() {
251            for provider in providers.iter().filter(|p| p.descriptor.eager) {
252                self.resolve_provider(provider, ResolutionContext::default())
253                    .await?;
254            }
255        }
256        Ok(())
257    }
258
259    pub async fn shutdown(&self) -> Result<(), DiError> {
260        for providers in self.providers.values() {
261            for provider in providers {
262                if let (Some(destroy), Some(value)) =
263                    (provider.descriptor.destroy, provider.singleton.get())
264                {
265                    destroy(value.clone()).await?;
266                }
267            }
268        }
269        Ok(())
270    }
271
272    #[doc(hidden)]
273    pub async fn resolve_dependency<T>(
274        &self,
275        context: &ResolutionContext,
276    ) -> Result<Arc<T>, DiError>
277    where
278        T: Any + Send + Sync,
279    {
280        self.resolve_selected::<T>(None, context).await
281    }
282
283    #[doc(hidden)]
284    pub async fn resolve_named_dependency<T>(
285        &self,
286        name: &str,
287        context: &ResolutionContext,
288    ) -> Result<Arc<T>, DiError>
289    where
290        T: Any + Send + Sync,
291    {
292        self.resolve_selected::<T>(Some(name), context).await
293    }
294
295    #[doc(hidden)]
296    pub async fn resolve_optional_dependency<T>(
297        &self,
298        context: &ResolutionContext,
299    ) -> Result<Option<Arc<T>>, DiError>
300    where
301        T: Any + Send + Sync,
302    {
303        match self.resolve_dependency::<T>(context).await {
304            Ok(value) => Ok(Some(value)),
305            Err(DiError::MissingProvider(_)) => Ok(None),
306            Err(error) => Err(error),
307        }
308    }
309
310    #[doc(hidden)]
311    pub async fn resolve_all_dependency<T>(
312        &self,
313        context: &ResolutionContext,
314    ) -> Result<Vec<Arc<T>>, DiError>
315    where
316        T: Any + Send + Sync,
317    {
318        let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
319            return Ok(vec![]);
320        };
321        let mut values = Vec::with_capacity(providers.len());
322        for provider in providers {
323            let value = self.resolve_provider(provider, context.clone()).await?;
324            values.push(
325                value
326                    .downcast::<T>()
327                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
328            );
329        }
330        Ok(values)
331    }
332
333    async fn resolve_selected<T>(
334        &self,
335        name: Option<&str>,
336        context: &ResolutionContext,
337    ) -> Result<Arc<T>, DiError>
338    where
339        T: Any + Send + Sync,
340    {
341        let type_name = std::any::type_name::<T>();
342        let providers = self
343            .providers
344            .get(&TypeId::of::<T>())
345            .ok_or(DiError::MissingProvider(type_name))?;
346        let selected = if let Some(name) = name {
347            providers
348                .iter()
349                .find(|p| p.descriptor.name == Some(name))
350                .ok_or(DiError::MissingProvider(type_name))?
351        } else if providers.len() == 1 {
352            &providers[0]
353        } else {
354            providers
355                .iter()
356                .find(|p| p.descriptor.primary)
357                .ok_or(DiError::AmbiguousProvider(type_name))?
358        };
359        let value = self.resolve_provider(selected, context.clone()).await?;
360        value
361            .downcast::<T>()
362            .map_err(|_| DiError::TypeMismatch(type_name))
363    }
364
365    fn resolve_provider<'a>(
366        &'a self,
367        provider: &'a RuntimeProvider,
368        mut context: ResolutionContext,
369    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
370        Box::pin(async move {
371            let descriptor = provider.descriptor;
372            let type_name = (descriptor.type_name)();
373            if context.chain.contains(&type_name) {
374                context.chain.push(type_name);
375                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
376            }
377            context.chain.push(type_name);
378            if descriptor.scope == Scope::Prototype {
379                return (descriptor.factory)(self, context).await;
380            }
381            match descriptor.scope {
382                Scope::Singleton => {
383                    let value = provider
384                        .singleton
385                        .get_or_try_init(
386                            || async move { (descriptor.factory)(self, context).await },
387                        )
388                        .await?;
389                    Ok(value.clone())
390                }
391                Scope::Request => {
392                    let map = context
393                        .request_instances
394                        .as_deref()
395                        .ok_or(DiError::RequestScopeUnavailable(type_name))?;
396                    let cell = {
397                        let mut instances = map.lock().expect("DI instance lock poisoned");
398                        instances
399                            .entry(provider_key(descriptor))
400                            .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
401                            .clone()
402                    };
403                    let value = cell
404                        .get_or_try_init(
405                            || async move { (descriptor.factory)(self, context).await },
406                        )
407                        .await?;
408                    Ok(value.clone())
409                }
410                Scope::Prototype => unreachable!(),
411            }
412        })
413    }
414}
415
416fn provider_key(provider: &'static ProviderDescriptor) -> usize {
417    provider as *const ProviderDescriptor as usize
418}
419
420pub struct RequestContext<'a> {
421    container: &'a Container,
422    context: ResolutionContext,
423}
424impl RequestContext<'_> {
425    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
426    where
427        T: Any + Send + Sync,
428    {
429        self.container.resolve_dependency::<T>(&self.context).await
430    }
431}
432
433#[doc(hidden)]
434pub mod __private {
435    pub use inventory;
436    pub use tokio;
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use crate::{Lazy, Provider, configuration_properties, provider, singleton};
443    use std::sync::atomic::{AtomicUsize, Ordering};
444    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
445    struct SyncDependency;
446    #[singleton]
447    fn sync_dependency() -> SyncDependency {
448        CREATIONS.fetch_add(1, Ordering::SeqCst);
449        SyncDependency
450    }
451    struct AsyncDependency {
452        _sync: Arc<SyncDependency>,
453    }
454    #[singleton]
455    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
456        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
457        AsyncDependency { _sync: sync }
458    }
459    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
460    async fn singleton_is_concurrent_safe() {
461        let container = Arc::new(Container::new().unwrap());
462        let mut tasks = tokio::task::JoinSet::new();
463        for _ in 0..32 {
464            let c = container.clone();
465            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
466        }
467        let mut values = vec![];
468        while let Some(v) = tasks.join_next().await {
469            values.push(v.unwrap());
470        }
471        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
472        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
473    }
474
475    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
476    struct PrototypeBean(usize);
477    #[singleton(scope = "prototype")]
478    fn prototype_bean() -> PrototypeBean {
479        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
480    }
481
482    struct RequestBean;
483    #[singleton(scope = "request")]
484    fn request_bean() -> RequestBean {
485        RequestBean
486    }
487
488    trait Greeting: Send + Sync {
489        fn text(&self) -> &'static str;
490    }
491    struct English;
492    impl Greeting for English {
493        fn text(&self) -> &'static str {
494            "hello"
495        }
496    }
497    struct Hindi;
498    impl Greeting for Hindi {
499        fn text(&self) -> &'static str {
500            "namaste"
501        }
502    }
503
504    #[singleton(name = "english", primary)]
505    fn english_greeting() -> Arc<dyn Greeting> {
506        Arc::new(English)
507    }
508    #[singleton(name = "hindi")]
509    fn hindi_greeting() -> Arc<dyn Greeting> {
510        Arc::new(Hindi)
511    }
512
513    struct MissingOptional;
514    struct Greeter {
515        greeting: Arc<dyn Greeting>,
516        optional: Option<Arc<MissingOptional>>,
517    }
518    struct GreetingLabel(&'static str);
519    #[singleton]
520    impl Greeter {
521        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
522            Self { greeting, optional }
523        }
524
525        #[provider]
526        fn label(&self) -> GreetingLabel {
527            GreetingLabel(self.greeting.text())
528        }
529    }
530    struct QualifiedGreeter(Arc<dyn Greeting>);
531    #[singleton]
532    impl QualifiedGreeter {
533        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
534            Self(greeting)
535        }
536    }
537
538    struct StaticConfig(&'static str);
539    struct ServiceWithStaticBean(Arc<StaticConfig>);
540    #[singleton]
541    impl ServiceWithStaticBean {
542        fn new(config: Arc<StaticConfig>) -> Self {
543            Self(config)
544        }
545
546        // No `&self`: this provider can run before its owning singleton and can
547        // therefore participate in the service's constructor graph.
548        #[provider]
549        fn config() -> StaticConfig {
550            StaticConfig("static-bean")
551        }
552    }
553
554    static STARTED: AtomicUsize = AtomicUsize::new(0);
555    static STOPPED: AtomicUsize = AtomicUsize::new(0);
556    struct Managed;
557    #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
558    impl Managed {
559        fn new() -> Self {
560            Self
561        }
562        async fn start(&self) {
563            STARTED.fetch_add(1, Ordering::SeqCst);
564        }
565        async fn stop(&self) {
566            STOPPED.fetch_add(1, Ordering::SeqCst);
567        }
568    }
569
570    struct ProfileBean;
571    #[singleton(profile = "test")]
572    fn profile_bean() -> ProfileBean {
573        ProfileBean
574    }
575
576    #[derive(Debug)]
577    struct Handler(&'static str);
578    #[singleton(name = "first")]
579    fn first_handler() -> Handler {
580        Handler("first")
581    }
582    #[singleton(name = "second")]
583    fn second_handler() -> Handler {
584        Handler("second")
585    }
586    struct Pipeline(Vec<Arc<Handler>>);
587    #[singleton]
588    impl Pipeline {
589        fn new(handlers: Vec<Arc<Handler>>) -> Self {
590            Self(handlers)
591        }
592    }
593
594    #[configuration_properties("testing_dep")]
595    struct TestProperties {
596        port: u16,
597    }
598
599    struct DeferredTarget;
600    #[singleton]
601    fn deferred_target() -> DeferredTarget {
602        DeferredTarget
603    }
604    struct DeferredConsumer {
605        provider: Provider<DeferredTarget>,
606        lazy: Lazy<DeferredTarget>,
607    }
608    #[singleton]
609    impl DeferredConsumer {
610        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
611            Self { provider, lazy }
612        }
613    }
614
615    struct ConditionalBean;
616    #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
617    fn conditional_bean() -> ConditionalBean {
618        ConditionalBean
619    }
620
621    struct StandaloneBean;
622    #[provider]
623    fn standalone_bean() -> StandaloneBean {
624        StandaloneBean
625    }
626
627    #[tokio::test]
628    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
629        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
630        let container = Container::with_profiles(["test"]).unwrap();
631        let first = container.resolve::<PrototypeBean>().await.unwrap();
632        let second = container.resolve::<PrototypeBean>().await.unwrap();
633        assert_ne!(first.0, second.0);
634
635        assert!(matches!(
636            container.resolve::<RequestBean>().await,
637            Err(DiError::RequestScopeUnavailable(_))
638        ));
639        let request = container.request_context();
640        let request_first = request.resolve::<RequestBean>().await.unwrap();
641        let request_second = request.resolve::<RequestBean>().await.unwrap();
642        assert!(Arc::ptr_eq(&request_first, &request_second));
643
644        let greeter = container.resolve::<Greeter>().await.unwrap();
645        assert_eq!(greeter.greeting.text(), "hello");
646        assert_eq!(
647            container.resolve::<GreetingLabel>().await.unwrap().0,
648            "hello"
649        );
650        assert!(greeter.optional.is_none());
651        assert_eq!(
652            container
653                .resolve::<QualifiedGreeter>()
654                .await
655                .unwrap()
656                .0
657                .text(),
658            "namaste"
659        );
660        let hindi = container
661            .resolve_named::<Arc<dyn Greeting>>("hindi")
662            .await
663            .unwrap();
664        assert_eq!(hindi.text(), "namaste");
665        let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
666        assert_eq!(static_bean_service.0.0, "static-bean");
667        container.resolve::<ProfileBean>().await.unwrap();
668        let pipeline = container.resolve::<Pipeline>().await.unwrap();
669        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
670        handler_names.sort_unstable();
671        assert_eq!(handler_names, ["first", "second"]);
672
673        // This process-local key is unique to the test crate.
674        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
675        let properties = container.resolve::<TestProperties>().await.unwrap();
676        assert_eq!(properties.port, 8080);
677        container.resolve::<ConditionalBean>().await.unwrap();
678        container.resolve::<StandaloneBean>().await.unwrap();
679
680        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
681        deferred.provider.get().await.unwrap();
682        deferred.lazy.get().await.unwrap();
683
684        container.initialize_eager().await.unwrap();
685        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
686        container.shutdown().await.unwrap();
687        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
688    }
689}