flow_di/
descriptor.rs

1use crate::{DiResult, Lifetime, ServiceKey};
2use std::any::Any;
3use std::sync::Arc;
4
5/// Service factory function type
6/// Receives a service resolver and returns a boxed Any object
7pub type ServiceFactory =
8    Box<dyn Fn(&dyn ServiceProvider) -> DiResult<Box<dyn Any + Send + Sync>> + Send + Sync>;
9
10/// Simplified service provider interface - core trait (object-safe)
11pub trait ServiceProvider: Send + Sync {
12    /// Get the raw implementation of a service
13    fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>>;
14}
15
16/// Service provider extension trait - contains generic methods
17pub trait ServiceProviderExt: ServiceProvider {
18    /// Get a service of the specified type
19    fn get_service<T: 'static + Send + Sync>(&self) -> DiResult<Option<Arc<T>>> {
20        let key = ServiceKey::of_type::<T>();
21        if let Some(any_arc) = self.get_service_raw(&key)? {
22            // Use custom conversion method
23            self.downcast_arc::<T>(any_arc)
24        } else {
25            Ok(None)
26        }
27    }
28
29    /// Get a required service of the specified type
30    fn get_required_service<T: 'static + Send + Sync>(&self) -> DiResult<Arc<T>> {
31        match self.get_service::<T>()? {
32            Some(service) => Ok(service),
33            None => Err(crate::DiError::service_not_registered::<T>()),
34        }
35    }
36
37    /// Get a service of the specified name and type
38    fn get_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Option<Arc<T>>> {
39        let service_key = ServiceKey::named::<T>(key);
40        if let Some(any_arc) = self.get_service_raw(&service_key)? {
41            self.downcast_arc::<T>(any_arc)
42        } else {
43            Ok(None)
44        }
45    }
46
47    /// Get a required service of the specified name and type
48    fn get_required_keyed_service<T: 'static + Send + Sync>(&self, key: &str) -> DiResult<Arc<T>> {
49        match self.get_keyed_service::<T>(key)? {
50            Some(service) => Ok(service),
51            None => Err(crate::DiError::keyed_service_not_registered::<T>(key)),
52        }
53    }
54
55    /// Convert Arc<dyn Any> to Arc<T>
56    fn downcast_arc<T: 'static + Send + Sync>(
57        &self,
58        any_arc: Arc<dyn Any + Send + Sync>,
59    ) -> DiResult<Option<Arc<T>>> {
60        // First try to downcast directly to T
61        match any_arc.downcast::<T>() {
62            Ok(typed_arc) => Ok(Some(typed_arc)),
63            Err(original_arc) => {
64                // If failed, try to downcast to Arc<T> (handle double wrapping case)
65                match original_arc.downcast::<Arc<T>>() {
66                    Ok(arc_of_arc) => Ok(Some((*arc_of_arc).clone())),
67                    Err(_) => Err(crate::DiError::type_casting_failed::<T>()),
68                }
69            }
70        }
71    }
72}
73
74/// Automatically implement ServiceProviderExt for all types that implement ServiceProvider
75impl<T: ServiceProvider + ?Sized> ServiceProviderExt for T {}
76
77/// Service descriptor - describes how to create and manage service instances
78#[derive(Clone)]
79pub struct ServiceDescriptor {
80    /// Service key for uniquely identifying services
81    pub service_key: ServiceKey,
82
83    /// Service lifetime
84    pub lifetime: Lifetime,
85
86    /// Service factory function for creating service instances
87    pub factory: Arc<ServiceFactory>,
88
89    /// Service type ID
90    pub service_type: std::any::TypeId,
91
92    /// Implementation type ID (may differ from service type)
93    pub implementation_type: std::any::TypeId,
94}
95
96impl ServiceDescriptor {
97    /// Create a new service descriptor
98    pub fn new<TService, TImplementation>(
99        service_key: ServiceKey,
100        lifetime: Lifetime,
101        factory: ServiceFactory,
102    ) -> Self
103    where
104        TService: 'static,
105        TImplementation: 'static,
106    {
107        Self {
108            service_key,
109            lifetime,
110            factory: Arc::new(factory),
111            service_type: std::any::TypeId::of::<TService>(),
112            implementation_type: std::any::TypeId::of::<TImplementation>(),
113        }
114    }
115
116    /// Create a type-based transient service descriptor
117    pub fn transient<TService, TImplementation>(factory: ServiceFactory) -> Self
118    where
119        TService: 'static,
120        TImplementation: 'static,
121    {
122        Self::new::<TService, TImplementation>(
123            ServiceKey::of_type::<TService>(),
124            Lifetime::Transient,
125            factory,
126        )
127    }
128
129    /// Create a type-based scoped service descriptor
130    pub fn scoped<TService, TImplementation>(factory: ServiceFactory) -> Self
131    where
132        TService: 'static,
133        TImplementation: 'static,
134    {
135        Self::new::<TService, TImplementation>(
136            ServiceKey::of_type::<TService>(),
137            Lifetime::Scoped,
138            factory,
139        )
140    }
141
142    /// Create a type-based singleton service descriptor
143    pub fn singleton<TService, TImplementation>(factory: ServiceFactory) -> Self
144    where
145        TService: 'static,
146        TImplementation: 'static,
147    {
148        Self::new::<TService, TImplementation>(
149            ServiceKey::of_type::<TService>(),
150            Lifetime::Singleton,
151            factory,
152        )
153    }
154
155    /// Create a name-based transient service descriptor
156    pub fn named_transient<TService, TImplementation>(
157        name: impl Into<String>,
158        factory: ServiceFactory,
159    ) -> Self
160    where
161        TService: 'static,
162        TImplementation: 'static,
163    {
164        Self::new::<TService, TImplementation>(
165            ServiceKey::named::<TService>(name),
166            Lifetime::Transient,
167            factory,
168        )
169    }
170
171    /// Create a name-based scoped service descriptor
172    pub fn named_scoped<TService, TImplementation>(
173        name: impl Into<String>,
174        factory: ServiceFactory,
175    ) -> Self
176    where
177        TService: 'static,
178        TImplementation: 'static,
179    {
180        Self::new::<TService, TImplementation>(
181            ServiceKey::named::<TService>(name),
182            Lifetime::Scoped,
183            factory,
184        )
185    }
186
187    /// Create a name-based singleton service descriptor
188    pub fn named_singleton<TService, TImplementation>(
189        name: impl Into<String>,
190        factory: ServiceFactory,
191    ) -> Self
192    where
193        TService: 'static,
194        TImplementation: 'static,
195    {
196        Self::new::<TService, TImplementation>(
197            ServiceKey::named::<TService>(name),
198            Lifetime::Singleton,
199            factory,
200        )
201    }
202
203    /// Create a singleton service descriptor from an instance
204    pub fn from_instance<TService>(instance: TService) -> Self
205    where
206        TService: Send + Sync + 'static,
207    {
208        let instance = Arc::new(instance);
209        Self::singleton::<TService, TService>(Box::new(move |_| {
210            Ok(Box::new(Arc::clone(&instance)))
211        }))
212    }
213
214    /// Create a singleton service descriptor from a named instance
215    pub fn from_named_instance<TService>(name: impl Into<String>, instance: TService) -> Self
216    where
217        TService: Send + Sync + 'static,
218    {
219        let instance = Arc::new(instance);
220        Self::named_singleton::<TService, TService>(
221            name,
222            Box::new(move |_| Ok(Box::new(Arc::clone(&instance)))),
223        )
224    }
225
226    /// Check if the service key matches
227    pub fn matches_key(&self, key: &ServiceKey) -> bool {
228        &self.service_key == key
229    }
230
231    /// Check if the service type matches
232    pub fn matches_service_type<T: 'static>(&self) -> bool {
233        self.service_type == std::any::TypeId::of::<T>()
234    }
235
236    /// Create a service instance
237    pub fn create_instance(
238        &self,
239        provider: &dyn ServiceProvider,
240    ) -> DiResult<Box<dyn Any + Send + Sync>> {
241        (self.factory)(provider)
242    }
243}
244
245impl std::fmt::Debug for ServiceDescriptor {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.debug_struct("ServiceDescriptor")
248            .field("service_key", &self.service_key)
249            .field("lifetime", &self.lifetime)
250            .field("service_type", &self.service_type)
251            .field("implementation_type", &self.implementation_type)
252            .finish()
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[derive(Clone)]
261    #[allow(dead_code)]
262    struct TestService {
263        value: i32,
264    }
265
266    struct MockServiceProvider;
267
268    impl ServiceProvider for MockServiceProvider {
269        fn get_service_raw(
270            &self,
271            _key: &ServiceKey,
272        ) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
273            Ok(None)
274        }
275    }
276
277    #[test]
278    fn test_descriptor_creation() {
279        let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
280            Ok(Box::new(TestService { value: 42 }))
281        }));
282
283        assert_eq!(descriptor.lifetime, Lifetime::Transient);
284        assert!(descriptor.matches_service_type::<TestService>());
285        assert!(descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
286    }
287
288    #[test]
289    fn test_from_instance() {
290        let service = TestService { value: 100 };
291        let descriptor = ServiceDescriptor::from_instance(service);
292
293        assert_eq!(descriptor.lifetime, Lifetime::Singleton);
294        assert!(descriptor.matches_service_type::<TestService>());
295
296        let provider = MockServiceProvider;
297        let result = descriptor.create_instance(&provider);
298        assert!(result.is_ok());
299    }
300
301    #[test]
302    fn test_named_descriptor() {
303        let descriptor = ServiceDescriptor::named_scoped::<TestService, TestService>(
304            "test-service",
305            Box::new(|_| Ok(Box::new(TestService { value: 200 }))),
306        );
307
308        assert_eq!(descriptor.lifetime, Lifetime::Scoped);
309        assert!(descriptor.matches_key(&ServiceKey::named::<TestService>("test-service")));
310        assert!(!descriptor.matches_key(&ServiceKey::of_type::<TestService>()));
311    }
312}