1use anyhow::Result;
2use axum::Router;
3use std::collections::HashSet;
4
5use crate::{framework_log, Container};
6
7pub trait ControllerBasePath {
11 fn base_path() -> &'static str;
12}
13
14pub trait ControllerDefinition: Send + Sync + 'static {
18 fn router() -> Router<Container>;
19}
20
21#[derive(Clone, Copy)]
22pub struct ModuleRef {
23 pub name: &'static str,
24 pub register: fn(&Container) -> Result<()>,
25 pub controllers: fn() -> Vec<Router<Container>>,
26 pub imports: fn() -> Vec<ModuleRef>,
27 pub exports: fn() -> Vec<&'static str>,
28 pub is_global: fn() -> bool,
29}
30
31impl ModuleRef {
32 pub fn of<M: ModuleDefinition>() -> Self {
33 Self {
34 name: M::module_name(),
35 register: M::register,
36 controllers: M::controllers,
37 imports: M::imports,
38 exports: M::exports,
39 is_global: M::is_global,
40 }
41 }
42}
43
44pub trait ModuleDefinition: Send + Sync + 'static {
48 fn module_name() -> &'static str {
49 std::any::type_name::<Self>()
50 }
51
52 fn register(container: &Container) -> Result<()>;
53
54 fn imports() -> Vec<ModuleRef> {
55 Vec::new()
56 }
57
58 fn is_global() -> bool {
59 false
60 }
61
62 fn exports() -> Vec<&'static str> {
63 Vec::new()
64 }
65
66 fn controllers() -> Vec<Router<Container>> {
67 Vec::new()
68 }
69}
70
71pub fn initialize_module_graph<M: ModuleDefinition>(
72 container: &Container,
73) -> Result<Vec<Router<Container>>> {
74 let mut state = ModuleGraphState::default();
75 visit_module(ModuleRef::of::<M>(), container, &mut state)?;
76 Ok(state.controllers)
77}
78
79#[derive(Default)]
80struct ModuleGraphState {
81 visited: HashSet<&'static str>,
82 visiting: HashSet<&'static str>,
83 stack: Vec<&'static str>,
84 controllers: Vec<Router<Container>>,
85 global_modules: HashSet<&'static str>,
86}
87
88fn visit_module(
89 module: ModuleRef,
90 container: &Container,
91 state: &mut ModuleGraphState,
92) -> Result<()> {
93 if state.visited.contains(module.name) {
94 return Ok(());
95 }
96
97 if state.visiting.contains(module.name) {
98 let mut cycle = state.stack.clone();
99 cycle.push(module.name);
100 anyhow::bail!("Detected module import cycle: {}", cycle.join(" -> "));
101 }
102
103 state.visiting.insert(module.name);
104 state.stack.push(module.name);
105 framework_log(format!("Registering module {}.", module.name));
106
107 for imported in (module.imports)() {
108 visit_module(imported, container, state)?;
109 }
110
111 (module.register)(container)
112 .map_err(|err| anyhow::anyhow!("Failed to register module `{}`: {}", module.name, err))?;
113
114 state.controllers.extend((module.controllers)());
115
116 if (module.is_global)() {
117 state.global_modules.insert(module.name);
118 }
119
120 for export in (module.exports)() {
121 let is_registered = container.is_type_registered_name(export).map_err(|err| {
122 anyhow::anyhow!(
123 "Failed to verify exports for module `{}`: {}",
124 module.name,
125 err
126 )
127 })?;
128 if !is_registered {
129 anyhow::bail!(
130 "Module `{}` exports `{}` but that provider is not registered in the container",
131 module.name,
132 export
133 );
134 }
135 }
136
137 state.stack.pop();
138 state.visiting.remove(module.name);
139 state.visited.insert(module.name);
140
141 Ok(())
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 struct ImportedConfig;
149 struct AppService;
150
151 struct ImportedModule;
152 impl ModuleDefinition for ImportedModule {
153 fn register(container: &Container) -> Result<()> {
154 container.register(ImportedConfig)?;
155 Ok(())
156 }
157 }
158
159 struct AppModule;
160 impl ModuleDefinition for AppModule {
161 fn imports() -> Vec<ModuleRef> {
162 vec![ModuleRef::of::<ImportedModule>()]
163 }
164
165 fn register(container: &Container) -> Result<()> {
166 let imported = container.resolve::<ImportedConfig>()?;
167 container.register(AppService::from(imported))?;
168 Ok(())
169 }
170 }
171
172 impl From<std::sync::Arc<ImportedConfig>> for AppService {
173 fn from(_: std::sync::Arc<ImportedConfig>) -> Self {
174 Self
175 }
176 }
177
178 #[test]
179 fn registers_imported_modules_before_local_providers() {
180 let container = Container::new();
181 let result = initialize_module_graph::<AppModule>(&container);
182
183 assert!(result.is_ok(), "module graph registration should succeed");
184 assert!(container.resolve::<ImportedConfig>().is_ok());
185 assert!(container.resolve::<AppService>().is_ok());
186 }
187
188 struct SharedImportedModule;
189 impl ModuleDefinition for SharedImportedModule {
190 fn register(container: &Container) -> Result<()> {
191 container.register(SharedMarker)?;
192 Ok(())
193 }
194 }
195
196 struct SharedMarker;
197 struct LeftModule;
198 struct RightModule;
199 struct RootModule;
200
201 impl ModuleDefinition for LeftModule {
202 fn imports() -> Vec<ModuleRef> {
203 vec![ModuleRef::of::<SharedImportedModule>()]
204 }
205
206 fn register(_container: &Container) -> Result<()> {
207 Ok(())
208 }
209 }
210
211 impl ModuleDefinition for RightModule {
212 fn imports() -> Vec<ModuleRef> {
213 vec![ModuleRef::of::<SharedImportedModule>()]
214 }
215
216 fn register(_container: &Container) -> Result<()> {
217 Ok(())
218 }
219 }
220
221 impl ModuleDefinition for RootModule {
222 fn imports() -> Vec<ModuleRef> {
223 vec![
224 ModuleRef::of::<LeftModule>(),
225 ModuleRef::of::<RightModule>(),
226 ]
227 }
228
229 fn register(_container: &Container) -> Result<()> {
230 Ok(())
231 }
232 }
233
234 #[test]
235 fn deduplicates_shared_imported_modules() {
236 let container = Container::new();
237 let result = initialize_module_graph::<RootModule>(&container);
238
239 assert!(
240 result.is_ok(),
241 "shared imported modules should only register once"
242 );
243 assert!(container.resolve::<SharedMarker>().is_ok());
244 }
245
246 struct CycleA;
247 struct CycleB;
248
249 impl ModuleDefinition for CycleA {
250 fn imports() -> Vec<ModuleRef> {
251 vec![ModuleRef::of::<CycleB>()]
252 }
253
254 fn register(_container: &Container) -> Result<()> {
255 Ok(())
256 }
257 }
258
259 impl ModuleDefinition for CycleB {
260 fn imports() -> Vec<ModuleRef> {
261 vec![ModuleRef::of::<CycleA>()]
262 }
263
264 fn register(_container: &Container) -> Result<()> {
265 Ok(())
266 }
267 }
268
269 #[test]
270 fn detects_module_import_cycles() {
271 let container = Container::new();
272 let err = initialize_module_graph::<CycleA>(&container).unwrap_err();
273
274 assert!(
275 err.to_string().contains("Detected module import cycle"),
276 "error should include cycle detection message"
277 );
278 }
279
280 struct MissingDependency;
281 struct BrokenModule;
282
283 impl ModuleDefinition for BrokenModule {
284 fn register(container: &Container) -> Result<()> {
285 let _ = container.resolve_in_module::<MissingDependency>(Self::module_name())?;
286 Ok(())
287 }
288 }
289
290 #[test]
291 fn module_registration_error_includes_module_and_type_context() {
292 let container = Container::new();
293 let err = initialize_module_graph::<BrokenModule>(&container).unwrap_err();
294 let message = err.to_string();
295
296 assert!(message.contains("Failed to register module"));
297 assert!(message.contains("BrokenModule"));
298 assert!(message.contains("MissingDependency"));
299 }
300}