Skip to main content

auto_di/
lib.rs

1extern crate self as auto_di;
2
3use std::{
4    any::{Any, TypeId},
5    collections::HashMap,
6    future::Future,
7    marker::PhantomData,
8    pin::Pin,
9    sync::{Arc, Mutex, OnceLock},
10};
11
12pub use auto_di_macros::{
13    application, bean, component, configuration, configuration_properties, qualifier, repository,
14    service, singleton,
15};
16
17pub type DynArc = Arc<dyn Any + Send + Sync>;
18pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
19type Factory =
20    for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
21type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum Scope {
25    Singleton,
26    Prototype,
27    Request,
28}
29
30#[doc(hidden)]
31pub struct ProviderDescriptor {
32    type_id: fn() -> TypeId,
33    type_name: fn() -> &'static str,
34    factory: Factory,
35    pub name: Option<&'static str>,
36    pub primary: bool,
37    pub scope: Scope,
38    pub eager: bool,
39    pub profile: Option<&'static str>,
40    pub condition_key: Option<&'static str>,
41    pub condition_value: Option<&'static str>,
42    pub destroy: Option<Destroy>,
43}
44
45impl ProviderDescriptor {
46    #[doc(hidden)]
47    pub const fn new(
48        type_id: fn() -> TypeId,
49        type_name: fn() -> &'static str,
50        factory: Factory,
51    ) -> Self {
52        Self::configured(
53            type_id,
54            type_name,
55            factory,
56            None,
57            false,
58            Scope::Singleton,
59            false,
60            None,
61            None,
62            None,
63            None,
64        )
65    }
66
67    #[allow(clippy::too_many_arguments)]
68    #[doc(hidden)]
69    pub const fn configured(
70        type_id: fn() -> TypeId,
71        type_name: fn() -> &'static str,
72        factory: Factory,
73        name: Option<&'static str>,
74        primary: bool,
75        scope: Scope,
76        eager: bool,
77        profile: Option<&'static str>,
78        condition_key: Option<&'static str>,
79        condition_value: Option<&'static str>,
80        destroy: Option<Destroy>,
81    ) -> Self {
82        Self {
83            type_id,
84            type_name,
85            factory,
86            name,
87            primary,
88            scope,
89            eager,
90            profile,
91            condition_key,
92            condition_value,
93            destroy,
94        }
95    }
96
97    fn active(&self, profiles: &[String]) -> bool {
98        let profile_matches = self
99            .profile
100            .is_none_or(|required| profiles.iter().any(|p| p == required));
101        let condition_matches = self.condition_key.is_none_or(|key| {
102            let actual = std::env::var(key).ok();
103            self.condition_value.map_or(actual.is_some(), |expected| {
104                actual.as_deref() == Some(expected)
105            })
106        });
107        profile_matches && condition_matches
108    }
109}
110
111inventory::collect!(ProviderDescriptor);
112
113type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
114
115#[derive(Clone, Default)]
116pub struct ResolutionContext {
117    chain: Vec<&'static str>,
118    request_instances: Option<Arc<Mutex<InstanceMap>>>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum DiError {
123    #[error("no active provider is registered for {0}")]
124    MissingProvider(&'static str),
125    #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
126    AmbiguousProvider(&'static str),
127    #[error("multiple primary providers are registered for {0}")]
128    MultiplePrimary(&'static str),
129    #[error("circular dependency detected: {0}")]
130    CircularDependency(String),
131    #[error("provider for {0} returned an incompatible type")]
132    TypeMismatch(&'static str),
133    #[error("request-scoped bean {0} was resolved outside RequestContext")]
134    RequestScopeUnavailable(&'static str),
135    #[error("configuration property {key} is missing or invalid: {message}")]
136    Configuration { key: String, message: String },
137    #[error("lifecycle hook failed for {0}")]
138    Lifecycle(&'static str),
139}
140
141pub struct Container {
142    providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>>,
143    instances: Mutex<InstanceMap>,
144}
145
146static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
147
148pub fn global_container() -> Result<&'static Container, DiError> {
149    if let Some(container) = GLOBAL_CONTAINER.get() {
150        return Ok(container);
151    }
152    let container = Container::new()?;
153    let _ = GLOBAL_CONTAINER.set(container);
154    Ok(GLOBAL_CONTAINER
155        .get()
156        .expect("global DI container initialized"))
157}
158
159pub async fn resolve<T>() -> Result<Arc<T>, DiError>
160where
161    T: Any + Send + Sync,
162{
163    global_container()?.resolve::<T>().await
164}
165
166impl Container {
167    pub fn new() -> Result<Self, DiError> {
168        let profiles = std::env::var("APP_PROFILES")
169            .unwrap_or_default()
170            .split(',')
171            .map(str::trim)
172            .filter(|p| !p.is_empty())
173            .map(str::to_owned)
174            .collect::<Vec<_>>();
175        Self::with_profiles(profiles)
176    }
177
178    pub fn with_profiles(
179        profiles: impl IntoIterator<Item = impl Into<String>>,
180    ) -> Result<Self, DiError> {
181        let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
182        let mut providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>> = HashMap::new();
183        for provider in inventory::iter::<ProviderDescriptor> {
184            if provider.active(&profiles) {
185                providers
186                    .entry((provider.type_id)())
187                    .or_default()
188                    .push(provider);
189            }
190        }
191        for group in providers.values() {
192            if group.iter().filter(|p| p.primary).count() > 1 {
193                return Err(DiError::MultiplePrimary((group[0].type_name)()));
194            }
195        }
196        Ok(Self {
197            providers,
198            instances: Mutex::new(HashMap::new()),
199        })
200    }
201
202    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
203    where
204        T: Any + Send + Sync,
205    {
206        self.resolve_dependency::<T>(&ResolutionContext::default())
207            .await
208    }
209
210    pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
211    where
212        T: Any + Send + Sync,
213    {
214        self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
215            .await
216    }
217
218    pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
219    where
220        T: Any + Send + Sync,
221    {
222        match self.resolve::<T>().await {
223            Ok(value) => Ok(Some(value)),
224            Err(DiError::MissingProvider(_)) => Ok(None),
225            Err(error) => Err(error),
226        }
227    }
228
229    pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
230    where
231        T: Any + Send + Sync,
232    {
233        self.resolve_all_dependency::<T>(&ResolutionContext::default())
234            .await
235    }
236
237    pub fn request_context(&self) -> RequestContext<'_> {
238        RequestContext {
239            container: self,
240            context: ResolutionContext {
241                chain: vec![],
242                request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
243            },
244        }
245    }
246
247    pub async fn initialize_eager(&self) -> Result<(), DiError> {
248        for providers in self.providers.values() {
249            for provider in providers.iter().filter(|p| p.eager) {
250                self.resolve_provider(provider, ResolutionContext::default())
251                    .await?;
252            }
253        }
254        Ok(())
255    }
256
257    pub async fn shutdown(&self) -> Result<(), DiError> {
258        let initialized = self
259            .instances
260            .lock()
261            .expect("DI instance lock poisoned")
262            .iter()
263            .filter_map(|(key, cell)| cell.get().map(|value| (*key, value.clone())))
264            .collect::<Vec<_>>();
265        for providers in self.providers.values() {
266            for provider in providers {
267                let key = provider_key(provider);
268                if let (Some(destroy), Some((_, value))) = (
269                    provider.destroy,
270                    initialized.iter().find(|(k, _)| *k == key),
271                ) {
272                    destroy(value.clone()).await?;
273                }
274            }
275        }
276        Ok(())
277    }
278
279    #[doc(hidden)]
280    pub async fn resolve_dependency<T>(
281        &self,
282        context: &ResolutionContext,
283    ) -> Result<Arc<T>, DiError>
284    where
285        T: Any + Send + Sync,
286    {
287        self.resolve_selected::<T>(None, context).await
288    }
289
290    #[doc(hidden)]
291    pub async fn resolve_named_dependency<T>(
292        &self,
293        name: &str,
294        context: &ResolutionContext,
295    ) -> Result<Arc<T>, DiError>
296    where
297        T: Any + Send + Sync,
298    {
299        self.resolve_selected::<T>(Some(name), context).await
300    }
301
302    #[doc(hidden)]
303    pub async fn resolve_optional_dependency<T>(
304        &self,
305        context: &ResolutionContext,
306    ) -> Result<Option<Arc<T>>, DiError>
307    where
308        T: Any + Send + Sync,
309    {
310        match self.resolve_dependency::<T>(context).await {
311            Ok(value) => Ok(Some(value)),
312            Err(DiError::MissingProvider(_)) => Ok(None),
313            Err(error) => Err(error),
314        }
315    }
316
317    #[doc(hidden)]
318    pub async fn resolve_all_dependency<T>(
319        &self,
320        context: &ResolutionContext,
321    ) -> Result<Vec<Arc<T>>, DiError>
322    where
323        T: Any + Send + Sync,
324    {
325        let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
326            return Ok(vec![]);
327        };
328        let mut values = Vec::with_capacity(providers.len());
329        for provider in providers {
330            let value = self.resolve_provider(provider, context.clone()).await?;
331            values.push(
332                value
333                    .downcast::<T>()
334                    .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
335            );
336        }
337        Ok(values)
338    }
339
340    async fn resolve_selected<T>(
341        &self,
342        name: Option<&str>,
343        context: &ResolutionContext,
344    ) -> Result<Arc<T>, DiError>
345    where
346        T: Any + Send + Sync,
347    {
348        let type_name = std::any::type_name::<T>();
349        let providers = self
350            .providers
351            .get(&TypeId::of::<T>())
352            .ok_or(DiError::MissingProvider(type_name))?;
353        let selected = if let Some(name) = name {
354            providers
355                .iter()
356                .copied()
357                .find(|p| p.name == Some(name))
358                .ok_or(DiError::MissingProvider(type_name))?
359        } else if providers.len() == 1 {
360            providers[0]
361        } else {
362            providers
363                .iter()
364                .copied()
365                .find(|p| p.primary)
366                .ok_or(DiError::AmbiguousProvider(type_name))?
367        };
368        let value = self.resolve_provider(selected, context.clone()).await?;
369        value
370            .downcast::<T>()
371            .map_err(|_| DiError::TypeMismatch(type_name))
372    }
373
374    fn resolve_provider<'a>(
375        &'a self,
376        provider: &'static ProviderDescriptor,
377        mut context: ResolutionContext,
378    ) -> BoxFuture<'a, Result<DynArc, DiError>> {
379        Box::pin(async move {
380            let type_name = (provider.type_name)();
381            if context.chain.contains(&type_name) {
382                context.chain.push(type_name);
383                return Err(DiError::CircularDependency(context.chain.join(" -> ")));
384            }
385            context.chain.push(type_name);
386            if provider.scope == Scope::Prototype {
387                return (provider.factory)(self, context).await;
388            }
389            let map = match provider.scope {
390                Scope::Singleton => &self.instances,
391                Scope::Request => context
392                    .request_instances
393                    .as_deref()
394                    .ok_or(DiError::RequestScopeUnavailable(type_name))?,
395                Scope::Prototype => unreachable!(),
396            };
397            let cell = {
398                let mut instances = map.lock().expect("DI instance lock poisoned");
399                instances
400                    .entry(provider_key(provider))
401                    .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
402                    .clone()
403            };
404            let value = cell
405                .get_or_try_init(|| async move { (provider.factory)(self, context).await })
406                .await?;
407            Ok(value.clone())
408        })
409    }
410}
411
412fn provider_key(provider: &'static ProviderDescriptor) -> usize {
413    provider as *const ProviderDescriptor as usize
414}
415
416pub struct RequestContext<'a> {
417    container: &'a Container,
418    context: ResolutionContext,
419}
420impl RequestContext<'_> {
421    pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
422    where
423        T: Any + Send + Sync,
424    {
425        self.container.resolve_dependency::<T>(&self.context).await
426    }
427}
428
429pub struct Provider<T> {
430    _marker: PhantomData<fn() -> T>,
431}
432impl<T> Clone for Provider<T> {
433    fn clone(&self) -> Self {
434        Self {
435            _marker: PhantomData,
436        }
437    }
438}
439impl<T> Default for Provider<T> {
440    fn default() -> Self {
441        Self {
442            _marker: PhantomData,
443        }
444    }
445}
446impl<T> Provider<T>
447where
448    T: Any + Send + Sync,
449{
450    pub async fn get(&self) -> Result<Arc<T>, DiError> {
451        resolve::<T>().await
452    }
453}
454
455pub struct Lazy<T> {
456    cell: tokio::sync::OnceCell<Arc<T>>,
457}
458impl<T> Default for Lazy<T> {
459    fn default() -> Self {
460        Self {
461            cell: tokio::sync::OnceCell::new(),
462        }
463    }
464}
465impl<T> Lazy<T>
466where
467    T: Any + Send + Sync,
468{
469    pub async fn get(&self) -> Result<&Arc<T>, DiError> {
470        self.cell
471            .get_or_try_init(|| async { resolve::<T>().await })
472            .await
473    }
474}
475
476pub trait ConfigurationProperties: Sized {
477    fn from_environment() -> Result<Self, DiError>;
478}
479
480#[doc(hidden)]
481pub mod __private {
482    pub use inventory;
483    pub use tokio;
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use std::sync::atomic::{AtomicUsize, Ordering};
490    static CREATIONS: AtomicUsize = AtomicUsize::new(0);
491    struct SyncDependency;
492    #[singleton]
493    fn sync_dependency() -> SyncDependency {
494        CREATIONS.fetch_add(1, Ordering::SeqCst);
495        SyncDependency
496    }
497    struct AsyncDependency {
498        _sync: Arc<SyncDependency>,
499    }
500    #[singleton]
501    async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
502        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
503        AsyncDependency { _sync: sync }
504    }
505    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
506    async fn singleton_is_concurrent_safe() {
507        let container = Arc::new(Container::new().unwrap());
508        let mut tasks = tokio::task::JoinSet::new();
509        for _ in 0..32 {
510            let c = container.clone();
511            tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
512        }
513        let mut values = vec![];
514        while let Some(v) = tasks.join_next().await {
515            values.push(v.unwrap());
516        }
517        assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
518        assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
519    }
520
521    static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
522    struct PrototypeBean(usize);
523    #[component(scope = "prototype")]
524    fn prototype_bean() -> PrototypeBean {
525        PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
526    }
527
528    struct RequestBean;
529    #[component(scope = "request")]
530    fn request_bean() -> RequestBean {
531        RequestBean
532    }
533
534    trait Greeting: Send + Sync {
535        fn text(&self) -> &'static str;
536    }
537    struct English;
538    impl Greeting for English {
539        fn text(&self) -> &'static str {
540            "hello"
541        }
542    }
543    struct Hindi;
544    impl Greeting for Hindi {
545        fn text(&self) -> &'static str {
546            "namaste"
547        }
548    }
549
550    #[repository(name = "english", primary)]
551    fn english_greeting() -> Arc<dyn Greeting> {
552        Arc::new(English)
553    }
554    #[repository(name = "hindi")]
555    fn hindi_greeting() -> Arc<dyn Greeting> {
556        Arc::new(Hindi)
557    }
558
559    struct MissingOptional;
560    struct Greeter {
561        greeting: Arc<dyn Greeting>,
562        optional: Option<Arc<MissingOptional>>,
563    }
564    #[service]
565    impl Greeter {
566        fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
567            Self { greeting, optional }
568        }
569    }
570    struct QualifiedGreeter(Arc<dyn Greeting>);
571    #[service]
572    impl QualifiedGreeter {
573        fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
574            Self(greeting)
575        }
576    }
577
578    static STARTED: AtomicUsize = AtomicUsize::new(0);
579    static STOPPED: AtomicUsize = AtomicUsize::new(0);
580    struct Managed;
581    #[service(eager, post_construct = "start", pre_destroy = "stop")]
582    impl Managed {
583        fn new() -> Self {
584            Self
585        }
586        async fn start(&self) {
587            STARTED.fetch_add(1, Ordering::SeqCst);
588        }
589        async fn stop(&self) {
590            STOPPED.fetch_add(1, Ordering::SeqCst);
591        }
592    }
593
594    struct ProfileBean;
595    #[component(profile = "test")]
596    fn profile_bean() -> ProfileBean {
597        ProfileBean
598    }
599
600    #[derive(Debug)]
601    struct Handler(&'static str);
602    #[component(name = "first")]
603    fn first_handler() -> Handler {
604        Handler("first")
605    }
606    #[component(name = "second")]
607    fn second_handler() -> Handler {
608        Handler("second")
609    }
610    struct Pipeline(Vec<Arc<Handler>>);
611    #[service]
612    impl Pipeline {
613        fn new(handlers: Vec<Arc<Handler>>) -> Self {
614            Self(handlers)
615        }
616    }
617
618    #[configuration_properties("testing_dep")]
619    struct TestProperties {
620        port: u16,
621    }
622
623    struct DeferredTarget;
624    #[component]
625    fn deferred_target() -> DeferredTarget {
626        DeferredTarget
627    }
628    struct DeferredConsumer {
629        provider: Provider<DeferredTarget>,
630        lazy: Lazy<DeferredTarget>,
631    }
632    #[service]
633    impl DeferredConsumer {
634        fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
635            Self { provider, lazy }
636        }
637    }
638
639    struct ConditionalBean;
640    #[component(condition = "TESTING_DEP_FEATURE=enabled")]
641    fn conditional_bean() -> ConditionalBean {
642        ConditionalBean
643    }
644
645    #[tokio::test]
646    async fn scopes_traits_primary_profiles_and_lifecycle_work() {
647        unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
648        let container = Container::with_profiles(["test"]).unwrap();
649        let first = container.resolve::<PrototypeBean>().await.unwrap();
650        let second = container.resolve::<PrototypeBean>().await.unwrap();
651        assert_ne!(first.0, second.0);
652
653        assert!(matches!(
654            container.resolve::<RequestBean>().await,
655            Err(DiError::RequestScopeUnavailable(_))
656        ));
657        let request = container.request_context();
658        let request_first = request.resolve::<RequestBean>().await.unwrap();
659        let request_second = request.resolve::<RequestBean>().await.unwrap();
660        assert!(Arc::ptr_eq(&request_first, &request_second));
661
662        let greeter = container.resolve::<Greeter>().await.unwrap();
663        assert_eq!(greeter.greeting.text(), "hello");
664        assert!(greeter.optional.is_none());
665        assert_eq!(
666            container
667                .resolve::<QualifiedGreeter>()
668                .await
669                .unwrap()
670                .0
671                .text(),
672            "namaste"
673        );
674        let hindi = container
675            .resolve_named::<Arc<dyn Greeting>>("hindi")
676            .await
677            .unwrap();
678        assert_eq!(hindi.text(), "namaste");
679        container.resolve::<ProfileBean>().await.unwrap();
680        let pipeline = container.resolve::<Pipeline>().await.unwrap();
681        let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
682        handler_names.sort_unstable();
683        assert_eq!(handler_names, ["first", "second"]);
684
685        // This process-local key is unique to the test crate.
686        unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
687        let properties = container.resolve::<TestProperties>().await.unwrap();
688        assert_eq!(properties.port, 8080);
689        container.resolve::<ConditionalBean>().await.unwrap();
690
691        let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
692        deferred.provider.get().await.unwrap();
693        deferred.lazy.get().await.unwrap();
694
695        container.initialize_eager().await.unwrap();
696        assert_eq!(STARTED.load(Ordering::SeqCst), 1);
697        container.shutdown().await.unwrap();
698        assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
699    }
700}