Skip to main content

auto_di/
runtime.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    future::Future,
5    pin::Pin,
6    sync::{Arc, Mutex, OnceLock},
7};
8
9pub type DynArc = Arc<dyn Any + Send + Sync>;
10pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
11type Factory =
12    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
13type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Scope {
17    Singleton,
18    Prototype,
19    Request,
20}
21
22#[doc(hidden)]
23pub struct ProviderDescriptor {
24    type_id: fn() -> TypeId,
25    type_name: fn() -> &'static str,
26    factory: Factory,
27    pub name: Option<&'static str>,
28    pub primary: bool,
29    pub scope: Scope,
30    pub eager: bool,
31    pub profile: Option<&'static str>,
32    pub condition_key: Option<&'static str>,
33    pub condition_value: Option<&'static str>,
34    pub destroy: Option<Destroy>,
35}
36
37impl ProviderDescriptor {
38    #[doc(hidden)]
39    pub const fn new(
40        type_id: fn() -> TypeId,
41        type_name: fn() -> &'static str,
42        factory: Factory,
43    ) -> Self {
44        Self::configured(
45            type_id,
46            type_name,
47            factory,
48            None,
49            false,
50            Scope::Singleton,
51            false,
52            None,
53            None,
54            None,
55            None,
56        )
57    }
58
59    #[allow(clippy::too_many_arguments)]
60    #[doc(hidden)]
61    pub const fn configured(
62        type_id: fn() -> TypeId,
63        type_name: fn() -> &'static str,
64        factory: Factory,
65        name: Option<&'static str>,
66        primary: bool,
67        scope: Scope,
68        eager: bool,
69        profile: Option<&'static str>,
70        condition_key: Option<&'static str>,
71        condition_value: Option<&'static str>,
72        destroy: Option<Destroy>,
73    ) -> Self {
74        Self {
75            type_id,
76            type_name,
77            factory,
78            name,
79            primary,
80            scope,
81            eager,
82            profile,
83            condition_key,
84            condition_value,
85            destroy,
86        }
87    }
88
89    fn active(&self, profiles: &[String]) -> bool {
90        let profile_matches = self
91            .profile
92            .is_none_or(|required| profiles.iter().any(|p| p == required));
93        let condition_matches = self.condition_key.is_none_or(|key| {
94            let actual = std::env::var(key).ok();
95            self.condition_value.map_or(actual.is_some(), |expected| {
96                actual.as_deref() == Some(expected)
97            })
98        });
99        profile_matches && condition_matches
100    }
101}
102
103inventory::collect!(ProviderDescriptor);
104
105type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
106
107#[derive(Clone, Default)]
108pub struct ResolutionContext {
109    chain: Vec<&'static str>,
110    request_instances: Option<Arc<Mutex<InstanceMap>>>,
111}
112
113#[derive(Debug, thiserror::Error)]
114pub enum DiError {
115    #[error("no active provider is registered for {0}")]
116    MissingProvider(&'static str),
117    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
118    AmbiguousProvider(&'static str),
119    #[error("multiple primary providers are registered for {0}")]
120    MultiplePrimary(&'static str),
121    #[error("circular dependency detected: {0}")]
122    CircularDependency(String),
123    #[error("provider for {0} returned an incompatible type")]
124    TypeMismatch(&'static str),
125    #[error("request-scoped dependency {0} was resolved outside RequestContext")]
126    RequestScopeUnavailable(&'static str),
127    #[error("configuration property {key} is missing or invalid: {message}")]
128    Configuration { key: String, message: String },
129    #[error("lifecycle hook failed for {0}")]
130    Lifecycle(&'static str),
131}
132
133pub struct Container {
134    providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>>,
135    instances: Mutex<InstanceMap>,
136}
137
138static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
139
140pub fn global_container() -> Result<&'static Container, DiError> {
141    if let Some(container) = GLOBAL_CONTAINER.get() {
142        return Ok(container);
143    }
144    let container = Container::new()?;
145    let _ = GLOBAL_CONTAINER.set(container);
146    Ok(GLOBAL_CONTAINER
147        .get()
148        .expect("global DI container initialized"))
149}
150
151pub async fn resolve<T>() -> Result<Arc<T>, DiError>
152where
153    T: Any + Send + Sync,
154{
155    global_container()?.resolve::<T>().await
156}
157
158impl Container {
159    pub fn new() -> Result<Self, DiError> {
160        let profiles = std::env::var("APP_PROFILES")
161            .unwrap_or_default()
162            .split(',')
163            .map(str::trim)
164            .filter(|p| !p.is_empty())
165            .map(str::to_owned)
166            .collect::<Vec<_>>();
167        Self::with_profiles(profiles)
168    }
169
170    pub fn with_profiles(
171        profiles: impl IntoIterator<Item = impl Into<String>>,
172    ) -> Result<Self, DiError> {
173        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
174        let mut providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>> = HashMap::new();
175        for provider in inventory::iter::<ProviderDescriptor> {
176            if provider.active(&profiles) {
177                providers
178                    .entry((provider.type_id)())
179                    .or_default()
180                    .push(provider);
181            }
182        }
183        for group in providers.values() {
184            if group.iter().filter(|p| p.primary).count() > 1 {
185                return Err(DiError::MultiplePrimary((group[0].type_name)()));
186            }
187        }
188        Ok(Self {
189            providers,
190            instances: Mutex::new(HashMap::new()),
191        })
192    }
193
194    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
195    where
196        T: Any + Send + Sync,
197    {
198        self.resolve_dependency::<T>(&ResolutionContext::default())
199            .await
200    }
201
202    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
203    where
204        T: Any + Send + Sync,
205    {
206        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
207            .await
208    }
209
210    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
211    where
212        T: Any + Send + Sync,
213    {
214        match self.resolve::<T>().await {
215            Ok(value) => Ok(Some(value)),
216            Err(DiError::MissingProvider(_)) => Ok(None),
217            Err(error) => Err(error),
218        }
219    }
220
221    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
222    where
223        T: Any + Send + Sync,
224    {
225        self.resolve_all_dependency::<T>(&ResolutionContext::default())
226            .await
227    }
228
229    pub fn request_context(&self) -> RequestContext<'_> {
230        RequestContext {
231            container: self,
232            context: ResolutionContext {
233                chain: vec![],
234                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
235            },
236        }
237    }
238
239    pub async fn initialize_eager(&self) -> Result<(), DiError> {
240        for providers in self.providers.values() {
241            for provider in providers.iter().filter(|p| p.eager) {
242                self.resolve_provider(provider, ResolutionContext::default())
243                    .await?;
244            }
245        }
246        Ok(())
247    }
248
249    pub async fn shutdown(&self) -> Result<(), DiError> {
250        let initialized = self
251            .instances
252            .lock()
253            .expect("DI instance lock poisoned")
254            .iter()
255            .filter_map(|(key, cell)| cell.get().map(|value| (*key, value.clone())))
256            .collect::<Vec<_>>();
257        for providers in self.providers.values() {
258            for provider in providers {
259                let key = provider_key(provider);
260                if let (Some(destroy), Some((_, value))) = (
261                    provider.destroy,
262                    initialized.iter().find(|(k, _)| *k == key),
263                ) {
264                    destroy(value.clone()).await?;
265                }
266            }
267        }
268        Ok(())
269    }
270
271    #[doc(hidden)]
272    pub async fn resolve_dependency<T>(
273        &self,
274        context: &ResolutionContext,
275    ) -> Result<Arc<T>, DiError>
276    where
277        T: Any + Send + Sync,
278    {
279        self.resolve_selected::<T>(None, context).await
280    }
281
282    #[doc(hidden)]
283    pub async fn resolve_named_dependency<T>(
284        &self,
285        name: &str,
286        context: &ResolutionContext,
287    ) -> Result<Arc<T>, DiError>
288    where
289        T: Any + Send + Sync,
290    {
291        self.resolve_selected::<T>(Some(name), context).await
292    }
293
294    #[doc(hidden)]
295    pub async fn resolve_optional_dependency<T>(
296        &self,
297        context: &ResolutionContext,
298    ) -> Result<Option<Arc<T>>, DiError>
299    where
300        T: Any + Send + Sync,
301    {
302        match self.resolve_dependency::<T>(context).await {
303            Ok(value) => Ok(Some(value)),
304            Err(DiError::MissingProvider(_)) => Ok(None),
305            Err(error) => Err(error),
306        }
307    }
308
309    #[doc(hidden)]
310    pub async fn resolve_all_dependency<T>(
311        &self,
312        context: &ResolutionContext,
313    ) -> Result<Vec<Arc<T>>, DiError>
314    where
315        T: Any + Send + Sync,
316    {
317        let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
318            return Ok(vec![]);
319        };
320        let mut values = Vec::with_capacity(providers.len());
321        for provider in providers {
322            let value = self.resolve_provider(provider, context.clone()).await?;
323            values.push(
324                value
325                    .downcast::<T>()
326                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
327            );
328        }
329        Ok(values)
330    }
331
332    async fn resolve_selected<T>(
333        &self,
334        name: Option<&str>,
335        context: &ResolutionContext,
336    ) -> Result<Arc<T>, DiError>
337    where
338        T: Any + Send + Sync,
339    {
340        let type_name = std::any::type_name::<T>();
341        let providers = self
342            .providers
343            .get(&TypeId::of::<T>())
344            .ok_or(DiError::MissingProvider(type_name))?;
345        let selected = if let Some(name) = name {
346            providers
347                .iter()
348                .copied()
349                .find(|p| p.name == Some(name))
350                .ok_or(DiError::MissingProvider(type_name))?
351        } else if providers.len() == 1 {
352            providers[0]
353        } else {
354            providers
355                .iter()
356                .copied()
357                .find(|p| p.primary)
358                .ok_or(DiError::AmbiguousProvider(type_name))?
359        };
360        let value = self.resolve_provider(selected, context.clone()).await?;
361        value
362            .downcast::<T>()
363            .map_err(|_| DiError::TypeMismatch(type_name))
364    }
365
366    fn resolve_provider<'a>(
367        &'a self,
368        provider: &'static ProviderDescriptor,
369        mut context: ResolutionContext,
370    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
371        Box::pin(async move {
372            let type_name = (provider.type_name)();
373            if context.chain.contains(&type_name) {
374                context.chain.push(type_name);
375                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
376            }
377            context.chain.push(type_name);
378            if provider.scope == Scope::Prototype {
379                return (provider.factory)(self, context).await;
380            }
381            let map = match provider.scope {
382                Scope::Singleton => &self.instances,
383                Scope::Request => context
384                    .request_instances
385                    .as_deref()
386                    .ok_or(DiError::RequestScopeUnavailable(type_name))?,
387                Scope::Prototype => unreachable!(),
388            };
389            let cell = {
390                let mut instances = map.lock().expect("DI instance lock poisoned");
391                instances
392                    .entry(provider_key(provider))
393                    .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
394                    .clone()
395            };
396            let value = cell
397                .get_or_try_init(|| async move { (provider.factory)(self, context).await })
398                .await?;
399            Ok(value.clone())
400        })
401    }
402}
403
404fn provider_key(provider: &'static ProviderDescriptor) -> usize {
405    provider as *const ProviderDescriptor as usize
406}
407
408pub struct RequestContext<'a> {
409    container: &'a Container,
410    context: ResolutionContext,
411}
412impl RequestContext<'_> {
413    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
414    where
415        T: Any + Send + Sync,
416    {
417        self.container.resolve_dependency::<T>(&self.context).await
418    }
419}
420
421#[doc(hidden)]
422pub mod __private {
423    pub use inventory;
424    pub use tokio;
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use crate::{Lazy, Provider, beans, configuration_properties, provider, singleton};
431    use std::sync::atomic::{AtomicUsize, Ordering};
432    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
433    struct SyncDependency;
434    #[singleton]
435    fn sync_dependency() -> SyncDependency {
436        CREATIONS.fetch_add(1, Ordering::SeqCst);
437        SyncDependency
438    }
439    struct AsyncDependency {
440        _sync: Arc<SyncDependency>,
441    }
442    #[singleton]
443    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
444        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
445        AsyncDependency { _sync: sync }
446    }
447    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
448    async fn singleton_is_concurrent_safe() {
449        let container = Arc::new(Container::new().unwrap());
450        let mut tasks = tokio::task::JoinSet::new();
451        for _ in 0..32 {
452            let c = container.clone();
453            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
454        }
455        let mut values = vec![];
456        while let Some(v) = tasks.join_next().await {
457            values.push(v.unwrap());
458        }
459        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
460        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
461    }
462
463    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
464    struct PrototypeBean(usize);
465    #[singleton(scope = "prototype")]
466    fn prototype_bean() -> PrototypeBean {
467        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
468    }
469
470    struct RequestBean;
471    #[singleton(scope = "request")]
472    fn request_bean() -> RequestBean {
473        RequestBean
474    }
475
476    trait Greeting: Send + Sync {
477        fn text(&self) -> &'static str;
478    }
479    struct English;
480    impl Greeting for English {
481        fn text(&self) -> &'static str {
482            "hello"
483        }
484    }
485    struct Hindi;
486    impl Greeting for Hindi {
487        fn text(&self) -> &'static str {
488            "namaste"
489        }
490    }
491
492    #[singleton(name = "english", primary)]
493    fn english_greeting() -> Arc<dyn Greeting> {
494        Arc::new(English)
495    }
496    #[singleton(name = "hindi")]
497    fn hindi_greeting() -> Arc<dyn Greeting> {
498        Arc::new(Hindi)
499    }
500
501    struct MissingOptional;
502    struct Greeter {
503        greeting: Arc<dyn Greeting>,
504        optional: Option<Arc<MissingOptional>>,
505    }
506    struct GreetingLabel(&'static str);
507    #[singleton]
508    impl Greeter {
509        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
510            Self { greeting, optional }
511        }
512
513        #[provider]
514        fn label(&self) -> GreetingLabel {
515            GreetingLabel(self.greeting.text())
516        }
517    }
518    struct QualifiedGreeter(Arc<dyn Greeting>);
519    #[singleton]
520    impl QualifiedGreeter {
521        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
522            Self(greeting)
523        }
524    }
525
526    struct StaticConfig(&'static str);
527    struct ServiceWithStaticBean(Arc<StaticConfig>);
528    #[singleton]
529    impl ServiceWithStaticBean {
530        fn new(config: Arc<StaticConfig>) -> Self {
531            Self(config)
532        }
533
534        // No `&self`: this provider can run before its owning singleton and can
535        // therefore participate in the service's constructor graph.
536        #[provider]
537        fn config() -> StaticConfig {
538            StaticConfig("static-bean")
539        }
540    }
541
542    static STARTED: AtomicUsize = AtomicUsize::new(0);
543    static STOPPED: AtomicUsize = AtomicUsize::new(0);
544    struct Managed;
545    #[singleton(eager, post_construct = "start", pre_destroy = "stop")]
546    impl Managed {
547        fn new() -> Self {
548            Self
549        }
550        async fn start(&self) {
551            STARTED.fetch_add(1, Ordering::SeqCst);
552        }
553        async fn stop(&self) {
554            STOPPED.fetch_add(1, Ordering::SeqCst);
555        }
556    }
557
558    struct ProfileBean;
559    #[singleton(profile = "test")]
560    fn profile_bean() -> ProfileBean {
561        ProfileBean
562    }
563
564    #[derive(Debug)]
565    struct Handler(&'static str);
566    #[singleton(name = "first")]
567    fn first_handler() -> Handler {
568        Handler("first")
569    }
570    #[singleton(name = "second")]
571    fn second_handler() -> Handler {
572        Handler("second")
573    }
574    struct Pipeline(Vec<Arc<Handler>>);
575    #[singleton]
576    impl Pipeline {
577        fn new(handlers: Vec<Arc<Handler>>) -> Self {
578            Self(handlers)
579        }
580    }
581
582    #[configuration_properties("testing_dep")]
583    struct TestProperties {
584        port: u16,
585    }
586
587    struct DeferredTarget;
588    #[singleton]
589    fn deferred_target() -> DeferredTarget {
590        DeferredTarget
591    }
592    struct DeferredConsumer {
593        provider: Provider<DeferredTarget>,
594        lazy: Lazy<DeferredTarget>,
595    }
596    #[singleton]
597    impl DeferredConsumer {
598        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
599            Self { provider, lazy }
600        }
601    }
602
603    struct ConditionalBean;
604    #[singleton(condition = "TESTING_DEP_FEATURE=enabled")]
605    fn conditional_bean() -> ConditionalBean {
606        ConditionalBean
607    }
608
609    struct StandaloneBean;
610    #[provider]
611    fn standalone_bean() -> StandaloneBean {
612        StandaloneBean
613    }
614
615    #[derive(Default)]
616    struct OrdinaryFactory;
617    struct BeanFromOrdinaryImpl;
618    #[beans]
619    impl OrdinaryFactory {
620        #[provider]
621        fn dependency(&self) -> BeanFromOrdinaryImpl {
622            BeanFromOrdinaryImpl
623        }
624    }
625
626    #[tokio::test]
627    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
628        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
629        let container = Container::with_profiles(["test"]).unwrap();
630        let first = container.resolve::<PrototypeBean>().await.unwrap();
631        let second = container.resolve::<PrototypeBean>().await.unwrap();
632        assert_ne!(first.0, second.0);
633
634        assert!(matches!(
635            container.resolve::<RequestBean>().await,
636            Err(DiError::RequestScopeUnavailable(_))
637        ));
638        let request = container.request_context();
639        let request_first = request.resolve::<RequestBean>().await.unwrap();
640        let request_second = request.resolve::<RequestBean>().await.unwrap();
641        assert!(Arc::ptr_eq(&request_first, &request_second));
642
643        let greeter = container.resolve::<Greeter>().await.unwrap();
644        assert_eq!(greeter.greeting.text(), "hello");
645        assert_eq!(
646            container.resolve::<GreetingLabel>().await.unwrap().0,
647            "hello"
648        );
649        assert!(greeter.optional.is_none());
650        assert_eq!(
651            container
652                .resolve::<QualifiedGreeter>()
653                .await
654                .unwrap()
655                .0
656                .text(),
657            "namaste"
658        );
659        let hindi = container
660            .resolve_named::<Arc<dyn Greeting>>("hindi")
661            .await
662            .unwrap();
663        assert_eq!(hindi.text(), "namaste");
664        let static_bean_service = container.resolve::<ServiceWithStaticBean>().await.unwrap();
665        assert_eq!(static_bean_service.0.0, "static-bean");
666        container.resolve::<ProfileBean>().await.unwrap();
667        let pipeline = container.resolve::<Pipeline>().await.unwrap();
668        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
669        handler_names.sort_unstable();
670        assert_eq!(handler_names, ["first", "second"]);
671
672        // This process-local key is unique to the test crate.
673        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
674        let properties = container.resolve::<TestProperties>().await.unwrap();
675        assert_eq!(properties.port, 8080);
676        container.resolve::<ConditionalBean>().await.unwrap();
677        container.resolve::<StandaloneBean>().await.unwrap();
678        container.resolve::<BeanFromOrdinaryImpl>().await.unwrap();
679
680        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
681        deferred.provider.get().await.unwrap();
682        deferred.lazy.get().await.unwrap();
683
684        container.initialize_eager().await.unwrap();
685        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
686        container.shutdown().await.unwrap();
687        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
688    }
689}