1use dashmap::DashMap;
2use once_cell::sync::Lazy;
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6
7use crate::{
8 descriptor::ServiceProvider as DescriptorServiceProvider, DiError, DiResult, Lifetime,
9 ServiceDescriptor, ServiceKey,
10};
11
12static SINGLETON_SERVICES: Lazy<DashMap<ServiceKey, Arc<dyn Any + Send + Sync>>> =
14 Lazy::new(DashMap::new);
15
16pub struct Container {
18 services: Arc<RwLock<HashMap<ServiceKey, ServiceDescriptor>>>,
20 resolution_stack: Arc<Mutex<Vec<ServiceKey>>>,
22}
23
24impl Container {
25 pub fn new() -> Self {
27 Self {
28 services: Arc::new(RwLock::new(HashMap::new())),
29 resolution_stack: Arc::new(Mutex::new(Vec::new())),
30 }
31 }
32
33 pub fn register(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
35 let mut services = self
36 .services
37 .write()
38 .map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
39
40 if services.contains_key(&descriptor.service_key) {
42 return Err(DiError::Generic {
43 message: format!(
44 "Service with key {:?} is already registered",
45 descriptor.service_key
46 ),
47 });
48 }
49
50 services.insert(descriptor.service_key.clone(), descriptor);
51 Ok(())
52 }
53
54 pub fn register_overwrite(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
56 let mut services = self
57 .services
58 .write()
59 .map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
60
61 services.insert(descriptor.service_key.clone(), descriptor);
62 Ok(())
63 }
64
65 pub fn is_registered<T: 'static>(&self) -> DiResult<bool> {
67 let key = ServiceKey::of_type::<T>();
68 self.is_registered_with_key(&key)
69 }
70
71 pub fn is_keyed_registered<T: 'static>(&self, name: &str) -> DiResult<bool> {
73 let key = ServiceKey::named::<T>(name);
74 self.is_registered_with_key(&key)
75 }
76
77 pub fn is_registered_with_key(&self, key: &ServiceKey) -> DiResult<bool> {
79 let services = self
80 .services
81 .read()
82 .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
83
84 Ok(services.contains_key(key))
85 }
86
87 fn get_descriptor(&self, key: &ServiceKey) -> DiResult<Option<ServiceDescriptor>> {
89 let services = self
90 .services
91 .read()
92 .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
93
94 Ok(services.get(key).cloned())
95 }
96
97 pub fn build_provider(self) -> ServiceProvider {
99 ServiceProvider::new(Arc::new(self))
100 }
101
102 pub fn build(self) -> ServiceProvider {
104 self.build_provider()
105 }
106}
107
108impl Default for Container {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114type ScopeStorage = Arc<RwLock<HashMap<ServiceKey, Arc<dyn Any + Send + Sync>>>>;
116
117pub struct ServiceProvider {
119 container: Arc<Container>,
120}
121
122impl ServiceProvider {
123 fn new(container: Arc<Container>) -> Self {
124 Self { container }
125 }
126
127 pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
129 let descriptors = self.get_all_descriptors_for_type::<T>()?;
130 let mut services = Vec::new();
131
132 for descriptor in descriptors {
133 if let Some(service) = self.resolve_service::<T>(&descriptor.service_key, None)? {
134 services.push(service);
135 }
136 }
137
138 Ok(services)
139 }
140
141 pub fn create_scope(&self) -> DiResult<ServiceScope> {
143 ServiceScope::new(Arc::clone(&self.container))
144 }
145
146 fn resolve_service<T: 'static + Send + Sync>(
148 &self,
149 key: &ServiceKey,
150 scope_storage: Option<&ScopeStorage>,
151 ) -> DiResult<Option<Arc<T>>> {
152 self.begin_resolution(key)?;
154
155 let result = self.internal_resolve_service::<T>(key, scope_storage);
156
157 self.end_resolution(key)?;
159
160 result
161 }
162
163 fn check_circular_dependency(&self, key: &ServiceKey) -> DiResult<()> {
165 let stack = self
166 .container
167 .resolution_stack
168 .lock()
169 .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
170
171 if stack.contains(key) {
172 return Err(DiError::Generic {
173 message: format!("Circular dependency detected for service key: {key:?}"),
174 });
175 }
176
177 Ok(())
178 }
179
180 fn begin_resolution(&self, key: &ServiceKey) -> DiResult<()> {
182 self.check_circular_dependency(key)?;
183
184 let mut stack = self
185 .container
186 .resolution_stack
187 .lock()
188 .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
189
190 stack.push(key.clone());
191 Ok(())
192 }
193
194 fn end_resolution(&self, key: &ServiceKey) -> DiResult<()> {
196 let mut stack = self
197 .container
198 .resolution_stack
199 .lock()
200 .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
201
202 if let Some(pos) = stack.iter().position(|k| k == key) {
203 stack.remove(pos);
204 }
205
206 Ok(())
207 }
208
209 fn internal_resolve_service<T: 'static + Send + Sync>(
211 &self,
212 key: &ServiceKey,
213 scope_storage: Option<&ScopeStorage>,
214 ) -> DiResult<Option<Arc<T>>> {
215 let descriptor = match self.container.get_descriptor(key)? {
216 Some(desc) => desc,
217 None => return Ok(None),
218 };
219
220 match descriptor.lifetime {
221 Lifetime::Singleton => self.resolve_singleton::<T>(&descriptor),
222 Lifetime::Scoped => match scope_storage {
223 Some(storage) => self.resolve_scoped::<T>(&descriptor, storage),
224 None => Err(DiError::Generic {
225 message: format!("Scoped service cannot be resolved without a scope: {key:?}"),
226 }),
227 },
228 Lifetime::Transient => self.resolve_transient::<T>(&descriptor),
229 }
230 }
231
232 fn resolve_singleton<T: 'static + Send + Sync>(
234 &self,
235 descriptor: &ServiceDescriptor,
236 ) -> DiResult<Option<Arc<T>>> {
237 if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
239 let any_arc = Arc::clone(&cached);
240 return self.cast_to_arc::<T>(any_arc);
241 }
242
243 let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
245 let instance = descriptor.create_instance(&provider)?;
246
247 let typed_instance = self.box_to_typed_arc::<T>(instance)?;
249 let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
250 SINGLETON_SERVICES.insert(descriptor.service_key.clone(), any_arc);
251
252 Ok(Some(typed_instance))
253 }
254
255 fn resolve_scoped<T: 'static + Send + Sync>(
257 &self,
258 descriptor: &ServiceDescriptor,
259 scope_storage: &ScopeStorage,
260 ) -> DiResult<Option<Arc<T>>> {
261 {
263 let storage = scope_storage
264 .read()
265 .map_err(|_| DiError::generic("Failed to acquire scope storage read lock"))?;
266
267 if let Some(cached) = storage.get(&descriptor.service_key) {
268 let any_arc = Arc::clone(cached);
269 return self.cast_to_arc::<T>(any_arc);
270 }
271 }
272
273 let provider =
275 ContainerServiceProvider::new(Arc::clone(&self.container), Some(scope_storage.clone()));
276 let instance = descriptor.create_instance(&provider)?;
277
278 let typed_instance = self.box_to_typed_arc::<T>(instance)?;
280 let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
281
282 {
283 let mut storage = scope_storage
284 .write()
285 .map_err(|_| DiError::generic("Failed to acquire scope storage write lock"))?;
286 storage.insert(descriptor.service_key.clone(), any_arc);
287 }
288
289 Ok(Some(typed_instance))
290 }
291
292 fn resolve_transient<T: 'static + Send + Sync>(
294 &self,
295 descriptor: &ServiceDescriptor,
296 ) -> DiResult<Option<Arc<T>>> {
297 let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
298 let instance = descriptor.create_instance(&provider)?;
299 let typed_instance = self.box_to_typed_arc::<T>(instance)?;
300 Ok(Some(typed_instance))
301 }
302
303 fn box_to_typed_arc<T: 'static + Send + Sync>(
305 &self,
306 instance: Box<dyn Any + Send + Sync>,
307 ) -> DiResult<Arc<T>> {
308 match instance.downcast::<T>() {
309 Ok(boxed) => Ok(Arc::new(*boxed)),
310 Err(_) => Err(DiError::type_casting_failed::<T>()),
311 }
312 }
313
314 fn cast_to_arc<T: 'static + Send + Sync>(
316 &self,
317 any_arc: Arc<dyn Any + Send + Sync>,
318 ) -> DiResult<Option<Arc<T>>> {
319 if let Ok(arc_t) = any_arc.downcast::<T>() {
321 return Ok(Some(arc_t));
322 }
323
324 Err(DiError::type_casting_failed::<T>())
325 }
326
327 fn get_all_descriptors_for_type<T: 'static + Send + Sync>(
329 &self,
330 ) -> DiResult<Vec<ServiceDescriptor>> {
331 let services = self
332 .container
333 .services
334 .read()
335 .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
336
337 let target_type_id = TypeId::of::<T>();
338 let descriptors: Vec<ServiceDescriptor> = services
339 .values()
340 .filter(|desc| desc.service_type == target_type_id)
341 .cloned()
342 .collect();
343
344 Ok(descriptors)
345 }
346}
347
348struct ContainerServiceProvider {
350 container: Arc<Container>,
351 scope_storage: Option<ScopeStorage>,
352}
353
354impl ContainerServiceProvider {
355 fn new(container: Arc<Container>, scope_storage: Option<ScopeStorage>) -> Self {
356 Self {
357 container,
358 scope_storage,
359 }
360 }
361}
362
363impl DescriptorServiceProvider for ServiceProvider {
364 fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
365 let inner_provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
366 inner_provider.get_service_raw(key)
367 }
368}
369
370impl DescriptorServiceProvider for ContainerServiceProvider {
371 fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
372 let descriptor = match self.container.get_descriptor(key)? {
374 Some(desc) => desc,
375 None => return Ok(None),
376 };
377
378 match descriptor.lifetime {
380 Lifetime::Singleton => {
381 if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
383 return Ok(Some(Arc::clone(&cached)));
384 }
385
386 let inner_provider =
388 ContainerServiceProvider::new(Arc::clone(&self.container), None);
389 let instance = descriptor.create_instance(&inner_provider)?;
390 let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
391 SINGLETON_SERVICES.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
392 Ok(Some(any_arc))
393 }
394 Lifetime::Scoped => {
395 if let Some(storage) = &self.scope_storage {
396 {
398 let storage_guard = storage.read().map_err(|_| {
399 DiError::generic("Failed to acquire scope storage read lock")
400 })?;
401
402 if let Some(cached) = storage_guard.get(&descriptor.service_key) {
403 return Ok(Some(Arc::clone(cached)));
404 }
405 }
406
407 let inner_provider = ContainerServiceProvider::new(
409 Arc::clone(&self.container),
410 Some(storage.clone()),
411 );
412 let instance = descriptor.create_instance(&inner_provider)?;
413 let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
414
415 {
416 let mut storage_guard = storage.write().map_err(|_| {
417 DiError::generic("Failed to acquire scope storage write lock")
418 })?;
419 storage_guard.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
420 }
421
422 Ok(Some(any_arc))
423 } else {
424 Err(DiError::Generic {
425 message: format!(
426 "Scoped service cannot be resolved without a scope: {key:?}"
427 ),
428 })
429 }
430 }
431 Lifetime::Transient => {
432 let inner_provider = ContainerServiceProvider::new(
433 Arc::clone(&self.container),
434 self.scope_storage.clone(),
435 );
436 let instance = descriptor.create_instance(&inner_provider)?;
437 let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
438 Ok(Some(any_arc))
439 }
440 }
441 }
442}
443
444pub struct ServiceScope {
446 container: Arc<Container>,
447 storage: ScopeStorage,
448 disposed: Arc<Mutex<bool>>,
449}
450
451impl ServiceScope {
452 pub fn new(container: Arc<Container>) -> DiResult<Self> {
454 Ok(Self {
455 container,
456 storage: Arc::new(RwLock::new(HashMap::new())),
457 disposed: Arc::new(Mutex::new(false)),
458 })
459 }
460
461 fn ensure_not_disposed(&self) -> DiResult<()> {
463 let disposed = self
464 .disposed
465 .lock()
466 .map_err(|_| DiError::generic("Failed to acquire disposed lock"))?;
467
468 if *disposed {
469 return Err(DiError::ScopeDisposed);
470 }
471
472 Ok(())
473 }
474
475 pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
477 self.ensure_not_disposed()?;
478 let provider = ServiceProvider::new(Arc::clone(&self.container));
479
480 let descriptors = provider.get_all_descriptors_for_type::<T>()?;
481 let mut services = Vec::new();
482
483 for descriptor in descriptors {
484 if let Some(service) =
485 provider.resolve_service::<T>(&descriptor.service_key, Some(&self.storage))?
486 {
487 services.push(service);
488 }
489 }
490
491 Ok(services)
492 }
493
494 pub fn create_scope(&self) -> DiResult<ServiceScope> {
496 self.ensure_not_disposed()?;
497 ServiceScope::new(Arc::clone(&self.container))
498 }
499
500 pub fn dispose(&mut self) {
502 if let Ok(mut disposed) = self.disposed.lock() {
503 if !*disposed {
504 *disposed = true;
505
506 if let Ok(mut storage) = self.storage.write() {
508 storage.clear();
509 }
510 }
511 }
512 }
513
514 pub fn is_disposed(&self) -> bool {
516 self.disposed
517 .lock()
518 .map(|disposed| *disposed)
519 .unwrap_or(true)
520 }
521}
522
523impl DescriptorServiceProvider for ServiceScope {
524 fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
525 self.ensure_not_disposed()?;
526 let inner_provider =
527 ContainerServiceProvider::new(Arc::clone(&self.container), Some(self.storage.clone()));
528 inner_provider.get_service_raw(key)
529 }
530}
531
532impl Drop for ServiceScope {
533 fn drop(&mut self) {
534 self.dispose();
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::descriptor::ServiceProviderExt;
542 use crate::ServiceDescriptor;
543
544 #[derive(Debug, Clone, PartialEq)]
545 struct TestService {
546 value: i32,
547 }
548
549 #[derive(Debug, Clone, PartialEq)]
550 #[allow(dead_code)]
551 struct DependentService {
552 dependency: Arc<TestService>,
553 }
554
555 #[test]
556 fn test_container_creation() {
557 let container = Container::new();
558 assert!(!container.is_registered::<TestService>().unwrap());
559 }
560
561 #[test]
562 fn test_service_registration() {
563 let container = Container::new();
564
565 let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
566 Ok(Box::new(TestService { value: 42 }))
567 }));
568
569 container.register(descriptor).unwrap();
570 assert!(container.is_registered::<TestService>().unwrap());
571 }
572
573 #[test]
574 fn test_singleton_service_resolution() {
575 let container = Container::new();
576
577 let descriptor = ServiceDescriptor::singleton::<TestService, TestService>(Box::new(|_| {
578 Ok(Box::new(TestService { value: 100 }))
579 }));
580
581 container.register(descriptor).unwrap();
582
583 let provider = container.build();
584 let service1 = provider.get_required_service::<TestService>().unwrap();
585 let service2 = provider.get_required_service::<TestService>().unwrap();
586
587 assert_eq!(service1.value, 100);
588 assert_eq!(service2.value, 100);
589 }
590
591 #[test]
592 fn test_transient_service_resolution() {
593 let container = Container::new();
594
595 let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
596 Ok(Box::new(TestService { value: 200 }))
597 }));
598
599 container.register(descriptor).unwrap();
600
601 let provider = container.build();
602 let service1 = provider.get_required_service::<TestService>().unwrap();
603 let service2 = provider.get_required_service::<TestService>().unwrap();
604
605 assert_eq!(service1.value, 200);
606 assert_eq!(service2.value, 200);
607 }
608
609 #[test]
610 fn test_keyed_service_registration_and_resolution() {
611 let container = Container::new();
612
613 let descriptor = ServiceDescriptor::named_singleton::<TestService, TestService>(
614 "primary",
615 Box::new(|_| Ok(Box::new(TestService { value: 300 }))),
616 );
617
618 container.register(descriptor).unwrap();
619 assert!(container
620 .is_keyed_registered::<TestService>("primary")
621 .unwrap());
622 assert!(!container
623 .is_keyed_registered::<TestService>("secondary")
624 .unwrap());
625
626 let provider = container.build();
627 let service = provider
628 .get_required_keyed_service::<TestService>("primary")
629 .unwrap();
630 assert_eq!(service.value, 300);
631
632 let result = provider.get_keyed_service::<TestService>("nonexistent");
633 assert!(result.is_ok());
634 assert!(result.unwrap().is_none());
635 }
636
637 #[test]
638 fn test_scoped_service_with_scope() {
639 let container = Container::new();
640
641 let descriptor = ServiceDescriptor::scoped::<TestService, TestService>(Box::new(|_| {
642 Ok(Box::new(TestService { value: 400 }))
643 }));
644
645 container.register(descriptor).unwrap();
646
647 let provider = container.build();
648 let mut scope = provider.create_scope().unwrap();
649
650 let service1 = scope.get_required_service::<TestService>().unwrap();
651 let service2 = scope.get_required_service::<TestService>().unwrap();
652
653 assert_eq!(service1.value, 400);
654 assert_eq!(service2.value, 400);
655
656 scope.dispose();
657 }
658
659 #[test]
660 fn test_service_collection() {
661 let container = Container::new();
662
663 let desc1 = ServiceDescriptor::named_transient::<TestService, TestService>(
664 "service1",
665 Box::new(|_| Ok(Box::new(TestService { value: 1 }))),
666 );
667 let desc2 = ServiceDescriptor::named_transient::<TestService, TestService>(
668 "service2",
669 Box::new(|_| Ok(Box::new(TestService { value: 2 }))),
670 );
671
672 container.register(desc1).unwrap();
673 container.register(desc2).unwrap();
674
675 let provider = container.build();
676 let services = provider.get_services::<TestService>().unwrap();
677
678 assert_eq!(services.len(), 2);
679 let values: Vec<i32> = services.iter().map(|s| s.value).collect();
680 assert!(values.contains(&1));
681 assert!(values.contains(&2));
682 }
683}