1extern crate self as auto_di;
2
3use std::{
4 any::{Any, TypeId},
5 collections::HashMap,
6 future::Future,
7 marker::PhantomData,
8 pin::Pin,
9 sync::{Arc, Mutex, OnceLock},
10};
11
12pub use auto_di_macros::{
13 application, bean, component, configuration, configuration_properties, qualifier, repository,
14 service, singleton,
15};
16
17pub type DynArc = Arc<dyn Any + Send + Sync>;
18pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
19type Factory =
20 for<'a> fn(&'a Container, ResolutionContext) -> BoxFuture<'a, Result<DynArc, DiError>>;
21type Destroy = fn(DynArc) -> BoxFuture<'static, Result<(), DiError>>;
22
23#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum Scope {
25 Singleton,
26 Prototype,
27 Request,
28}
29
30#[doc(hidden)]
31pub struct ProviderDescriptor {
32 type_id: fn() -> TypeId,
33 type_name: fn() -> &'static str,
34 factory: Factory,
35 pub name: Option<&'static str>,
36 pub primary: bool,
37 pub scope: Scope,
38 pub eager: bool,
39 pub profile: Option<&'static str>,
40 pub condition_key: Option<&'static str>,
41 pub condition_value: Option<&'static str>,
42 pub destroy: Option<Destroy>,
43}
44
45impl ProviderDescriptor {
46 #[doc(hidden)]
47 pub const fn new(
48 type_id: fn() -> TypeId,
49 type_name: fn() -> &'static str,
50 factory: Factory,
51 ) -> Self {
52 Self::configured(
53 type_id,
54 type_name,
55 factory,
56 None,
57 false,
58 Scope::Singleton,
59 false,
60 None,
61 None,
62 None,
63 None,
64 )
65 }
66
67 #[allow(clippy::too_many_arguments)]
68 #[doc(hidden)]
69 pub const fn configured(
70 type_id: fn() -> TypeId,
71 type_name: fn() -> &'static str,
72 factory: Factory,
73 name: Option<&'static str>,
74 primary: bool,
75 scope: Scope,
76 eager: bool,
77 profile: Option<&'static str>,
78 condition_key: Option<&'static str>,
79 condition_value: Option<&'static str>,
80 destroy: Option<Destroy>,
81 ) -> Self {
82 Self {
83 type_id,
84 type_name,
85 factory,
86 name,
87 primary,
88 scope,
89 eager,
90 profile,
91 condition_key,
92 condition_value,
93 destroy,
94 }
95 }
96
97 fn active(&self, profiles: &[String]) -> bool {
98 let profile_matches = self
99 .profile
100 .is_none_or(|required| profiles.iter().any(|p| p == required));
101 let condition_matches = self.condition_key.is_none_or(|key| {
102 let actual = std::env::var(key).ok();
103 self.condition_value.map_or(actual.is_some(), |expected| {
104 actual.as_deref() == Some(expected)
105 })
106 });
107 profile_matches && condition_matches
108 }
109}
110
111inventory::collect!(ProviderDescriptor);
112
113type InstanceMap = HashMap<usize, Arc<tokio::sync::OnceCell<DynArc>>>;
114
115#[derive(Clone, Default)]
116pub struct ResolutionContext {
117 chain: Vec<&'static str>,
118 request_instances: Option<Arc<Mutex<InstanceMap>>>,
119}
120
121#[derive(Debug, thiserror::Error)]
122pub enum DiError {
123 #[error("no active provider is registered for {0}")]
124 MissingProvider(&'static str),
125 #[error("multiple providers match {0}; add a name/qualifier or mark one primary")]
126 AmbiguousProvider(&'static str),
127 #[error("multiple primary providers are registered for {0}")]
128 MultiplePrimary(&'static str),
129 #[error("circular dependency detected: {0}")]
130 CircularDependency(String),
131 #[error("provider for {0} returned an incompatible type")]
132 TypeMismatch(&'static str),
133 #[error("request-scoped bean {0} was resolved outside RequestContext")]
134 RequestScopeUnavailable(&'static str),
135 #[error("configuration property {key} is missing or invalid: {message}")]
136 Configuration { key: String, message: String },
137 #[error("lifecycle hook failed for {0}")]
138 Lifecycle(&'static str),
139}
140
141pub struct Container {
142 providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>>,
143 instances: Mutex<InstanceMap>,
144}
145
146static GLOBAL_CONTAINER: OnceLock<Container> = OnceLock::new();
147
148pub fn global_container() -> Result<&'static Container, DiError> {
149 if let Some(container) = GLOBAL_CONTAINER.get() {
150 return Ok(container);
151 }
152 let container = Container::new()?;
153 let _ = GLOBAL_CONTAINER.set(container);
154 Ok(GLOBAL_CONTAINER
155 .get()
156 .expect("global DI container initialized"))
157}
158
159pub async fn resolve<T>() -> Result<Arc<T>, DiError>
160where
161 T: Any + Send + Sync,
162{
163 global_container()?.resolve::<T>().await
164}
165
166impl Container {
167 pub fn new() -> Result<Self, DiError> {
168 let profiles = std::env::var("APP_PROFILES")
169 .unwrap_or_default()
170 .split(',')
171 .map(str::trim)
172 .filter(|p| !p.is_empty())
173 .map(str::to_owned)
174 .collect::<Vec<_>>();
175 Self::with_profiles(profiles)
176 }
177
178 pub fn with_profiles(
179 profiles: impl IntoIterator<Item = impl Into<String>>,
180 ) -> Result<Self, DiError> {
181 let profiles = profiles.into_iter().map(Into::into).collect::<Vec<_>>();
182 let mut providers: HashMap<TypeId, Vec<&'static ProviderDescriptor>> = HashMap::new();
183 for provider in inventory::iter::<ProviderDescriptor> {
184 if provider.active(&profiles) {
185 providers
186 .entry((provider.type_id)())
187 .or_default()
188 .push(provider);
189 }
190 }
191 for group in providers.values() {
192 if group.iter().filter(|p| p.primary).count() > 1 {
193 return Err(DiError::MultiplePrimary((group[0].type_name)()));
194 }
195 }
196 Ok(Self {
197 providers,
198 instances: Mutex::new(HashMap::new()),
199 })
200 }
201
202 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
203 where
204 T: Any + Send + Sync,
205 {
206 self.resolve_dependency::<T>(&ResolutionContext::default())
207 .await
208 }
209
210 pub async fn resolve_named<T>(&self, name: &str) -> Result<Arc<T>, DiError>
211 where
212 T: Any + Send + Sync,
213 {
214 self.resolve_named_dependency::<T>(name, &ResolutionContext::default())
215 .await
216 }
217
218 pub async fn resolve_optional<T>(&self) -> Result<Option<Arc<T>>, DiError>
219 where
220 T: Any + Send + Sync,
221 {
222 match self.resolve::<T>().await {
223 Ok(value) => Ok(Some(value)),
224 Err(DiError::MissingProvider(_)) => Ok(None),
225 Err(error) => Err(error),
226 }
227 }
228
229 pub async fn resolve_all<T>(&self) -> Result<Vec<Arc<T>>, DiError>
230 where
231 T: Any + Send + Sync,
232 {
233 self.resolve_all_dependency::<T>(&ResolutionContext::default())
234 .await
235 }
236
237 pub fn request_context(&self) -> RequestContext<'_> {
238 RequestContext {
239 container: self,
240 context: ResolutionContext {
241 chain: vec![],
242 request_instances: Some(Arc::new(Mutex::new(HashMap::new()))),
243 },
244 }
245 }
246
247 pub async fn initialize_eager(&self) -> Result<(), DiError> {
248 for providers in self.providers.values() {
249 for provider in providers.iter().filter(|p| p.eager) {
250 self.resolve_provider(provider, ResolutionContext::default())
251 .await?;
252 }
253 }
254 Ok(())
255 }
256
257 pub async fn shutdown(&self) -> Result<(), DiError> {
258 let initialized = self
259 .instances
260 .lock()
261 .expect("DI instance lock poisoned")
262 .iter()
263 .filter_map(|(key, cell)| cell.get().map(|value| (*key, value.clone())))
264 .collect::<Vec<_>>();
265 for providers in self.providers.values() {
266 for provider in providers {
267 let key = provider_key(provider);
268 if let (Some(destroy), Some((_, value))) = (
269 provider.destroy,
270 initialized.iter().find(|(k, _)| *k == key),
271 ) {
272 destroy(value.clone()).await?;
273 }
274 }
275 }
276 Ok(())
277 }
278
279 #[doc(hidden)]
280 pub async fn resolve_dependency<T>(
281 &self,
282 context: &ResolutionContext,
283 ) -> Result<Arc<T>, DiError>
284 where
285 T: Any + Send + Sync,
286 {
287 self.resolve_selected::<T>(None, context).await
288 }
289
290 #[doc(hidden)]
291 pub async fn resolve_named_dependency<T>(
292 &self,
293 name: &str,
294 context: &ResolutionContext,
295 ) -> Result<Arc<T>, DiError>
296 where
297 T: Any + Send + Sync,
298 {
299 self.resolve_selected::<T>(Some(name), context).await
300 }
301
302 #[doc(hidden)]
303 pub async fn resolve_optional_dependency<T>(
304 &self,
305 context: &ResolutionContext,
306 ) -> Result<Option<Arc<T>>, DiError>
307 where
308 T: Any + Send + Sync,
309 {
310 match self.resolve_dependency::<T>(context).await {
311 Ok(value) => Ok(Some(value)),
312 Err(DiError::MissingProvider(_)) => Ok(None),
313 Err(error) => Err(error),
314 }
315 }
316
317 #[doc(hidden)]
318 pub async fn resolve_all_dependency<T>(
319 &self,
320 context: &ResolutionContext,
321 ) -> Result<Vec<Arc<T>>, DiError>
322 where
323 T: Any + Send + Sync,
324 {
325 let Some(providers) = self.providers.get(&TypeId::of::<T>()) else {
326 return Ok(vec![]);
327 };
328 let mut values = Vec::with_capacity(providers.len());
329 for provider in providers {
330 let value = self.resolve_provider(provider, context.clone()).await?;
331 values.push(
332 value
333 .downcast::<T>()
334 .map_err(|_| DiError::TypeMismatch(std::any::type_name::<T>()))?,
335 );
336 }
337 Ok(values)
338 }
339
340 async fn resolve_selected<T>(
341 &self,
342 name: Option<&str>,
343 context: &ResolutionContext,
344 ) -> Result<Arc<T>, DiError>
345 where
346 T: Any + Send + Sync,
347 {
348 let type_name = std::any::type_name::<T>();
349 let providers = self
350 .providers
351 .get(&TypeId::of::<T>())
352 .ok_or(DiError::MissingProvider(type_name))?;
353 let selected = if let Some(name) = name {
354 providers
355 .iter()
356 .copied()
357 .find(|p| p.name == Some(name))
358 .ok_or(DiError::MissingProvider(type_name))?
359 } else if providers.len() == 1 {
360 providers[0]
361 } else {
362 providers
363 .iter()
364 .copied()
365 .find(|p| p.primary)
366 .ok_or(DiError::AmbiguousProvider(type_name))?
367 };
368 let value = self.resolve_provider(selected, context.clone()).await?;
369 value
370 .downcast::<T>()
371 .map_err(|_| DiError::TypeMismatch(type_name))
372 }
373
374 fn resolve_provider<'a>(
375 &'a self,
376 provider: &'static ProviderDescriptor,
377 mut context: ResolutionContext,
378 ) -> BoxFuture<'a, Result<DynArc, DiError>> {
379 Box::pin(async move {
380 let type_name = (provider.type_name)();
381 if context.chain.contains(&type_name) {
382 context.chain.push(type_name);
383 return Err(DiError::CircularDependency(context.chain.join(" -> ")));
384 }
385 context.chain.push(type_name);
386 if provider.scope == Scope::Prototype {
387 return (provider.factory)(self, context).await;
388 }
389 let map = match provider.scope {
390 Scope::Singleton => &self.instances,
391 Scope::Request => context
392 .request_instances
393 .as_deref()
394 .ok_or(DiError::RequestScopeUnavailable(type_name))?,
395 Scope::Prototype => unreachable!(),
396 };
397 let cell = {
398 let mut instances = map.lock().expect("DI instance lock poisoned");
399 instances
400 .entry(provider_key(provider))
401 .or_insert_with(|| Arc::new(tokio::sync::OnceCell::new()))
402 .clone()
403 };
404 let value = cell
405 .get_or_try_init(|| async move { (provider.factory)(self, context).await })
406 .await?;
407 Ok(value.clone())
408 })
409 }
410}
411
412fn provider_key(provider: &'static ProviderDescriptor) -> usize {
413 provider as *const ProviderDescriptor as usize
414}
415
416pub struct RequestContext<'a> {
417 container: &'a Container,
418 context: ResolutionContext,
419}
420impl RequestContext<'_> {
421 pub async fn resolve<T>(&self) -> Result<Arc<T>, DiError>
422 where
423 T: Any + Send + Sync,
424 {
425 self.container.resolve_dependency::<T>(&self.context).await
426 }
427}
428
429pub struct Provider<T> {
430 _marker: PhantomData<fn() -> T>,
431}
432impl<T> Clone for Provider<T> {
433 fn clone(&self) -> Self {
434 Self {
435 _marker: PhantomData,
436 }
437 }
438}
439impl<T> Default for Provider<T> {
440 fn default() -> Self {
441 Self {
442 _marker: PhantomData,
443 }
444 }
445}
446impl<T> Provider<T>
447where
448 T: Any + Send + Sync,
449{
450 pub async fn get(&self) -> Result<Arc<T>, DiError> {
451 resolve::<T>().await
452 }
453}
454
455pub struct Lazy<T> {
456 cell: tokio::sync::OnceCell<Arc<T>>,
457}
458impl<T> Default for Lazy<T> {
459 fn default() -> Self {
460 Self {
461 cell: tokio::sync::OnceCell::new(),
462 }
463 }
464}
465impl<T> Lazy<T>
466where
467 T: Any + Send + Sync,
468{
469 pub async fn get(&self) -> Result<&Arc<T>, DiError> {
470 self.cell
471 .get_or_try_init(|| async { resolve::<T>().await })
472 .await
473 }
474}
475
476pub trait ConfigurationProperties: Sized {
477 fn from_environment() -> Result<Self, DiError>;
478}
479
480#[doc(hidden)]
481pub mod __private {
482 pub use inventory;
483 pub use tokio;
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use std::sync::atomic::{AtomicUsize, Ordering};
490 static CREATIONS: AtomicUsize = AtomicUsize::new(0);
491 struct SyncDependency;
492 #[singleton]
493 fn sync_dependency() -> SyncDependency {
494 CREATIONS.fetch_add(1, Ordering::SeqCst);
495 SyncDependency
496 }
497 struct AsyncDependency {
498 _sync: Arc<SyncDependency>,
499 }
500 #[singleton]
501 async fn async_dependency(sync: Arc<SyncDependency>) -> AsyncDependency {
502 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
503 AsyncDependency { _sync: sync }
504 }
505 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
506 async fn singleton_is_concurrent_safe() {
507 let container = Arc::new(Container::new().unwrap());
508 let mut tasks = tokio::task::JoinSet::new();
509 for _ in 0..32 {
510 let c = container.clone();
511 tasks.spawn(async move { c.resolve::<AsyncDependency>().await.unwrap() });
512 }
513 let mut values = vec![];
514 while let Some(v) = tasks.join_next().await {
515 values.push(v.unwrap());
516 }
517 assert!(values.iter().all(|v| Arc::ptr_eq(&values[0], v)));
518 assert_eq!(CREATIONS.load(Ordering::SeqCst), 1);
519 }
520
521 static PROTOTYPES: AtomicUsize = AtomicUsize::new(0);
522 struct PrototypeBean(usize);
523 #[component(scope = "prototype")]
524 fn prototype_bean() -> PrototypeBean {
525 PrototypeBean(PROTOTYPES.fetch_add(1, Ordering::SeqCst))
526 }
527
528 struct RequestBean;
529 #[component(scope = "request")]
530 fn request_bean() -> RequestBean {
531 RequestBean
532 }
533
534 trait Greeting: Send + Sync {
535 fn text(&self) -> &'static str;
536 }
537 struct English;
538 impl Greeting for English {
539 fn text(&self) -> &'static str {
540 "hello"
541 }
542 }
543 struct Hindi;
544 impl Greeting for Hindi {
545 fn text(&self) -> &'static str {
546 "namaste"
547 }
548 }
549
550 #[repository(name = "english", primary)]
551 fn english_greeting() -> Arc<dyn Greeting> {
552 Arc::new(English)
553 }
554 #[repository(name = "hindi")]
555 fn hindi_greeting() -> Arc<dyn Greeting> {
556 Arc::new(Hindi)
557 }
558
559 struct MissingOptional;
560 struct Greeter {
561 greeting: Arc<dyn Greeting>,
562 optional: Option<Arc<MissingOptional>>,
563 }
564 #[service]
565 impl Greeter {
566 fn new(greeting: Arc<dyn Greeting>, optional: Option<Arc<MissingOptional>>) -> Self {
567 Self { greeting, optional }
568 }
569 }
570 struct QualifiedGreeter(Arc<dyn Greeting>);
571 #[service]
572 impl QualifiedGreeter {
573 fn new(#[qualifier("hindi")] greeting: Arc<dyn Greeting>) -> Self {
574 Self(greeting)
575 }
576 }
577
578 static STARTED: AtomicUsize = AtomicUsize::new(0);
579 static STOPPED: AtomicUsize = AtomicUsize::new(0);
580 struct Managed;
581 #[service(eager, post_construct = "start", pre_destroy = "stop")]
582 impl Managed {
583 fn new() -> Self {
584 Self
585 }
586 async fn start(&self) {
587 STARTED.fetch_add(1, Ordering::SeqCst);
588 }
589 async fn stop(&self) {
590 STOPPED.fetch_add(1, Ordering::SeqCst);
591 }
592 }
593
594 struct ProfileBean;
595 #[component(profile = "test")]
596 fn profile_bean() -> ProfileBean {
597 ProfileBean
598 }
599
600 #[derive(Debug)]
601 struct Handler(&'static str);
602 #[component(name = "first")]
603 fn first_handler() -> Handler {
604 Handler("first")
605 }
606 #[component(name = "second")]
607 fn second_handler() -> Handler {
608 Handler("second")
609 }
610 struct Pipeline(Vec<Arc<Handler>>);
611 #[service]
612 impl Pipeline {
613 fn new(handlers: Vec<Arc<Handler>>) -> Self {
614 Self(handlers)
615 }
616 }
617
618 #[configuration_properties("testing_dep")]
619 struct TestProperties {
620 port: u16,
621 }
622
623 struct DeferredTarget;
624 #[component]
625 fn deferred_target() -> DeferredTarget {
626 DeferredTarget
627 }
628 struct DeferredConsumer {
629 provider: Provider<DeferredTarget>,
630 lazy: Lazy<DeferredTarget>,
631 }
632 #[service]
633 impl DeferredConsumer {
634 fn new(provider: Provider<DeferredTarget>, lazy: Lazy<DeferredTarget>) -> Self {
635 Self { provider, lazy }
636 }
637 }
638
639 struct ConditionalBean;
640 #[component(condition = "TESTING_DEP_FEATURE=enabled")]
641 fn conditional_bean() -> ConditionalBean {
642 ConditionalBean
643 }
644
645 #[tokio::test]
646 async fn scopes_traits_primary_profiles_and_lifecycle_work() {
647 unsafe { std::env::set_var("TESTING_DEP_FEATURE", "enabled") };
648 let container = Container::with_profiles(["test"]).unwrap();
649 let first = container.resolve::<PrototypeBean>().await.unwrap();
650 let second = container.resolve::<PrototypeBean>().await.unwrap();
651 assert_ne!(first.0, second.0);
652
653 assert!(matches!(
654 container.resolve::<RequestBean>().await,
655 Err(DiError::RequestScopeUnavailable(_))
656 ));
657 let request = container.request_context();
658 let request_first = request.resolve::<RequestBean>().await.unwrap();
659 let request_second = request.resolve::<RequestBean>().await.unwrap();
660 assert!(Arc::ptr_eq(&request_first, &request_second));
661
662 let greeter = container.resolve::<Greeter>().await.unwrap();
663 assert_eq!(greeter.greeting.text(), "hello");
664 assert!(greeter.optional.is_none());
665 assert_eq!(
666 container
667 .resolve::<QualifiedGreeter>()
668 .await
669 .unwrap()
670 .0
671 .text(),
672 "namaste"
673 );
674 let hindi = container
675 .resolve_named::<Arc<dyn Greeting>>("hindi")
676 .await
677 .unwrap();
678 assert_eq!(hindi.text(), "namaste");
679 container.resolve::<ProfileBean>().await.unwrap();
680 let pipeline = container.resolve::<Pipeline>().await.unwrap();
681 let mut handler_names = pipeline.0.iter().map(|h| h.0).collect::<Vec<_>>();
682 handler_names.sort_unstable();
683 assert_eq!(handler_names, ["first", "second"]);
684
685 unsafe { std::env::set_var("TESTING_DEP_PORT", "8080") };
687 let properties = container.resolve::<TestProperties>().await.unwrap();
688 assert_eq!(properties.port, 8080);
689 container.resolve::<ConditionalBean>().await.unwrap();
690
691 let deferred = container.resolve::<DeferredConsumer>().await.unwrap();
692 deferred.provider.get().await.unwrap();
693 deferred.lazy.get().await.unwrap();
694
695 container.initialize_eager().await.unwrap();
696 assert_eq!(STARTED.load(Ordering::SeqCst), 1);
697 container.shutdown().await.unwrap();
698 assert_eq!(STOPPED.load(Ordering::SeqCst), 1);
699 }
700}