1use crate::{DiResult, Injectable, InjectionContext};
8use async_trait::async_trait;
9use dashmap::DashMap;
10use std::any::{Any, TypeId};
11use std::future::Future;
12use std::marker::PhantomData;
13use std::sync::{Arc, OnceLock};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum DependencyScope {
18 #[default]
20 Singleton,
21 Request,
23 Transient,
25}
26
27#[async_trait]
32pub trait FactoryTrait: Send + Sync {
33 async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>>;
35}
36
37pub struct AsyncFactory<F, Fut, T>
39where
40 F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
41 Fut: Future<Output = DiResult<T>> + Send,
42 T: Any + Send + Sync + 'static,
43{
44 factory: F,
45 _phantom: std::marker::PhantomData<fn() -> (Fut, T)>,
46}
47
48impl<F, Fut, T> AsyncFactory<F, Fut, T>
49where
50 F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
51 Fut: Future<Output = DiResult<T>> + Send,
52 T: Any + Send + Sync + 'static,
53{
54 pub fn new(factory: F) -> Self {
56 Self {
57 factory,
58 _phantom: std::marker::PhantomData,
59 }
60 }
61}
62
63#[async_trait]
64impl<F, Fut, T> FactoryTrait for AsyncFactory<F, Fut, T>
65where
66 F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
67 Fut: Future<Output = DiResult<T>> + Send + 'static,
68 T: Any + Send + Sync + 'static,
69{
70 async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>> {
71 let ctx_arc = Arc::new(ctx.clone());
72 let instance = (self.factory)(ctx_arc).await?;
73 Ok(Arc::new(instance))
74 }
75}
76
77pub struct InjectableFactory<T>(PhantomData<T>);
83
84impl<T> Default for InjectableFactory<T> {
85 fn default() -> Self {
86 Self(PhantomData)
87 }
88}
89
90impl<T> InjectableFactory<T> {
91 pub fn new() -> Self {
93 Self::default()
94 }
95}
96
97#[async_trait]
98impl<T: Injectable + Any + Send + Sync + 'static> FactoryTrait for InjectableFactory<T> {
99 async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>> {
100 let ctx_arc = Arc::new(ctx.clone());
103 let resolve_ctx = crate::resolve_context::ResolveContext {
104 root: crate::resolve_context::RESOLVE_CTX
105 .try_with(|outer| Arc::clone(&outer.root))
106 .unwrap_or_else(|_| Arc::clone(&ctx_arc)),
107 current: Arc::clone(&ctx_arc),
108 };
109
110 let value = crate::resolve_context::RESOLVE_CTX
111 .scope(resolve_ctx, T::inject(ctx))
112 .await?;
113 Ok(Arc::new(value))
114 }
115}
116
117type BoxedFactory = Box<dyn FactoryTrait>;
119
120pub struct DependencyRegistry {
125 factories: DashMap<TypeId, BoxedFactory>,
126 scopes: DashMap<TypeId, DependencyScope>,
127 dependencies: DashMap<TypeId, Vec<TypeId>>,
129 type_names: DashMap<TypeId, &'static str>,
131 qualified_type_names: DashMap<TypeId, &'static str>,
134}
135
136impl DependencyRegistry {
137 pub fn new() -> Self {
139 Self {
140 factories: DashMap::new(),
141 scopes: DashMap::new(),
142 dependencies: DashMap::new(),
143 type_names: DashMap::new(),
144 qualified_type_names: DashMap::new(),
145 }
146 }
147
148 pub fn register<T: Any + Send + Sync + 'static>(
162 &self,
163 scope: DependencyScope,
164 factory: impl FactoryTrait + 'static,
165 ) {
166 let type_id = TypeId::of::<T>();
167 let type_name = std::any::type_name::<T>();
168 if self.factories.contains_key(&type_id) {
173 let short = type_name.rsplit("::").next().unwrap_or(type_name);
174 panic!(
175 "Duplicate DependencyRegistry registration for type `{type_name}`.\n\
176\n\
177Hint: reinhardt DI uses TypeId as the sole registry key. Two factories\n\
178returning the same type will conflict regardless of function name or scope.\n\
179Use a distinct newtype (e.g., `struct Primary{short}({short})`) for each."
180 );
181 }
182 self.factories.insert(type_id, Box::new(factory));
183 self.scopes.insert(type_id, scope);
184 }
185
186 pub fn register_async<T, F, Fut>(&self, scope: DependencyScope, factory: F)
188 where
189 T: Any + Send + Sync + 'static,
190 F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync + 'static,
191 Fut: Future<Output = DiResult<T>> + Send + 'static,
192 {
193 self.register::<T>(scope, AsyncFactory::new(factory));
194 }
195
196 pub fn get_scope<T: Any + 'static>(&self) -> Option<DependencyScope> {
198 let type_id = TypeId::of::<T>();
199 self.scopes.get(&type_id).map(|entry| *entry.value())
200 }
201
202 pub fn is_registered<T: Any + 'static>(&self) -> bool {
204 let type_id = TypeId::of::<T>();
205 self.factories.contains_key(&type_id)
206 }
207
208 pub fn len(&self) -> usize {
210 self.factories.len()
211 }
212
213 pub fn is_empty(&self) -> bool {
215 self.factories.is_empty()
216 }
217
218 pub async fn create<T: Any + Send + Sync + 'static>(
220 &self,
221 ctx: &InjectionContext,
222 ) -> DiResult<Arc<T>> {
223 let type_id = TypeId::of::<T>();
224
225 let factory = self.factories.get(&type_id).ok_or_else(|| {
226 crate::DiError::DependencyNotRegistered {
227 type_name: std::any::type_name::<T>().to_string(),
228 }
229 })?;
230
231 let any_arc = factory.create(ctx).await?;
232
233 any_arc
234 .downcast::<T>()
235 .map_err(|_| crate::DiError::Internal {
236 message: format!(
237 "Failed to downcast dependency: expected {}, got different type",
238 std::any::type_name::<T>()
239 ),
240 })
241 }
242
243 pub fn get_dependencies(&self, type_id: TypeId) -> Vec<TypeId> {
247 self.dependencies
248 .get(&type_id)
249 .map(|deps| deps.value().clone())
250 .unwrap_or_default()
251 }
252
253 pub fn get_all_dependencies(&self) -> std::collections::HashMap<TypeId, Vec<TypeId>> {
257 self.dependencies
258 .iter()
259 .map(|entry| (*entry.key(), entry.value().clone()))
260 .collect()
261 }
262
263 pub fn get_type_names(&self) -> std::collections::HashMap<TypeId, &'static str> {
267 self.type_names
268 .iter()
269 .map(|entry| (*entry.key(), *entry.value()))
270 .collect()
271 }
272
273 #[doc(hidden)]
278 pub fn register_dependencies(&self, type_id: TypeId, deps: impl AsRef<[TypeId]>) {
279 self.dependencies.insert(type_id, deps.as_ref().to_vec());
280 }
281
282 #[doc(hidden)]
287 pub fn register_type_name(&self, type_id: TypeId, type_name: &'static str) {
288 self.type_names.insert(type_id, type_name);
289 }
290
291 pub(crate) fn is_registered_by_id(&self, type_id: TypeId) -> bool {
293 self.factories.contains_key(&type_id)
294 }
295
296 pub(crate) fn get_scope_by_id(&self, type_id: TypeId) -> Option<DependencyScope> {
298 self.scopes.get(&type_id).map(|entry| *entry.value())
299 }
300
301 pub(crate) fn get_type_name(&self, type_id: TypeId) -> Option<&'static str> {
303 self.type_names.get(&type_id).map(|entry| *entry.value())
304 }
305
306 #[doc(hidden)]
310 pub fn register_qualified_type_name(&self, type_id: TypeId, qualified_name: &'static str) {
311 self.qualified_type_names.insert(type_id, qualified_name);
312 }
313
314 pub fn get_qualified_type_name(&self, type_id: &TypeId) -> Option<&'static str> {
316 self.qualified_type_names.get(type_id).map(|r| *r.value())
317 }
318
319 pub fn iter_qualified_type_names(&self) -> impl Iterator<Item = (TypeId, &'static str)> + '_ {
321 self.qualified_type_names
322 .iter()
323 .map(|entry| (*entry.key(), *entry.value()))
324 }
325}
326
327#[cfg(feature = "testing")]
328impl DependencyRegistry {
329 pub fn register_override<T, F, Fut>(
362 self: &std::sync::Arc<Self>,
363 scope: crate::registry::DependencyScope,
364 factory: F,
365 ) -> crate::testing::OverrideGuard
366 where
367 T: std::any::Any + Send + Sync + 'static,
368 F: Fn(std::sync::Arc<crate::InjectionContext>) -> Fut + Send + Sync + 'static,
369 Fut: std::future::Future<Output = crate::DiResult<T>> + Send + 'static,
370 {
371 let type_id = std::any::TypeId::of::<T>();
372 let async_factory = crate::registry::AsyncFactory::new(factory);
373 let boxed: Box<dyn crate::registry::FactoryTrait> = Box::new(async_factory);
374
375 let previous_factory = self.factories.insert(type_id, boxed);
377 let previous_scope = self.scopes.insert(type_id, scope);
378
379 debug_assert!(
385 previous_factory.is_some() == previous_scope.is_some(),
386 "torn override state: factories/scopes diverged for `{}`",
387 std::any::type_name::<T>()
388 );
389 let previous = match (previous_factory, previous_scope) {
390 (Some(f), Some(s)) => Some((f, s)),
391 _ => None,
392 };
393
394 crate::testing::OverrideGuard {
395 type_id,
396 previous,
397 registry: std::sync::Arc::downgrade(self),
398 }
399 }
400
401 pub(crate) fn restore_override(
409 &self,
410 type_id: std::any::TypeId,
411 factory: Box<dyn crate::registry::FactoryTrait>,
412 scope: crate::registry::DependencyScope,
413 ) {
414 self.factories.insert(type_id, factory);
415 self.scopes.insert(type_id, scope);
416 }
417
418 pub(crate) fn remove_override(&self, type_id: std::any::TypeId) {
422 self.factories.remove(&type_id);
423 self.scopes.remove(&type_id);
424 }
425}
426
427impl Default for DependencyRegistry {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433static GLOBAL_REGISTRY: OnceLock<Arc<DependencyRegistry>> = OnceLock::new();
435
436pub fn global_registry() -> &'static Arc<DependencyRegistry> {
438 GLOBAL_REGISTRY.get_or_init(|| {
439 let registry = Arc::new(DependencyRegistry::new());
440 initialize_registry(®istry);
441 registry
442 })
443}
444
445#[cfg(test)]
456pub fn reset_global_registry() {
457 unsafe {
461 let ptr = std::ptr::addr_of!(GLOBAL_REGISTRY) as *mut OnceLock<Arc<DependencyRegistry>>;
462 std::ptr::write(ptr, OnceLock::new());
463 }
464}
465
466pub struct DependencyRegistration {
468 pub type_id: TypeId,
470 pub type_name: &'static str,
472 pub scope: DependencyScope,
474 pub dependencies: &'static [TypeId],
476 pub register_fn: fn(&DependencyRegistry),
478}
479
480impl DependencyRegistration {
481 pub const fn new<T: Send + Sync + 'static>(
483 type_name: &'static str,
484 scope: DependencyScope,
485 register_fn: fn(&DependencyRegistry),
486 ) -> Self {
487 Self {
488 type_id: TypeId::of::<T>(),
489 type_name,
490 scope,
491 dependencies: &[],
492 register_fn,
493 }
494 }
495
496 pub const fn new_with_deps<T: Send + Sync + 'static>(
498 type_name: &'static str,
499 scope: DependencyScope,
500 dependencies: &'static [TypeId],
501 register_fn: fn(&DependencyRegistry),
502 ) -> Self {
503 Self {
504 type_id: TypeId::of::<T>(),
505 type_name,
506 scope,
507 dependencies,
508 register_fn,
509 }
510 }
511}
512
513inventory::collect!(DependencyRegistration);
515
516pub struct InjectableRegistration {
522 pub register_fn: fn(&DependencyRegistry),
524}
525
526impl InjectableRegistration {
527 pub const fn new(register_fn: fn(&DependencyRegistry)) -> Self {
529 Self { register_fn }
530 }
531}
532
533inventory::collect!(InjectableRegistration);
534
535fn initialize_registry(registry: &DependencyRegistry) {
537 for registration in inventory::iter::<DependencyRegistration> {
538 (registration.register_fn)(registry);
539 }
540 for registration in inventory::iter::<InjectableRegistration> {
541 (registration.register_fn)(registry);
542 }
543}
544
545#[macro_export]
549macro_rules! submit_registration {
550 ($registration:expr) => {
551 $crate::inventory::submit! {
552 $registration
553 }
554 };
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::scope::SingletonScope;
561 use rstest::*;
562
563 #[derive(Clone)]
564 struct TestService {
565 value: i32,
566 }
567
568 #[rstest]
569 #[tokio::test]
570 async fn test_registry_basic() {
571 let registry = DependencyRegistry::new();
572
573 registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
574 Ok(TestService { value: 42 })
575 });
576
577 assert!(registry.is_registered::<TestService>());
578 assert_eq!(
579 registry.get_scope::<TestService>(),
580 Some(DependencyScope::Singleton)
581 );
582
583 let singleton_scope = Arc::new(SingletonScope::new());
584 let ctx = InjectionContext::builder(singleton_scope).build();
585
586 let service = registry.create::<TestService>(&ctx).await.unwrap();
587 assert_eq!(service.value, 42);
588 }
589
590 #[rstest]
591 #[tokio::test]
592 async fn test_registry_not_registered() {
593 let registry = DependencyRegistry::new();
594 let singleton_scope = Arc::new(SingletonScope::new());
595 let ctx = InjectionContext::builder(singleton_scope).build();
596
597 let result = registry.create::<TestService>(&ctx).await;
598 assert!(result.is_err());
599 }
600
601 #[rstest]
603 fn test_duplicate_registration_panics() {
604 let registry = DependencyRegistry::new();
605
606 registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
607 Ok(TestService { value: 1 })
608 });
609
610 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
612 registry.register_async::<TestService, _, _>(DependencyScope::Request, |_ctx| async {
613 Ok(TestService { value: 2 })
614 });
615 }));
616
617 let err = result.expect_err("expected panic on duplicate registration");
619 let msg = err
620 .downcast_ref::<String>()
621 .map(|s| s.as_str())
622 .or_else(|| err.downcast_ref::<&str>().copied())
623 .expect("panic payload should be a string");
624 assert!(
625 msg.contains("Duplicate DependencyRegistry registration"),
626 "missing duplicate prefix: {msg}"
627 );
628 assert!(
629 msg.contains("TestService"),
630 "missing type name in panic message: {msg}"
631 );
632 assert!(
633 msg.contains("newtype"),
634 "missing newtype hint in panic message: {msg}"
635 );
636 }
637
638 #[rstest]
640 fn test_is_registered_guard_allows_skip() {
641 let registry = DependencyRegistry::new();
642
643 registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
644 Ok(TestService { value: 1 })
645 });
646
647 if !registry.is_registered::<TestService>() {
649 registry.register_async::<TestService, _, _>(DependencyScope::Request, |_ctx| async {
650 Ok(TestService { value: 2 })
651 });
652 }
653
654 assert!(registry.is_registered::<TestService>());
655 }
656}