1use crate::{DiResult, Lifetime, ServiceKey};
2use std::any::Any;
3use std::sync::Arc;
4
5pub type ServiceFactory =
8 Box<dyn Fn(&dyn ServiceProvider) -> DiResult<Box<dyn Any + Send + Sync>> + Send + Sync>;
9
10pub trait ServiceProvider: Send + Sync {
12 fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>>;
14}
15
16pub trait ServiceProviderExt: ServiceProvider {
18 fn get_service<T: 'static + Send + Sync>(&self) -> DiResult<Option<Arc<T>>> {
20 let key = ServiceKey::of_type::<T>();
21 if let Some(any_arc) = self.get_service_raw(&key)? {
22 self.downcast_arc::<T>(any_arc)
24 } else {
25 Ok(None)
26 }
27 }
28
29 fn get_required_service<T: 'static + Send + Sync>(&self) -> DiResult<Arc<T>> {
31 match self.get_service::<T>()? {
32 Some(service) => Ok(service),
33 None => Err(crate::DiError::service_not_registered::<T>()),
34 }
35 }
36
37 fn get_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Option<Arc<T>>> {
39 let service_key = ServiceKey::named::<T>(key);
40 if let Some(any_arc) = self.get_service_raw(&service_key)? {
41 self.downcast_arc::<T>(any_arc)
42 } else {
43 Ok(None)
44 }
45 }
46
47 fn get_required_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Arc<T>> {
49 match self.get_keyed_service::<T>(key)? {
50 Some(service) => Ok(service),
51 None => Err(crate::DiError::keyed_service_not_registered::<T>(key)),
52 }
53 }
54
55 fn downcast_arc<T: 'static + Send + Sync>(
57 &self,
58 any_arc: Arc<dyn Any + Send + Sync>,
59 ) -> DiResult<Option<Arc<T>>> {
60 match any_arc.downcast::<T>() {
62 Ok(typed_arc) => Ok(Some(typed_arc)),
63 Err(original_arc) => {
64 match original_arc.downcast::<Arc<T>>() {
66 Ok(arc_of_arc) => Ok(Some((*arc_of_arc).clone())),
67 Err(_) => Err(crate::DiError::type_casting_failed::<T>()),
68 }
69 }
70 }
71 }
72}
73
74impl<T: ServiceProvider + ?Sized> ServiceProviderExt for T {}
76
77#[derive(Clone)]
79pub struct ServiceDescriptor {
80 pub service_key: ServiceKey,
82
83 pub lifetime: Lifetime,
85
86 pub factory: Arc<ServiceFactory>,
88
89 pub service_type: std::any::TypeId,
91
92 pub implementation_type: std::any::TypeId,
94}
95
96impl ServiceDescriptor {
97 pub fn new<TService, TImplementation>(
99 service_key: ServiceKey,
100 lifetime: Lifetime,
101 factory: ServiceFactory,
102 ) -> Self
103 where
104 TService: 'static,
105 TImplementation: 'static,
106 {
107 Self {
108 service_key,
109 lifetime,
110 factory: Arc::new(factory),
111 service_type: std::any::TypeId::of::<TService>(),
112 implementation_type: std::any::TypeId::of::<TImplementation>(),
113 }
114 }
115
116 pub fn transient<TService, TImplementation>(factory: ServiceFactory) -> Self
118 where
119 TService: 'static,
120 TImplementation: 'static,
121 {
122 Self::new::<TService, TImplementation>(
123 ServiceKey::of_type::<TService>(),
124 Lifetime::Transient,
125 factory,
126 )
127 }
128
129 pub fn scoped<TService, TImplementation>(factory: ServiceFactory) -> Self
131 where
132 TService: 'static,
133 TImplementation: 'static,
134 {
135 Self::new::<TService, TImplementation>(
136 ServiceKey::of_type::<TService>(),
137 Lifetime::Scoped,
138 factory,
139 )
140 }
141
142 pub fn singleton<TService, TImplementation>(factory: ServiceFactory) -> Self
144 where
145 TService: 'static,
146 TImplementation: 'static,
147 {
148 Self::new::<TService, TImplementation>(
149 ServiceKey::of_type::<TService>(),
150 Lifetime::Singleton,
151 factory,
152 )
153 }
154
155 pub fn named_transient<TService, TImplementation>(
157 name: impl Into<String>,
158 factory: ServiceFactory,
159 ) -> Self
160 where
161 TService: 'static,
162 TImplementation: 'static,
163 {
164 Self::new::<TService, TImplementation>(
165 ServiceKey::named::<TService>(name),
166 Lifetime::Transient,
167 factory,
168 )
169 }
170
171 pub fn named_scoped<TService, TImplementation>(
173 name: impl Into<String>,
174 factory: ServiceFactory,
175 ) -> Self
176 where
177 TService: 'static,
178 TImplementation: 'static,
179 {
180 Self::new::<TService, TImplementation>(
181 ServiceKey::named::<TService>(name),
182 Lifetime::Scoped,
183 factory,
184 )
185 }
186
187 pub fn named_singleton<TService, TImplementation>(
189 name: impl Into<String>,
190 factory: ServiceFactory,
191 ) -> Self
192 where
193 TService: 'static,
194 TImplementation: 'static,
195 {
196 Self::new::<TService, TImplementation>(
197 ServiceKey::named::<TService>(name),
198 Lifetime::Singleton,
199 factory,
200 )
201 }
202
203 pub fn from_instance<TService>(instance: TService) -> Self
205 where
206 TService: Send + Sync + 'static,
207 {
208 let instance = Arc::new(instance);
209 Self::singleton::<TService, TService>(Box::new(move |_| {
210 Ok(Box::new(Arc::clone(&instance)))
211 }))
212 }
213
214 pub fn from_named_instance<TService>(name: impl Into<String>, instance: TService) -> Self
216 where
217 TService: Send + Sync + 'static,
218 {
219 let instance = Arc::new(instance);
220 Self::named_singleton::<TService, TService>(
221 name,
222 Box::new(move |_| Ok(Box::new(Arc::clone(&instance)))),
223 )
224 }
225
226 pub fn matches_key(&self, key: &ServiceKey) -> bool {
228 &self.service_key == key
229 }
230
231 pub fn matches_service_type<T: 'static>(&self) -> bool {
233 self.service_type == std::any::TypeId::of::<T>()
234 }
235
236 pub fn create_instance(
238 &self,
239 provider: &dyn ServiceProvider,
240 ) -> DiResult<Box<dyn Any + Send + Sync>> {
241 (self.factory)(provider)
242 }
243}
244
245impl std::fmt::Debug for ServiceDescriptor {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("ServiceDescriptor")
248 .field("service_key", &self.service_key)
249 .field("lifetime", &self.lifetime)
250 .field("service_type", &self.service_type)
251 .field("implementation_type", &self.implementation_type)
252 .finish()
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[derive(Clone)]
261 #[allow(dead_code)]
262 struct TestService {
263 value: i32,
264 }
265
266 struct MockServiceProvider;
267
268 impl ServiceProvider for MockServiceProvider {
269 fn get_service_raw(
270 &self,
271 _key: &ServiceKey,
272 ) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
273 Ok(None)
274 }
275 }
276
277 #[test]
278 fn test_descriptor_creation() {
279 let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
280 Ok(Box::new(TestService { value: 42 }))
281 }));
282
283 assert_eq!(descriptor.lifetime, Lifetime::Transient);
284 assert!(descriptor.matches_service_type::<TestService>());
285 assert!(descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
286 }
287
288 #[test]
289 fn test_from_instance() {
290 let service = TestService { value: 100 };
291 let descriptor = ServiceDescriptor::from_instance(service);
292
293 assert_eq!(descriptor.lifetime, Lifetime::Singleton);
294 assert!(descriptor.matches_service_type::<TestService>());
295
296 let provider = MockServiceProvider;
297 let result = descriptor.create_instance(&provider);
298 assert!(result.is_ok());
299 }
300
301 #[test]
302 fn test_named_descriptor() {
303 let descriptor = ServiceDescriptor::named_scoped::<TestService, TestService>(
304 "test-service",
305 Box::new(|_| Ok(Box::new(TestService { value: 200 }))),
306 );
307
308 assert_eq!(descriptor.lifetime, Lifetime::Scoped);
309 assert!(descriptor.matches_key(&ServiceKey::named::<TestService>("test-service")));
310 assert!(!descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
311 }
312}