1use std::any::{Any, TypeId};
2use std::collections::{HashMap, HashSet};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::OnceLock;
7
8use async_trait::async_trait;
9use axum::Router;
10
11mod database;
12mod discovery;
13mod execution_context;
14mod guard;
15mod metadata;
16mod module_ref;
17mod pipe;
18mod platform;
19mod route_registry;
20mod strategy;
21
22pub use database::DatabasePing;
23pub use discovery::DiscoveryService;
24pub use execution_context::{ExecutionContext, HostType, HttpExecutionArguments};
25pub use guard::{CanActivate, GuardError};
26pub use metadata::MetadataRegistry;
27pub use module_ref::ModuleRef;
28pub use pipe::PipeTransform;
29pub use platform::{AxumHttpEngine, HttpServerEngine};
30pub use route_registry::{OpenApiResponseDesc, OpenApiRouteSpec, RouteInfo, RouteRegistry};
31pub use strategy::{AuthError, AuthStrategy};
32
33type CustomFactoryFn =
34 std::sync::Arc<dyn Fn(&ProviderRegistry) -> Arc<dyn Any + Send + Sync> + Send + Sync>;
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum ProviderScope {
45 Singleton,
47 Transient,
49 Request,
51}
52
53#[derive(Clone)]
54enum ProviderFactory {
55 InjectableFn(fn(&ProviderRegistry) -> Arc<dyn Any + Send + Sync>),
56 Custom(CustomFactoryFn),
57}
58
59#[derive(Clone)]
60struct ProviderEntry {
61 type_name: &'static str,
62 scope: ProviderScope,
63 factory: ProviderFactory,
64 instance: Arc<OnceLock<Arc<dyn Any + Send + Sync>>>,
65 on_module_init: HookFn,
66 on_module_destroy: HookFn,
67 on_application_bootstrap: HookFn,
68 on_application_shutdown: HookFn,
69}
70
71fn noop_hook<'a>(_registry: &'a ProviderRegistry) -> HookFuture<'a> {
72 Box::pin(async {})
73}
74
75fn create_entry_for_injectable<T: Injectable + Send + Sync + 'static>() -> ProviderEntry {
76 fn factory<T: Injectable + Send + Sync + 'static>(
77 registry: &ProviderRegistry,
78 ) -> Arc<dyn Any + Send + Sync> {
79 T::construct(registry)
80 }
81
82 ProviderEntry {
83 type_name: std::any::type_name::<T>(),
84 scope: T::scope(),
85 factory: ProviderFactory::InjectableFn(factory::<T>),
86 instance: Arc::new(OnceLock::new()),
87 on_module_init: hook_on_module_init::<T>,
88 on_module_destroy: hook_on_module_destroy::<T>,
89 on_application_bootstrap: hook_on_application_bootstrap::<T>,
90 on_application_shutdown: hook_on_application_shutdown::<T>,
91 }
92}
93
94pub struct ProviderRegistry {
95 entries: HashMap<TypeId, ProviderEntry>,
96}
97
98#[derive(Clone, Copy, Debug)]
100pub struct HandlerKey(pub &'static str);
101
102impl ProviderRegistry {
103 pub fn new() -> Self {
104 Self {
105 entries: HashMap::new(),
106 }
107 }
108
109 pub fn register<T>(&mut self)
110 where
111 T: Injectable + Send + Sync + 'static,
112 {
113 self.entries
114 .insert(TypeId::of::<T>(), create_entry_for_injectable::<T>());
115 }
116
117 pub fn register_use_value<T: Send + Sync + 'static>(&mut self, value: Arc<T>) {
119 let preset: Arc<dyn Any + Send + Sync> = value;
120 let cell = Arc::new(OnceLock::new());
121 let _ = cell.set(preset.clone());
122 self.entries.insert(
123 TypeId::of::<T>(),
124 ProviderEntry {
125 type_name: std::any::type_name::<T>(),
126 scope: ProviderScope::Singleton,
127 factory: ProviderFactory::Custom(Arc::new(move |_| preset.clone())),
128 instance: cell,
129 on_module_init: noop_hook,
130 on_module_destroy: noop_hook,
131 on_application_bootstrap: noop_hook,
132 on_application_shutdown: noop_hook,
133 },
134 );
135 }
136
137 pub fn register_use_factory<T, F>(&mut self, scope: ProviderScope, factory: F)
146 where
147 T: Send + Sync + 'static,
148 F: Fn(&ProviderRegistry) -> Arc<T> + Send + Sync + 'static,
149 {
150 let factory: std::sync::Arc<F> = std::sync::Arc::new(factory);
151 let factory = factory.clone();
152 self.entries.insert(
153 TypeId::of::<T>(),
154 ProviderEntry {
155 type_name: std::any::type_name::<T>(),
156 scope,
157 factory: ProviderFactory::Custom(Arc::new(move |r| {
158 let v = factory(r);
159 v as Arc<dyn Any + Send + Sync>
160 })),
161 instance: Arc::new(OnceLock::new()),
162 on_module_init: noop_hook,
163 on_module_destroy: noop_hook,
164 on_application_bootstrap: noop_hook,
165 on_application_shutdown: noop_hook,
166 },
167 );
168 }
169
170 #[inline]
172 pub fn register_use_class<T>(&mut self)
173 where
174 T: Injectable + Send + Sync + 'static,
175 {
176 self.register::<T>();
177 }
178
179 pub fn override_provider<T>(&mut self, instance: Arc<T>)
184 where
185 T: Injectable + Send + Sync + 'static,
186 {
187 let entry = ProviderEntry {
188 type_name: std::any::type_name::<T>(),
189 scope: ProviderScope::Singleton,
190 factory: ProviderFactory::InjectableFn(|_| unreachable!("override preset")),
191 instance: Arc::new(OnceLock::new()),
192 on_module_init: hook_on_module_init::<T>,
193 on_module_destroy: hook_on_module_destroy::<T>,
194 on_application_bootstrap: hook_on_application_bootstrap::<T>,
195 on_application_shutdown: hook_on_application_shutdown::<T>,
196 };
197
198 let any: Arc<dyn Any + Send + Sync> = instance;
199 let _ = entry.instance.set(any);
200
201 self.entries.insert(TypeId::of::<T>(), entry);
202 }
203
204 fn produce_any(&self, type_id: TypeId, entry: &ProviderEntry) -> Arc<dyn Any + Send + Sync> {
205 match entry.scope {
206 ProviderScope::Singleton => {
207 let _guard = ConstructionGuard::push(type_id, entry.type_name);
208 entry
209 .instance
210 .get_or_init(|| match &entry.factory {
211 ProviderFactory::InjectableFn(f) => f(self),
212 ProviderFactory::Custom(f) => f(self),
213 })
214 .clone()
215 }
216 ProviderScope::Transient => {
217 let _guard = ConstructionGuard::push(type_id, entry.type_name);
218 match &entry.factory {
219 ProviderFactory::InjectableFn(f) => f(self),
220 ProviderFactory::Custom(f) => f(self),
221 }
222 }
223 ProviderScope::Request => {
224 let _guard = ConstructionGuard::push(type_id, entry.type_name);
225 REQUEST_SCOPE_CACHE
226 .try_with(|cell| {
227 if let Some(existing) = cell.borrow().get(&type_id).cloned() {
228 return existing;
229 }
230 let value = match &entry.factory {
231 ProviderFactory::InjectableFn(f) => f(self),
232 ProviderFactory::Custom(f) => f(self),
233 };
234 cell.borrow_mut().insert(type_id, value.clone());
235 value
236 })
237 .unwrap_or_else(|_| {
238 panic!(
239 "Request-scoped provider `{}` requested outside request scope; enable request scope middleware",
240 entry.type_name
241 )
242 })
243 }
244 }
245 }
246
247 pub fn get<T>(&self) -> Arc<T>
248 where
249 T: Send + Sync + 'static,
250 {
251 let type_id = TypeId::of::<T>();
252 let entry = self
253 .entries
254 .get(&type_id)
255 .unwrap_or_else(|| panic!("Provider `{}` not registered", std::any::type_name::<T>()));
256
257 let any = self.produce_any(type_id, entry);
258
259 any.downcast::<T>().unwrap_or_else(|_| {
260 panic!(
261 "Provider downcast failed for `{}`",
262 std::any::type_name::<T>()
263 )
264 })
265 }
266
267 pub fn registered_type_ids(&self) -> Vec<TypeId> {
269 self.entries.keys().copied().collect()
270 }
271
272 pub fn registered_type_names(&self) -> Vec<&'static str> {
274 self.entries.values().map(|e| e.type_name).collect()
275 }
276
277 pub fn absorb(&mut self, other: ProviderRegistry) {
278 self.entries.extend(other.entries);
279 }
280
281 pub fn absorb_exported(&mut self, other: ProviderRegistry, exported: &[TypeId]) {
282 if exported.is_empty() {
283 return;
284 }
285 let allow = exported.iter().copied().collect::<HashSet<_>>();
286 for (type_id, entry) in other.entries {
287 if allow.contains(&type_id) {
288 self.entries.insert(type_id, entry);
289 }
290 }
291 }
292
293 pub fn absorb_exported_from(&mut self, other: &ProviderRegistry, exported: &[TypeId]) {
296 if exported.is_empty() {
297 return;
298 }
299 let allow = exported.iter().copied().collect::<HashSet<_>>();
300 for (type_id, entry) in &other.entries {
301 if allow.contains(type_id) {
302 self.entries.insert(*type_id, entry.clone());
303 }
304 }
305 }
306
307 pub fn eager_init_singletons(&self) {
309 for (type_id, entry) in self.entries.iter() {
310 if entry.scope == ProviderScope::Singleton {
311 let _guard = ConstructionGuard::push(*type_id, entry.type_name);
312 let _ = entry.instance.get_or_init(|| match &entry.factory {
313 ProviderFactory::InjectableFn(f) => f(self),
314 ProviderFactory::Custom(f) => f(self),
315 });
316 }
317 }
318 }
319
320 pub async fn run_on_module_init(&self) {
321 for entry in self.entries.values() {
322 if entry.scope == ProviderScope::Singleton {
323 (entry.on_module_init)(self).await;
324 }
325 }
326 }
327
328 pub async fn run_on_module_destroy(&self) {
329 for entry in self.entries.values() {
330 if entry.scope == ProviderScope::Singleton {
331 (entry.on_module_destroy)(self).await;
332 }
333 }
334 }
335
336 pub async fn run_on_application_bootstrap(&self) {
337 for entry in self.entries.values() {
338 if entry.scope == ProviderScope::Singleton {
339 (entry.on_application_bootstrap)(self).await;
340 }
341 }
342 }
343
344 pub async fn run_on_application_shutdown(&self) {
345 for entry in self.entries.values() {
346 if entry.scope == ProviderScope::Singleton {
347 (entry.on_application_shutdown)(self).await;
348 }
349 }
350 }
351}
352
353impl Clone for ProviderRegistry {
354 fn clone(&self) -> Self {
355 Self {
356 entries: self.entries.clone(),
357 }
358 }
359}
360
361impl Default for ProviderRegistry {
362 fn default() -> Self {
363 Self::new()
364 }
365}
366
367type HookFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
368type HookFn = for<'a> fn(&'a ProviderRegistry) -> HookFuture<'a>;
369
370fn hook_on_module_init<'a, T>(registry: &'a ProviderRegistry) -> HookFuture<'a>
371where
372 T: Injectable + Send + Sync + 'static,
373{
374 Box::pin(async move {
375 let v = registry.get::<T>();
376 v.on_module_init().await;
377 })
378}
379
380fn hook_on_module_destroy<'a, T>(registry: &'a ProviderRegistry) -> HookFuture<'a>
381where
382 T: Injectable + Send + Sync + 'static,
383{
384 Box::pin(async move {
385 let v = registry.get::<T>();
386 v.on_module_destroy().await;
387 })
388}
389
390fn hook_on_application_bootstrap<'a, T>(registry: &'a ProviderRegistry) -> HookFuture<'a>
391where
392 T: Injectable + Send + Sync + 'static,
393{
394 Box::pin(async move {
395 let v = registry.get::<T>();
396 v.on_application_bootstrap().await;
397 })
398}
399
400fn hook_on_application_shutdown<'a, T>(registry: &'a ProviderRegistry) -> HookFuture<'a>
401where
402 T: Injectable + Send + Sync + 'static,
403{
404 Box::pin(async move {
405 let v = registry.get::<T>();
406 v.on_application_shutdown().await;
407 })
408}
409
410#[async_trait]
420pub trait Injectable: Send + Sync + 'static {
421 fn construct(registry: &ProviderRegistry) -> Arc<Self>;
422
423 fn scope() -> ProviderScope {
425 ProviderScope::Singleton
426 }
427
428 async fn on_module_init(&self) {}
429 async fn on_module_destroy(&self) {}
430 async fn on_application_bootstrap(&self) {}
431 async fn on_application_shutdown(&self) {}
432}
433
434pub trait Controller {
435 fn register(router: Router, registry: &ProviderRegistry) -> Router;
436}
437
438pub trait Module {
439 fn build() -> (ProviderRegistry, Router);
440
441 fn exports() -> Vec<TypeId> {
442 Vec::new()
443 }
444}
445
446pub trait ModuleGraph {
452 fn register_providers(registry: &mut ProviderRegistry);
453 fn register_controllers(router: Router, registry: &ProviderRegistry) -> Router;
454}
455
456pub struct DynamicModule {
464 pub registry: ProviderRegistry,
466 pub router: Router,
467 pub exports: Vec<TypeId>,
469}
470
471impl DynamicModule {
472 pub fn from_module<M: Module>() -> Self {
474 let (registry, router) = M::build();
475 let exports = <M as Module>::exports();
476 Self {
477 registry,
478 router,
479 exports,
480 }
481 }
482
483 pub fn from_router(router: Router) -> Self {
485 Self {
486 registry: ProviderRegistry::new(),
487 router,
488 exports: Vec::new(),
489 }
490 }
491
492 pub fn from_parts(registry: ProviderRegistry, router: Router, exports: Vec<TypeId>) -> Self {
494 Self {
495 registry,
496 router,
497 exports,
498 }
499 }
500
501 pub fn lazy<M: Module + 'static>() -> Self {
504 static CELL: std::sync::OnceLock<DynamicModule> = std::sync::OnceLock::new();
505 CELL.get_or_init(DynamicModule::from_module::<M>).clone()
506 }
507}
508
509impl Clone for DynamicModule {
510 fn clone(&self) -> Self {
511 Self {
512 registry: self.registry.clone(),
513 router: self.router.clone(),
514 exports: self.exports.clone(),
515 }
516 }
517}
518
519pub struct ModuleOptions<O, M> {
524 inner: O,
525 _marker: std::marker::PhantomData<fn() -> M>,
526}
527
528impl<O, M> ModuleOptions<O, M> {
529 pub fn new(inner: O) -> Self {
530 Self {
531 inner,
532 _marker: std::marker::PhantomData,
533 }
534 }
535
536 pub fn get(&self) -> &O {
537 &self.inner
538 }
539
540 pub fn into_inner(self) -> O {
541 self.inner
542 }
543}
544
545impl<O, M> std::ops::Deref for ModuleOptions<O, M> {
546 type Target = O;
547
548 fn deref(&self) -> &Self::Target {
549 &self.inner
550 }
551}
552
553#[async_trait]
554impl<O, M> Injectable for ModuleOptions<O, M>
555where
556 O: Send + Sync + 'static,
557 M: 'static,
558{
559 fn construct(_registry: &ProviderRegistry) -> Arc<Self> {
560 panic!(
561 "ModuleOptions requested but no value was provided. Use ConfigurableModuleBuilder / DynamicModuleBuilder to supply module options."
562 );
563 }
564}
565
566type RegistryOverrideFn = Box<dyn FnOnce(&mut ProviderRegistry) + Send>;
567
568pub struct DynamicModuleBuilder<M>
571where
572 M: Module + ModuleGraph,
573{
574 overrides: Vec<RegistryOverrideFn>,
575 _marker: std::marker::PhantomData<M>,
576}
577
578impl<M> DynamicModuleBuilder<M>
579where
580 M: Module + ModuleGraph,
581{
582 pub fn new() -> Self {
583 Self {
584 overrides: Vec::new(),
585 _marker: std::marker::PhantomData,
586 }
587 }
588
589 pub fn override_provider<T>(mut self, instance: Arc<T>) -> Self
590 where
591 T: Injectable + Send + Sync + 'static,
592 {
593 self.overrides
594 .push(Box::new(move |r| r.override_provider::<T>(instance)));
595 self
596 }
597
598 pub fn build(self) -> DynamicModule {
599 let mut registry = ProviderRegistry::new();
600 M::register_providers(&mut registry);
601 for apply in self.overrides {
602 apply(&mut registry);
603 }
604 let router = M::register_controllers(Router::new(), ®istry);
605 DynamicModule::from_parts(registry, router, M::exports())
606 }
607}
608
609impl<M> Default for DynamicModuleBuilder<M>
610where
611 M: Module + ModuleGraph,
612{
613 fn default() -> Self {
614 Self::new()
615 }
616}
617
618pub struct ConfigurableModuleBuilder<O> {
620 _marker: std::marker::PhantomData<O>,
621}
622
623impl<O> ConfigurableModuleBuilder<O>
624where
625 O: Send + Sync + 'static,
626{
627 pub fn for_root<M>(options: O) -> DynamicModule
628 where
629 M: Module + ModuleGraph + 'static,
630 {
631 DynamicModuleBuilder::<M>::new()
632 .override_provider::<ModuleOptions<O, M>>(Arc::new(ModuleOptions::new(options)))
633 .build()
634 }
635
636 pub async fn for_root_async<M, F, Fut>(factory: F) -> DynamicModule
637 where
638 M: Module + ModuleGraph + 'static,
639 F: FnOnce() -> Fut,
640 Fut: Future<Output = O>,
641 {
642 let options = factory().await;
643 Self::for_root::<M>(options)
644 }
645}
646
647thread_local! {
648 static MODULE_BUILD_STACK: std::cell::RefCell<Vec<(&'static str, TypeId)>> =
649 const { std::cell::RefCell::new(Vec::new()) };
650}
651
652#[doc(hidden)]
654pub struct __NestrsModuleBuildGuard {
655 type_id: TypeId,
656}
657
658impl __NestrsModuleBuildGuard {
659 pub fn push(type_id: TypeId, type_name: &'static str) -> Self {
660 let is_cycle = MODULE_BUILD_STACK.with(|stack| {
661 let mut guard = stack.borrow_mut();
662 let cycle = guard.iter().any(|(_, id)| *id == type_id);
663 if !cycle {
664 guard.push((type_name, type_id));
665 }
666 cycle
667 });
668
669 if is_cycle {
670 __nestrs_panic_circular_module_dependency(type_name);
671 }
672
673 Self { type_id }
674 }
675}
676
677impl Drop for __NestrsModuleBuildGuard {
678 fn drop(&mut self) {
679 MODULE_BUILD_STACK.with(|stack| {
680 let mut guard = stack.borrow_mut();
681 if let Some((_, id)) = guard.last() {
682 if *id == self.type_id {
683 guard.pop();
684 }
685 }
686 });
687 }
688}
689
690#[doc(hidden)]
691pub fn __nestrs_module_stack_contains(type_id: TypeId) -> bool {
692 MODULE_BUILD_STACK.with(|stack| stack.borrow().iter().any(|(_, id)| *id == type_id))
693}
694
695#[doc(hidden)]
696pub fn __nestrs_panic_circular_module_dependency(import_type_name: &'static str) -> ! {
697 let chain = MODULE_BUILD_STACK.with(|stack| {
698 stack
699 .borrow()
700 .iter()
701 .map(|(name, _)| *name)
702 .chain(std::iter::once(import_type_name))
703 .collect::<Vec<_>>()
704 .join(" -> ")
705 });
706
707 panic!(
708 "Circular module dependency detected: {chain}. If intentional, mark the NestJS-style back-edge import with `forward_ref::<T>()` (or `forwardRef` alias in the `#[module]` macro). See the nestrs mdBook chapter **Fundamentals** (`docs/src/fundamentals.md`).",
709 );
710}
711
712tokio::task_local! {
713 static REQUEST_SCOPE_CACHE: std::cell::RefCell<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>;
714}
715
716pub async fn with_request_scope<Fut, T>(future: Fut) -> T
718where
719 Fut: std::future::Future<Output = T>,
720{
721 REQUEST_SCOPE_CACHE
722 .scope(std::cell::RefCell::new(HashMap::new()), future)
723 .await
724}
725
726thread_local! {
727 static CONSTRUCTION_STACK: std::cell::RefCell<Vec<(&'static str, TypeId)>> =
728 const { std::cell::RefCell::new(Vec::new()) };
729}
730
731struct ConstructionGuard {
732 type_id: TypeId,
733}
734
735impl ConstructionGuard {
736 fn push(type_id: TypeId, type_name: &'static str) -> Self {
737 CONSTRUCTION_STACK.with(|stack| {
738 let mut guard = stack.borrow_mut();
739 if guard.iter().any(|(_, id)| *id == type_id) {
740 let chain = guard
741 .iter()
742 .map(|(name, _)| *name)
743 .chain(std::iter::once(type_name))
744 .collect::<Vec<_>>()
745 .join(" -> ");
746 panic!(
747 "Circular provider dependency detected: {chain}. Break the cycle with lazy construction (`register_use_factory`), split types, defer work to `on_module_init`, or a `forward_ref`-style module import for module graphs. See the nestrs mdBook chapter **Fundamentals** (`docs/src/fundamentals.md`)."
748 );
749 }
750 guard.push((type_name, type_id));
751 });
752 Self { type_id }
753 }
754}
755
756impl Drop for ConstructionGuard {
757 fn drop(&mut self) {
758 CONSTRUCTION_STACK.with(|stack| {
759 let mut guard = stack.borrow_mut();
760 if let Some((_, id)) = guard.last() {
761 if *id == self.type_id {
762 guard.pop();
763 }
764 }
765 });
766 }
767}