Skip to main content

reinhardt_di/
registry.rs

1//! Global dependency registry for FastAPI-style dependency injection
2//!
3//! This module provides a global registry that stores factory functions for creating
4//! dependencies. It uses the `inventory` crate to collect registrations at compile time
5//! and build a runtime registry that can be queried by type.
6
7use crate::{DiResult, Injectable, InjectionContext};
8use async_trait::async_trait;
9use dashmap::DashMap;
10use std::any::{Any, TypeId};
11use std::future::Future;
12use std::marker::PhantomData;
13use std::sync::{Arc, OnceLock};
14
15/// Scope for dependency injection
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum DependencyScope {
18	/// Single instance shared across the entire application
19	#[default]
20	Singleton,
21	/// New instance per request, cached within the request
22	Request,
23	/// New instance every time, never cached
24	Transient,
25}
26
27/// Factory trait for creating dependencies
28///
29/// Factories are async functions that can resolve dependencies from an InjectionContext
30/// and return a type-erased `Arc<dyn Any>`.
31#[async_trait]
32pub trait FactoryTrait: Send + Sync {
33	/// Create an instance of the dependency
34	async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>>;
35}
36
37/// Wrapper for async factory functions
38pub struct AsyncFactory<F, Fut, T>
39where
40	F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
41	Fut: Future<Output = DiResult<T>> + Send,
42	T: Any + Send + Sync + 'static,
43{
44	factory: F,
45	_phantom: std::marker::PhantomData<fn() -> (Fut, T)>,
46}
47
48impl<F, Fut, T> AsyncFactory<F, Fut, T>
49where
50	F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
51	Fut: Future<Output = DiResult<T>> + Send,
52	T: Any + Send + Sync + 'static,
53{
54	/// Creates a new async factory from the given closure.
55	pub fn new(factory: F) -> Self {
56		Self {
57			factory,
58			_phantom: std::marker::PhantomData,
59		}
60	}
61}
62
63#[async_trait]
64impl<F, Fut, T> FactoryTrait for AsyncFactory<F, Fut, T>
65where
66	F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync,
67	Fut: Future<Output = DiResult<T>> + Send + 'static,
68	T: Any + Send + Sync + 'static,
69{
70	async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>> {
71		let ctx_arc = Arc::new(ctx.clone());
72		let instance = (self.factory)(ctx_arc).await?;
73		Ok(Arc::new(instance))
74	}
75}
76
77/// Factory that creates instances via the `Injectable` trait.
78///
79/// Bypasses `AsyncFactory`'s `Fut: Sync` bound by implementing `FactoryTrait`
80/// directly. This is necessary because `Injectable::inject` uses `async_trait`,
81/// which returns `Pin<Box<dyn Future + Send>>` (not `Sync`).
82pub struct InjectableFactory<T>(PhantomData<T>);
83
84impl<T> Default for InjectableFactory<T> {
85	fn default() -> Self {
86		Self(PhantomData)
87	}
88}
89
90impl<T> InjectableFactory<T> {
91	/// Create a new `InjectableFactory`.
92	pub fn new() -> Self {
93		Self::default()
94	}
95}
96
97#[async_trait]
98impl<T: Injectable + Any + Send + Sync + 'static> FactoryTrait for InjectableFactory<T> {
99	async fn create(&self, ctx: &InjectionContext) -> DiResult<Arc<dyn Any + Send + Sync>> {
100		// Set task-local resolve context for get_di_context() access.
101		// Since we only have &InjectionContext, clone into Arc (same pattern as AsyncFactory).
102		let ctx_arc = Arc::new(ctx.clone());
103		let resolve_ctx = crate::resolve_context::ResolveContext {
104			root: crate::resolve_context::RESOLVE_CTX
105				.try_with(|outer| Arc::clone(&outer.root))
106				.unwrap_or_else(|_| Arc::clone(&ctx_arc)),
107			current: Arc::clone(&ctx_arc),
108		};
109
110		let value = crate::resolve_context::RESOLVE_CTX
111			.scope(resolve_ctx, T::inject(ctx))
112			.await?;
113		Ok(Arc::new(value))
114	}
115}
116
117/// Type-erased factory function
118type BoxedFactory = Box<dyn FactoryTrait>;
119
120/// Global dependency registry
121///
122/// Stores factory functions for each type, along with their scope information.
123/// Uses DashMap for thread-safe concurrent access without blocking.
124pub struct DependencyRegistry {
125	factories: DashMap<TypeId, BoxedFactory>,
126	scopes: DashMap<TypeId, DependencyScope>,
127	/// Maps type ID to its direct dependencies
128	dependencies: DashMap<TypeId, Vec<TypeId>>,
129	/// Maps type ID to its type name for debugging
130	type_names: DashMap<TypeId, &'static str>,
131	/// Maps type ID to its fully-qualified type name from `std::any::type_name`.
132	/// Used for framework type detection (pseudo orphan rule).
133	qualified_type_names: DashMap<TypeId, &'static str>,
134}
135
136impl DependencyRegistry {
137	/// Create a new empty registry
138	pub fn new() -> Self {
139		Self {
140			factories: DashMap::new(),
141			scopes: DashMap::new(),
142			dependencies: DashMap::new(),
143			type_names: DashMap::new(),
144			qualified_type_names: DashMap::new(),
145		}
146	}
147
148	/// Register a factory for a type.
149	///
150	/// # Panics
151	///
152	/// Panics if a factory for the same `TypeId` is already registered.
153	/// This prevents silent overwrites that lead to non-deterministic behavior
154	/// when multiple `#[injectable_factory]` or `#[injectable]` macros produce
155	/// the same return type. See [#3457].
156	///
157	/// To check before registering (e.g. in tests), use
158	/// [`is_registered`](Self::is_registered).
159	///
160	/// [#3457]: https://github.com/kent8192/reinhardt-web/issues/3457
161	pub fn register<T: Any + Send + Sync + 'static>(
162		&self,
163		scope: DependencyScope,
164		factory: impl FactoryTrait + 'static,
165	) {
166		let type_id = TypeId::of::<T>();
167		let type_name = std::any::type_name::<T>();
168		// Check for duplicates before inserting so that no state is mutated on the
169		// error path. This avoids leaving the registry inconsistent if the panic is
170		// caught (e.g. factories pointing to the new registration while scopes still
171		// reflects the old one).
172		if self.factories.contains_key(&type_id) {
173			let short = type_name.rsplit("::").next().unwrap_or(type_name);
174			panic!(
175				"Duplicate DependencyRegistry registration for type `{type_name}`.\n\
176\n\
177Hint: reinhardt DI uses TypeId as the sole registry key. Two factories\n\
178returning the same type will conflict regardless of function name or scope.\n\
179Use a distinct newtype (e.g., `struct Primary{short}({short})`) for each."
180			);
181		}
182		self.factories.insert(type_id, Box::new(factory));
183		self.scopes.insert(type_id, scope);
184	}
185
186	/// Register a simple async factory function
187	pub fn register_async<T, F, Fut>(&self, scope: DependencyScope, factory: F)
188	where
189		T: Any + Send + Sync + 'static,
190		F: Fn(Arc<InjectionContext>) -> Fut + Send + Sync + 'static,
191		Fut: Future<Output = DiResult<T>> + Send + 'static,
192	{
193		self.register::<T>(scope, AsyncFactory::new(factory));
194	}
195
196	/// Get the scope for a type
197	pub fn get_scope<T: Any + 'static>(&self) -> Option<DependencyScope> {
198		let type_id = TypeId::of::<T>();
199		self.scopes.get(&type_id).map(|entry| *entry.value())
200	}
201
202	/// Check if a type is registered
203	pub fn is_registered<T: Any + 'static>(&self) -> bool {
204		let type_id = TypeId::of::<T>();
205		self.factories.contains_key(&type_id)
206	}
207
208	/// Get the number of registered dependencies
209	pub fn len(&self) -> usize {
210		self.factories.len()
211	}
212
213	/// Check if the registry is empty
214	pub fn is_empty(&self) -> bool {
215		self.factories.is_empty()
216	}
217
218	/// Create an instance using the registered factory
219	pub async fn create<T: Any + Send + Sync + 'static>(
220		&self,
221		ctx: &InjectionContext,
222	) -> DiResult<Arc<T>> {
223		let type_id = TypeId::of::<T>();
224
225		let factory = self.factories.get(&type_id).ok_or_else(|| {
226			crate::DiError::DependencyNotRegistered {
227				type_name: std::any::type_name::<T>().to_string(),
228			}
229		})?;
230
231		let any_arc = factory.create(ctx).await?;
232
233		any_arc
234			.downcast::<T>()
235			.map_err(|_| crate::DiError::Internal {
236				message: format!(
237					"Failed to downcast dependency: expected {}, got different type",
238					std::any::type_name::<T>()
239				),
240			})
241	}
242
243	/// Get the direct dependencies of a type
244	///
245	/// Returns a vector of TypeIds representing the types that the given type directly depends on.
246	pub fn get_dependencies(&self, type_id: TypeId) -> Vec<TypeId> {
247		self.dependencies
248			.get(&type_id)
249			.map(|deps| deps.value().clone())
250			.unwrap_or_default()
251	}
252
253	/// Get all dependencies in the registry
254	///
255	/// Returns a HashMap mapping each type to its direct dependencies.
256	pub fn get_all_dependencies(&self) -> std::collections::HashMap<TypeId, Vec<TypeId>> {
257		self.dependencies
258			.iter()
259			.map(|entry| (*entry.key(), entry.value().clone()))
260			.collect()
261	}
262
263	/// Get all type names in the registry
264	///
265	/// Returns a HashMap mapping TypeIds to their human-readable type names.
266	pub fn get_type_names(&self) -> std::collections::HashMap<TypeId, &'static str> {
267		self.type_names
268			.iter()
269			.map(|entry| (*entry.key(), *entry.value()))
270			.collect()
271	}
272
273	/// Register dependencies for a type
274	///
275	/// This is typically called automatically by the registration system.
276	/// Not intended for direct use; exposed for macro-generated code.
277	#[doc(hidden)]
278	pub fn register_dependencies(&self, type_id: TypeId, deps: impl AsRef<[TypeId]>) {
279		self.dependencies.insert(type_id, deps.as_ref().to_vec());
280	}
281
282	/// Register a type name for debugging
283	///
284	/// This is typically called automatically by the registration system.
285	/// Not intended for direct use; exposed for macro-generated code.
286	#[doc(hidden)]
287	pub fn register_type_name(&self, type_id: TypeId, type_name: &'static str) {
288		self.type_names.insert(type_id, type_name);
289	}
290
291	/// Check if a type is registered by its `TypeId`.
292	pub(crate) fn is_registered_by_id(&self, type_id: TypeId) -> bool {
293		self.factories.contains_key(&type_id)
294	}
295
296	/// Get the scope for a type by its `TypeId`.
297	pub(crate) fn get_scope_by_id(&self, type_id: TypeId) -> Option<DependencyScope> {
298		self.scopes.get(&type_id).map(|entry| *entry.value())
299	}
300
301	/// Get the type name for a `TypeId`.
302	pub(crate) fn get_type_name(&self, type_id: TypeId) -> Option<&'static str> {
303		self.type_names.get(&type_id).map(|entry| *entry.value())
304	}
305
306	/// Register the fully-qualified type name obtained from `std::any::type_name::<T>()`.
307	///
308	/// Used by the pseudo orphan rule to detect framework-managed types.
309	#[doc(hidden)]
310	pub fn register_qualified_type_name(&self, type_id: TypeId, qualified_name: &'static str) {
311		self.qualified_type_names.insert(type_id, qualified_name);
312	}
313
314	/// Get the fully-qualified type name for a given `TypeId`.
315	pub fn get_qualified_type_name(&self, type_id: &TypeId) -> Option<&'static str> {
316		self.qualified_type_names.get(type_id).map(|r| *r.value())
317	}
318
319	/// Iterate over all qualified type name mappings without allocating a new map.
320	pub fn iter_qualified_type_names(&self) -> impl Iterator<Item = (TypeId, &'static str)> + '_ {
321		self.qualified_type_names
322			.iter()
323			.map(|entry| (*entry.key(), *entry.value()))
324	}
325}
326
327#[cfg(feature = "testing")]
328impl DependencyRegistry {
329	/// Registers or overrides a factory for type `T` without panicking on
330	/// duplicate registration.
331	///
332	/// Returns an [`OverrideGuard`](crate::testing::OverrideGuard) that
333	/// restores the previous factory (or removes the entry entirely if there
334	/// was none) when dropped.
335	///
336	/// Tests using this method **must** run inside the
337	/// `#[serial(di_registry)]` group because the global registry is mutated.
338	///
339	/// # Safety contract
340	///
341	/// This method inserts into two separate `DashMap`s (`factories` and
342	/// `scopes`) non-atomically. Without serialization, another thread can
343	/// observe a torn state where the new factory is paired with the old
344	/// scope (or vice versa). The `#[serial(di_registry)]` requirement is
345	/// what eliminates that window; do not relax it.
346	///
347	/// # Examples
348	///
349	/// ```rust,no_run
350	/// # use std::sync::Arc;
351	/// # use reinhardt_di::{DependencyRegistry, DependencyScope, DiResult, InjectionContext};
352	/// # fn _example(registry: Arc<DependencyRegistry>) {
353	/// let _guard = registry.register_override::<String, _, _>(
354	///     DependencyScope::Singleton,
355	///     |_ctx| async { Ok("mock".to_string()) },
356	/// );
357	/// // ... run test body ...
358	/// // `_guard` dropped here → previous factory restored.
359	/// # }
360	/// ```
361	pub fn register_override<T, F, Fut>(
362		self: &std::sync::Arc<Self>,
363		scope: crate::registry::DependencyScope,
364		factory: F,
365	) -> crate::testing::OverrideGuard
366	where
367		T: std::any::Any + Send + Sync + 'static,
368		F: Fn(std::sync::Arc<crate::InjectionContext>) -> Fut + Send + Sync + 'static,
369		Fut: std::future::Future<Output = crate::DiResult<T>> + Send + 'static,
370	{
371		let type_id = std::any::TypeId::of::<T>();
372		let async_factory = crate::registry::AsyncFactory::new(factory);
373		let boxed: Box<dyn crate::registry::FactoryTrait> = Box::new(async_factory);
374
375		// Capture previous (if any) and install the override.
376		let previous_factory = self.factories.insert(type_id, boxed);
377		let previous_scope = self.scopes.insert(type_id, scope);
378
379		// `factories` and `scopes` are always inserted/removed in lockstep, so
380		// observing one without the other indicates the `#[serial(di_registry)]`
381		// contract was violated by another writer. Assert in debug builds and
382		// fall back to "no previous" in release so the guard at least removes
383		// the entry instead of restoring a partial state.
384		debug_assert!(
385			previous_factory.is_some() == previous_scope.is_some(),
386			"torn override state: factories/scopes diverged for `{}`",
387			std::any::type_name::<T>()
388		);
389		let previous = match (previous_factory, previous_scope) {
390			(Some(f), Some(s)) => Some((f, s)),
391			_ => None,
392		};
393
394		crate::testing::OverrideGuard {
395			type_id,
396			previous,
397			registry: std::sync::Arc::downgrade(self),
398		}
399	}
400
401	/// Restores a previously-installed factory and scope. Used by
402	/// [`OverrideGuard::drop`](crate::testing::OverrideGuard).
403	///
404	/// Like `register_override`, this performs a non-atomic two-`DashMap`
405	/// mutation, so callers MUST hold the `#[serial(di_registry)]` lock — in
406	/// practice this is enforced by only invoking it from a guard's `Drop`,
407	/// which runs inside the same serialized test scope.
408	pub(crate) fn restore_override(
409		&self,
410		type_id: std::any::TypeId,
411		factory: Box<dyn crate::registry::FactoryTrait>,
412		scope: crate::registry::DependencyScope,
413	) {
414		self.factories.insert(type_id, factory);
415		self.scopes.insert(type_id, scope);
416	}
417
418	/// Removes an override entry that had no prior registration. Used by
419	/// [`OverrideGuard::drop`](crate::testing::OverrideGuard). Same
420	/// `#[serial(di_registry)]` requirement as `restore_override`.
421	pub(crate) fn remove_override(&self, type_id: std::any::TypeId) {
422		self.factories.remove(&type_id);
423		self.scopes.remove(&type_id);
424	}
425}
426
427impl Default for DependencyRegistry {
428	fn default() -> Self {
429		Self::new()
430	}
431}
432
433/// Global singleton registry instance
434static GLOBAL_REGISTRY: OnceLock<Arc<DependencyRegistry>> = OnceLock::new();
435
436/// Get the global registry instance
437pub fn global_registry() -> &'static Arc<DependencyRegistry> {
438	GLOBAL_REGISTRY.get_or_init(|| {
439		let registry = Arc::new(DependencyRegistry::new());
440		initialize_registry(&registry);
441		registry
442	})
443}
444
445/// Resets the global dependency registry for test isolation.
446///
447/// This replaces the `GLOBAL_REGISTRY` `OnceLock` with a fresh instance so
448/// that the next call to `global_registry()` will re-initialize it.
449///
450/// # Safety
451///
452/// This function replaces a static `OnceLock` value using `std::ptr::write`.
453/// It is only safe to call from a single-threaded test context (e.g., with
454/// `#[serial]`) where no other thread is concurrently reading the registry.
455#[cfg(test)]
456pub fn reset_global_registry() {
457	// SAFETY: We replace the OnceLock in-place with a fresh instance.
458	// This is safe only when called from a single-threaded test context
459	// (enforced by #[serial]) where no concurrent readers exist.
460	unsafe {
461		let ptr = std::ptr::addr_of!(GLOBAL_REGISTRY) as *mut OnceLock<Arc<DependencyRegistry>>;
462		std::ptr::write(ptr, OnceLock::new());
463	}
464}
465
466/// Registration entry for inventory collection
467pub struct DependencyRegistration {
468	/// The `TypeId` of the dependency being registered.
469	pub type_id: TypeId,
470	/// The human-readable name of the type.
471	pub type_name: &'static str,
472	/// The scope (request or singleton) for this dependency.
473	pub scope: DependencyScope,
474	/// Direct dependencies of this type.
475	pub dependencies: &'static [TypeId],
476	/// A function that registers this dependency's factory with the registry.
477	pub register_fn: fn(&DependencyRegistry),
478}
479
480impl DependencyRegistration {
481	/// Create a new registration entry
482	pub const fn new<T: Send + Sync + 'static>(
483		type_name: &'static str,
484		scope: DependencyScope,
485		register_fn: fn(&DependencyRegistry),
486	) -> Self {
487		Self {
488			type_id: TypeId::of::<T>(),
489			type_name,
490			scope,
491			dependencies: &[],
492			register_fn,
493		}
494	}
495
496	/// Create a new registration entry with explicit dependencies
497	pub const fn new_with_deps<T: Send + Sync + 'static>(
498		type_name: &'static str,
499		scope: DependencyScope,
500		dependencies: &'static [TypeId],
501		register_fn: fn(&DependencyRegistry),
502	) -> Self {
503		Self {
504			type_id: TypeId::of::<T>(),
505			type_name,
506			scope,
507			dependencies,
508			register_fn,
509		}
510	}
511}
512
513// Collect all dependency registrations at compile time
514inventory::collect!(DependencyRegistration);
515
516/// Const-constructible registration entry for `#[injectable]` structs with `#[scope]`.
517///
518/// Unlike `DependencyRegistration` which uses `Box<dyn Fn>` (non-const),
519/// this struct stores a plain function pointer so it can be used in
520/// `inventory::submit!` which requires const-evaluable expressions.
521pub struct InjectableRegistration {
522	/// A function that registers this type's factory with the registry.
523	pub register_fn: fn(&DependencyRegistry),
524}
525
526impl InjectableRegistration {
527	/// Create a new `InjectableRegistration` with a function pointer.
528	pub const fn new(register_fn: fn(&DependencyRegistry)) -> Self {
529		Self { register_fn }
530	}
531}
532
533inventory::collect!(InjectableRegistration);
534
535/// Initialize the registry with all collected registrations
536fn initialize_registry(registry: &DependencyRegistry) {
537	for registration in inventory::iter::<DependencyRegistration> {
538		(registration.register_fn)(registry);
539	}
540	for registration in inventory::iter::<InjectableRegistration> {
541		(registration.register_fn)(registry);
542	}
543}
544
545/// Helper macro for submitting registrations to inventory
546///
547/// This is used internally by the `#[injectable]` and `#[injectable_factory]` macros.
548#[macro_export]
549macro_rules! submit_registration {
550	($registration:expr) => {
551		$crate::inventory::submit! {
552			$registration
553		}
554	};
555}
556
557#[cfg(test)]
558mod tests {
559	use super::*;
560	use crate::scope::SingletonScope;
561	use rstest::*;
562
563	#[derive(Clone)]
564	struct TestService {
565		value: i32,
566	}
567
568	#[rstest]
569	#[tokio::test]
570	async fn test_registry_basic() {
571		let registry = DependencyRegistry::new();
572
573		registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
574			Ok(TestService { value: 42 })
575		});
576
577		assert!(registry.is_registered::<TestService>());
578		assert_eq!(
579			registry.get_scope::<TestService>(),
580			Some(DependencyScope::Singleton)
581		);
582
583		let singleton_scope = Arc::new(SingletonScope::new());
584		let ctx = InjectionContext::builder(singleton_scope).build();
585
586		let service = registry.create::<TestService>(&ctx).await.unwrap();
587		assert_eq!(service.value, 42);
588	}
589
590	#[rstest]
591	#[tokio::test]
592	async fn test_registry_not_registered() {
593		let registry = DependencyRegistry::new();
594		let singleton_scope = Arc::new(SingletonScope::new());
595		let ctx = InjectionContext::builder(singleton_scope).build();
596
597		let result = registry.create::<TestService>(&ctx).await;
598		assert!(result.is_err());
599	}
600
601	// Fixes #3457
602	#[rstest]
603	fn test_duplicate_registration_panics() {
604		let registry = DependencyRegistry::new();
605
606		registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
607			Ok(TestService { value: 1 })
608		});
609
610		// Act
611		let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
612			registry.register_async::<TestService, _, _>(DependencyScope::Request, |_ctx| async {
613				Ok(TestService { value: 2 })
614			});
615		}));
616
617		// Assert
618		let err = result.expect_err("expected panic on duplicate registration");
619		let msg = err
620			.downcast_ref::<String>()
621			.map(|s| s.as_str())
622			.or_else(|| err.downcast_ref::<&str>().copied())
623			.expect("panic payload should be a string");
624		assert!(
625			msg.contains("Duplicate DependencyRegistry registration"),
626			"missing duplicate prefix: {msg}"
627		);
628		assert!(
629			msg.contains("TestService"),
630			"missing type name in panic message: {msg}"
631		);
632		assert!(
633			msg.contains("newtype"),
634			"missing newtype hint in panic message: {msg}"
635		);
636	}
637
638	// Fixes #3457 — is_registered guard prevents panic (test helper pattern)
639	#[rstest]
640	fn test_is_registered_guard_allows_skip() {
641		let registry = DependencyRegistry::new();
642
643		registry.register_async::<TestService, _, _>(DependencyScope::Singleton, |_ctx| async {
644			Ok(TestService { value: 1 })
645		});
646
647		// Second registration guarded — no panic
648		if !registry.is_registered::<TestService>() {
649			registry.register_async::<TestService, _, _>(DependencyScope::Request, |_ctx| async {
650				Ok(TestService { value: 2 })
651			});
652		}
653
654		assert!(registry.is_registered::<TestService>());
655	}
656}