1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use super::{DiError, ProviderDescriptor, Scope};
6
7type InstanceMap = HashMap<TypeId, Arc<dyn Any + Send + Sync>>;
8
9#[derive(Clone, Default)]
13pub struct Container {
14 providers: Arc<RwLock<HashMap<TypeId, ProviderDescriptor>>>,
15 singletons: Arc<RwLock<InstanceMap>>,
16 resolving: Arc<RwLock<Vec<TypeId>>>,
17}
18
19impl Container {
20 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn register<T: Send + Sync + 'static>(&self, descriptor: ProviderDescriptor) {
26 let type_id = TypeId::of::<T>();
27 self.providers
28 .write()
29 .expect("container providers lock poisoned")
30 .insert(type_id, descriptor);
31 }
32
33 pub fn register_singleton<T: Send + Sync + 'static>(&self, instance: Arc<T>) {
35 let type_id = TypeId::of::<T>();
36 self.singletons
37 .write()
38 .expect("container singletons lock poisoned")
39 .insert(type_id, instance);
40 }
41
42 pub fn register_default<T: Default + Send + Sync + 'static>(&self) {
44 self.register_singleton(Arc::new(T::default()));
45 }
46
47 pub fn register_instance<T: Send + Sync + 'static>(&self, instance: Arc<T>) {
49 self.register_singleton(instance);
50 }
51
52 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
54 let type_id = TypeId::of::<T>();
55
56 if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
57 return downcast_arc::<T>(existing.clone());
58 }
59
60 let descriptor = self
61 .providers
62 .read()
63 .expect("lock")
64 .get(&type_id)
65 .cloned()
66 .ok_or(DiError::not_found::<T>())?;
67
68 self.guard_circular(type_id)?;
69
70 if descriptor.scope == Scope::Singleton {
71 if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
72 self.resolving.write().expect("lock").pop();
73 return downcast_arc::<T>(existing.clone());
74 }
75 }
76
77 let instance = self.resolve_descriptor(&descriptor)?;
78 self.resolving.write().expect("lock").pop();
79
80 let arc = downcast_arc::<T>(instance)?;
81
82 if descriptor.scope == Scope::Singleton {
83 self.singletons
84 .write()
85 .expect("lock")
86 .insert(type_id, arc.clone());
87 }
88
89 Ok(arc)
90 }
91
92 pub async fn get_async<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
94 let type_id = TypeId::of::<T>();
95
96 if let Some(existing) = self.singletons.read().expect("lock").get(&type_id) {
97 return downcast_arc::<T>(existing.clone());
98 }
99
100 let descriptor = self
101 .providers
102 .read()
103 .expect("lock")
104 .get(&type_id)
105 .cloned()
106 .ok_or(DiError::not_found::<T>())?;
107
108 self.guard_circular(type_id)?;
109
110 let instance = if let Some(fut) = descriptor.factory.create_async(self) {
111 fut.await
112 } else if let Some(sync) = descriptor.factory.create_sync(self) {
113 sync
114 } else {
115 self.resolving.write().expect("lock").pop();
116 return Err(DiError::ResolutionFailed {
117 type_name: descriptor.type_name,
118 reason: "async factory required".into(),
119 });
120 };
121
122 self.resolving.write().expect("lock").pop();
123
124 let arc = downcast_arc::<T>(instance)?;
125
126 if descriptor.scope == Scope::Singleton {
127 self.singletons
128 .write()
129 .expect("lock")
130 .insert(type_id, arc.clone());
131 }
132
133 Ok(arc)
134 }
135
136 pub fn request_scope(&self) -> RequestScope<'_> {
137 RequestScope {
138 parent: self,
139 request_instances: RwLock::new(HashMap::new()),
140 }
141 }
142
143 fn guard_circular(&self, type_id: TypeId) -> Result<(), DiError> {
144 let mut resolving = self.resolving.write().expect("lock");
145 if resolving.contains(&type_id) {
146 return Err(DiError::CircularDependency(format!("{type_id:?}")));
147 }
148 resolving.push(type_id);
149 Ok(())
150 }
151
152 fn resolve_descriptor(
153 &self,
154 descriptor: &ProviderDescriptor,
155 ) -> Result<Arc<dyn Any + Send + Sync>, DiError> {
156 if let Some(instance) = descriptor.factory.create_sync(self) {
157 return Ok(instance);
158 }
159 Err(DiError::ResolutionFailed {
160 type_name: descriptor.type_name,
161 reason: "sync factory required; use get_async for async providers".into(),
162 })
163 }
164}
165
166pub struct RequestScope<'a> {
168 parent: &'a Container,
169 request_instances: RwLock<InstanceMap>,
170}
171
172impl<'a> RequestScope<'a> {
173 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, DiError> {
174 let type_id = TypeId::of::<T>();
175
176 if let Some(existing) = self.request_instances.read().expect("lock").get(&type_id) {
177 return downcast_arc::<T>(existing.clone());
178 }
179
180 let instance = self.parent.get::<T>()?;
181 self.request_instances
182 .write()
183 .expect("lock")
184 .insert(type_id, instance.clone());
185 downcast_arc::<T>(instance)
186 }
187}
188
189fn downcast_arc<T: Send + Sync + 'static>(
190 value: Arc<dyn Any + Send + Sync>,
191) -> Result<Arc<T>, DiError> {
192 Arc::downcast::<T>(value)
193 .map_err(|_| DiError::ResolutionFailed {
194 type_name: std::any::type_name::<T>(),
195 reason: "type mismatch in DI container".into(),
196 })
197}