Skip to main content

nestforge_core/
module.rs

1use anyhow::Result;
2use axum::Router;
3use std::collections::HashSet;
4
5use crate::{framework_log, Container};
6
7/*
8ControllerBasePath = metadata implemented by #[controller("/...")]
9*/
10pub trait ControllerBasePath {
11    fn base_path() -> &'static str;
12}
13
14/*
15ControllerDefinition = generated by #[routes] for a controller
16*/
17pub trait ControllerDefinition: Send + Sync + 'static {
18    fn router() -> Router<Container>;
19}
20
21#[derive(Clone, Copy)]
22pub struct ModuleRef {
23    pub name: &'static str,
24    pub register: fn(&Container) -> Result<()>,
25    pub controllers: fn() -> Vec<Router<Container>>,
26    pub imports: fn() -> Vec<ModuleRef>,
27    pub exports: fn() -> Vec<&'static str>,
28    pub is_global: fn() -> bool,
29}
30
31impl ModuleRef {
32    pub fn of<M: ModuleDefinition>() -> Self {
33        Self {
34            name: M::module_name(),
35            register: M::register,
36            controllers: M::controllers,
37            imports: M::imports,
38            exports: M::exports,
39            is_global: M::is_global,
40        }
41    }
42}
43
44/*
45ModuleDefinition = app module contract (manual for now)
46*/
47pub trait ModuleDefinition: Send + Sync + 'static {
48    fn module_name() -> &'static str {
49        std::any::type_name::<Self>()
50    }
51
52    fn register(container: &Container) -> Result<()>;
53
54    fn imports() -> Vec<ModuleRef> {
55        Vec::new()
56    }
57
58    fn is_global() -> bool {
59        false
60    }
61
62    fn exports() -> Vec<&'static str> {
63        Vec::new()
64    }
65
66    fn controllers() -> Vec<Router<Container>> {
67        Vec::new()
68    }
69}
70
71pub fn initialize_module_graph<M: ModuleDefinition>(
72    container: &Container,
73) -> Result<Vec<Router<Container>>> {
74    let mut state = ModuleGraphState::default();
75    visit_module(ModuleRef::of::<M>(), container, &mut state)?;
76    Ok(state.controllers)
77}
78
79#[derive(Default)]
80struct ModuleGraphState {
81    visited: HashSet<&'static str>,
82    visiting: HashSet<&'static str>,
83    stack: Vec<&'static str>,
84    controllers: Vec<Router<Container>>,
85    global_modules: HashSet<&'static str>,
86}
87
88fn visit_module(
89    module: ModuleRef,
90    container: &Container,
91    state: &mut ModuleGraphState,
92) -> Result<()> {
93    if state.visited.contains(module.name) {
94        return Ok(());
95    }
96
97    if state.visiting.contains(module.name) {
98        let mut cycle = state.stack.clone();
99        cycle.push(module.name);
100        anyhow::bail!("Detected module import cycle: {}", cycle.join(" -> "));
101    }
102
103    state.visiting.insert(module.name);
104    state.stack.push(module.name);
105    framework_log(format!("Registering module {}.", module.name));
106
107    for imported in (module.imports)() {
108        visit_module(imported, container, state)?;
109    }
110
111    (module.register)(container)
112        .map_err(|err| anyhow::anyhow!("Failed to register module `{}`: {}", module.name, err))?;
113
114    state.controllers.extend((module.controllers)());
115
116    if (module.is_global)() {
117        state.global_modules.insert(module.name);
118    }
119
120    for export in (module.exports)() {
121        let is_registered = container.is_type_registered_name(export).map_err(|err| {
122            anyhow::anyhow!(
123                "Failed to verify exports for module `{}`: {}",
124                module.name,
125                err
126            )
127        })?;
128        if !is_registered {
129            anyhow::bail!(
130                "Module `{}` exports `{}` but that provider is not registered in the container",
131                module.name,
132                export
133            );
134        }
135    }
136
137    state.stack.pop();
138    state.visiting.remove(module.name);
139    state.visited.insert(module.name);
140
141    Ok(())
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    struct ImportedConfig;
149    struct AppService;
150
151    struct ImportedModule;
152    impl ModuleDefinition for ImportedModule {
153        fn register(container: &Container) -> Result<()> {
154            container.register(ImportedConfig)?;
155            Ok(())
156        }
157    }
158
159    struct AppModule;
160    impl ModuleDefinition for AppModule {
161        fn imports() -> Vec<ModuleRef> {
162            vec![ModuleRef::of::<ImportedModule>()]
163        }
164
165        fn register(container: &Container) -> Result<()> {
166            let imported = container.resolve::<ImportedConfig>()?;
167            container.register(AppService::from(imported))?;
168            Ok(())
169        }
170    }
171
172    impl From<std::sync::Arc<ImportedConfig>> for AppService {
173        fn from(_: std::sync::Arc<ImportedConfig>) -> Self {
174            Self
175        }
176    }
177
178    #[test]
179    fn registers_imported_modules_before_local_providers() {
180        let container = Container::new();
181        let result = initialize_module_graph::<AppModule>(&container);
182
183        assert!(result.is_ok(), "module graph registration should succeed");
184        assert!(container.resolve::<ImportedConfig>().is_ok());
185        assert!(container.resolve::<AppService>().is_ok());
186    }
187
188    struct SharedImportedModule;
189    impl ModuleDefinition for SharedImportedModule {
190        fn register(container: &Container) -> Result<()> {
191            container.register(SharedMarker)?;
192            Ok(())
193        }
194    }
195
196    struct SharedMarker;
197    struct LeftModule;
198    struct RightModule;
199    struct RootModule;
200
201    impl ModuleDefinition for LeftModule {
202        fn imports() -> Vec<ModuleRef> {
203            vec![ModuleRef::of::<SharedImportedModule>()]
204        }
205
206        fn register(_container: &Container) -> Result<()> {
207            Ok(())
208        }
209    }
210
211    impl ModuleDefinition for RightModule {
212        fn imports() -> Vec<ModuleRef> {
213            vec![ModuleRef::of::<SharedImportedModule>()]
214        }
215
216        fn register(_container: &Container) -> Result<()> {
217            Ok(())
218        }
219    }
220
221    impl ModuleDefinition for RootModule {
222        fn imports() -> Vec<ModuleRef> {
223            vec![
224                ModuleRef::of::<LeftModule>(),
225                ModuleRef::of::<RightModule>(),
226            ]
227        }
228
229        fn register(_container: &Container) -> Result<()> {
230            Ok(())
231        }
232    }
233
234    #[test]
235    fn deduplicates_shared_imported_modules() {
236        let container = Container::new();
237        let result = initialize_module_graph::<RootModule>(&container);
238
239        assert!(
240            result.is_ok(),
241            "shared imported modules should only register once"
242        );
243        assert!(container.resolve::<SharedMarker>().is_ok());
244    }
245
246    struct CycleA;
247    struct CycleB;
248
249    impl ModuleDefinition for CycleA {
250        fn imports() -> Vec<ModuleRef> {
251            vec![ModuleRef::of::<CycleB>()]
252        }
253
254        fn register(_container: &Container) -> Result<()> {
255            Ok(())
256        }
257    }
258
259    impl ModuleDefinition for CycleB {
260        fn imports() -> Vec<ModuleRef> {
261            vec![ModuleRef::of::<CycleA>()]
262        }
263
264        fn register(_container: &Container) -> Result<()> {
265            Ok(())
266        }
267    }
268
269    #[test]
270    fn detects_module_import_cycles() {
271        let container = Container::new();
272        let err = initialize_module_graph::<CycleA>(&container).unwrap_err();
273
274        assert!(
275            err.to_string().contains("Detected module import cycle"),
276            "error should include cycle detection message"
277        );
278    }
279
280    struct MissingDependency;
281    struct BrokenModule;
282
283    impl ModuleDefinition for BrokenModule {
284        fn register(container: &Container) -> Result<()> {
285            let _ = container.resolve_in_module::<MissingDependency>(Self::module_name())?;
286            Ok(())
287        }
288    }
289
290    #[test]
291    fn module_registration_error_includes_module_and_type_context() {
292        let container = Container::new();
293        let err = initialize_module_graph::<BrokenModule>(&container).unwrap_err();
294        let message = err.to_string();
295
296        assert!(message.contains("Failed to register module"));
297        assert!(message.contains("BrokenModule"));
298        assert!(message.contains("MissingDependency"));
299    }
300}