1use crate::container::{Container, ContainerBuilder};
2use std::collections::HashMap;
3use thiserror::Error;
4
5pub trait ServiceProvider: Send + Sync {
7 fn name(&self) -> &'static str;
9
10 fn register(&self, builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError>;
13
14 fn boot(&self, container: &Container) -> Result<(), ProviderError> {
17 let _ = container; Ok(())
20 }
21
22 fn dependencies(&self) -> Vec<&'static str> {
24 vec![]
25 }
26
27 fn defer_boot(&self) -> bool {
29 false
30 }
31}
32
33pub struct ProviderRegistry {
35 providers: Vec<Box<dyn ServiceProvider>>,
36 registration_order: Vec<usize>,
37 boot_order: Vec<usize>,
38}
39
40impl ProviderRegistry {
41 pub fn new() -> Self {
42 Self {
43 providers: Vec::new(),
44 registration_order: Vec::new(),
45 boot_order: Vec::new(),
46 }
47 }
48
49 pub fn register<P: ServiceProvider + 'static>(&mut self, provider: P) {
51 self.providers.push(Box::new(provider));
52 }
53
54 pub fn resolve_dependencies(&mut self) -> Result<(), ProviderError> {
56
57 let name_to_index: HashMap<String, usize> = self.providers
59 .iter()
60 .enumerate()
61 .map(|(i, p)| (p.name().to_string(), i))
62 .collect();
63
64 self.registration_order = self.topological_sort(&name_to_index, false)?;
66
67 self.boot_order = self.topological_sort(&name_to_index, true)?;
69
70 Ok(())
71 }
72
73 fn topological_sort(&self, name_to_index: &HashMap<String, usize>, consider_defer: bool) -> Result<Vec<usize>, ProviderError> {
75 let provider_count = self.providers.len();
76 let mut visited = vec![false; provider_count];
77 let mut temp_mark = vec![false; provider_count];
78 let mut result = Vec::new();
79
80 for i in 0..provider_count {
82 if !visited[i] {
83 self.visit_provider(i, name_to_index, &mut visited, &mut temp_mark, &mut result, consider_defer)?;
84 }
85 }
86
87 Ok(result)
88 }
89
90 fn visit_provider(
92 &self,
93 index: usize,
94 name_to_index: &HashMap<String, usize>,
95 visited: &mut Vec<bool>,
96 temp_mark: &mut Vec<bool>,
97 result: &mut Vec<usize>,
98 consider_defer: bool,
99 ) -> Result<(), ProviderError> {
100 if temp_mark[index] {
101 return Err(ProviderError::CircularDependency {
102 provider: self.providers[index].name().to_string(),
103 });
104 }
105
106 if visited[index] {
107 return Ok(());
108 }
109
110 temp_mark[index] = true;
111
112 let dependencies = self.providers[index].dependencies();
114 for dep_name in dependencies {
115 if let Some(&dep_index) = name_to_index.get(dep_name) {
116 self.visit_provider(dep_index, name_to_index, visited, temp_mark, result, consider_defer)?;
117 } else {
118 return Err(ProviderError::MissingDependency {
119 provider: self.providers[index].name().to_string(),
120 dependency: dep_name.to_string(),
121 });
122 }
123 }
124
125 if consider_defer && self.providers[index].defer_boot() {
127 }
130
131 temp_mark[index] = false;
132 visited[index] = true;
133 result.push(index);
134
135 Ok(())
136 }
137
138 pub fn register_all(&self, mut builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError> {
140 for &index in &self.registration_order {
141 let provider = &self.providers[index];
142 builder = provider.register(builder)
143 .map_err(|e| ProviderError::RegistrationFailed {
144 provider: provider.name().to_string(),
145 error: Box::new(e),
146 })?;
147 }
148 Ok(builder)
149 }
150
151 pub fn boot_all(&self, container: &Container) -> Result<(), ProviderError> {
153 let mut non_deferred = Vec::new();
155 let mut deferred = Vec::new();
156
157 for &index in &self.boot_order {
158 if self.providers[index].defer_boot() {
159 deferred.push(index);
160 } else {
161 non_deferred.push(index);
162 }
163 }
164
165 for index in non_deferred {
167 self.providers[index].boot(container)
168 .map_err(|e| ProviderError::BootFailed {
169 provider: self.providers[index].name().to_string(),
170 error: Box::new(e),
171 })?;
172 }
173
174 for index in deferred {
176 self.providers[index].boot(container)
177 .map_err(|e| ProviderError::BootFailed {
178 provider: self.providers[index].name().to_string(),
179 error: Box::new(e),
180 })?;
181 }
182
183 Ok(())
184 }
185
186 pub fn provider_names(&self) -> Vec<&str> {
188 self.providers.iter().map(|p| p.name()).collect()
189 }
190
191 pub fn registration_order(&self) -> &[usize] {
193 &self.registration_order
194 }
195
196 pub fn boot_order(&self) -> &[usize] {
198 &self.boot_order
199 }
200}
201
202impl Default for ProviderRegistry {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[derive(Error, Debug)]
209pub enum ProviderError {
210 #[error("Provider registration failed for '{provider}': {error}")]
211 RegistrationFailed {
212 provider: String,
213 error: Box<dyn std::error::Error + Send + Sync>
214 },
215
216 #[error("Provider boot failed for '{provider}': {error}")]
217 BootFailed {
218 provider: String,
219 error: Box<dyn std::error::Error + Send + Sync>
220 },
221
222 #[error("Circular dependency detected for provider '{provider}'")]
223 CircularDependency { provider: String },
224
225 #[error("Missing dependency '{dependency}' for provider '{provider}'")]
226 MissingDependency { provider: String, dependency: String },
227
228 #[error("Provider '{provider}' is already registered")]
229 AlreadyRegistered { provider: String },
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::app_config::{AppConfig, Environment};
236 use crate::container::DatabaseConnection;
237 use std::sync::{Arc, Mutex};
238
239 fn create_test_config() -> AppConfig {
241 AppConfig {
242 name: "test-app".to_string(),
243 environment: Environment::Testing,
244 database_url: "sqlite::memory:".to_string(),
245 jwt_secret: Some("test-secret".to_string()),
246 server: crate::app_config::ServerConfig {
247 host: "127.0.0.1".to_string(),
248 port: 8080,
249 workers: 4,
250 },
251 logging: crate::app_config::LoggingConfig {
252 level: "info".to_string(),
253 format: "compact".to_string(),
254 },
255 }
256 }
257
258 struct TestDatabase {
259 connected: bool,
260 }
261
262 impl DatabaseConnection for TestDatabase {
263 fn is_connected(&self) -> bool {
264 self.connected
265 }
266
267 fn execute(&self, _query: &str) -> Result<(), crate::container::DatabaseError> {
268 Ok(())
269 }
270 }
271
272 struct ConfigProvider;
274
275 impl ServiceProvider for ConfigProvider {
276 fn name(&self) -> &'static str {
277 "config"
278 }
279
280 fn register(&self, builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError> {
281 let config = Arc::new(create_test_config());
282
283 Ok(builder.config(config))
284 }
285 }
286
287 struct DatabaseProvider;
288
289 impl ServiceProvider for DatabaseProvider {
290 fn name(&self) -> &'static str {
291 "database"
292 }
293
294 fn dependencies(&self) -> Vec<&'static str> {
295 vec!["config"]
296 }
297
298 fn register(&self, builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError> {
299 let database = Arc::new(TestDatabase { connected: true }) as Arc<dyn DatabaseConnection>;
300 Ok(builder.database(database))
301 }
302
303 fn boot(&self, container: &Container) -> Result<(), ProviderError> {
304 let database = container.database();
305 if !database.is_connected() {
306 return Err(ProviderError::BootFailed {
307 provider: "database".to_string(),
308 error: Box::new(std::io::Error::new(
309 std::io::ErrorKind::ConnectionRefused,
310 "Database connection failed",
311 )),
312 });
313 }
314 Ok(())
315 }
316 }
317
318 lazy_static::lazy_static! {
320 static ref BOOT_ORDER: Mutex<Vec<String>> = Mutex::new(Vec::new());
321 }
322
323 struct BootTrackingProvider {
324 name: &'static str,
325 defer: bool,
326 provide_services: bool,
327 }
328
329 impl ServiceProvider for BootTrackingProvider {
330 fn name(&self) -> &'static str {
331 self.name
332 }
333
334 fn register(&self, builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError> {
335 if self.provide_services {
336 let config = Arc::new(create_test_config());
337 let database = Arc::new(TestDatabase { connected: true }) as Arc<dyn DatabaseConnection>;
338 Ok(builder.config(config).database(database))
339 } else {
340 Ok(builder)
341 }
342 }
343
344 fn boot(&self, _container: &Container) -> Result<(), ProviderError> {
345 BOOT_ORDER.lock().unwrap().push(self.name.to_string());
346 Ok(())
347 }
348
349 fn defer_boot(&self) -> bool {
350 self.defer
351 }
352 }
353
354 #[test]
355 fn test_provider_registration_and_boot() {
356 let mut registry = ProviderRegistry::new();
357 registry.register(ConfigProvider);
358 registry.register(DatabaseProvider);
359
360 registry.resolve_dependencies().unwrap();
361
362 let builder = Container::builder();
363 let builder = registry.register_all(builder).unwrap();
364 let container = builder.build().unwrap();
365
366 registry.boot_all(&container).unwrap();
367
368 let config = container.config();
370 assert_eq!(config.name, "test-app");
371
372 let database = container.database();
373 assert!(database.is_connected());
374 }
375
376 #[test]
377 fn test_dependency_resolution() {
378 let mut registry = ProviderRegistry::new();
379 registry.register(DatabaseProvider); registry.register(ConfigProvider); registry.resolve_dependencies().unwrap();
383
384 let order = registry.registration_order();
385
386 let config_pos = order.iter().position(|&i| registry.providers[i].name() == "config").unwrap();
388 let db_pos = order.iter().position(|&i| registry.providers[i].name() == "database").unwrap();
389
390 assert!(config_pos < db_pos);
391 }
392
393 #[test]
394 fn test_missing_dependency_error() {
395 let mut registry = ProviderRegistry::new();
396 registry.register(DatabaseProvider); let result = registry.resolve_dependencies();
399 assert!(matches!(result, Err(ProviderError::MissingDependency { .. })));
400 }
401
402 #[test]
403 fn test_defer_boot_ordering() {
404 BOOT_ORDER.lock().unwrap().clear();
405
406 let mut registry = ProviderRegistry::new();
407 registry.register(BootTrackingProvider { name: "normal", defer: false, provide_services: true });
408 registry.register(BootTrackingProvider { name: "deferred", defer: true, provide_services: false });
409
410 registry.resolve_dependencies().unwrap();
411
412 let builder = Container::builder();
413 let builder = registry.register_all(builder).unwrap();
414 let container = builder.build().unwrap();
415
416 registry.boot_all(&container).unwrap();
417
418 let boot_order = BOOT_ORDER.lock().unwrap();
419 assert_eq!(boot_order[0], "normal");
420 assert_eq!(boot_order[1], "deferred");
421 }
422}