cardinal_base/
context.rs

1use crate::provider::{Provider, ProviderScope};
2use cardinal_config::CardinalConfig;
3use cardinal_errors::CardinalError;
4use parking_lot::{Mutex, RwLock};
5use std::any::{Any, TypeId};
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::marker::PhantomData;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub struct CardinalContext {
13    pub config: Arc<CardinalConfig>,
14    scopes: RwLock<HashMap<TypeId, ProviderScope>>, // registered scopes for types
15    singletons: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>, // cached singleton instances
16    constructing: Mutex<HashSet<TypeId>>,           // basic cycle detection
17    factories: RwLock<HashMap<TypeId, Arc<dyn ProviderFactory>>>,
18}
19
20impl CardinalContext {
21    pub fn new(config: CardinalConfig) -> Self {
22        Self {
23            config: Arc::new(config),
24            scopes: RwLock::new(HashMap::new()),
25            singletons: RwLock::new(HashMap::new()),
26            constructing: Mutex::new(HashSet::new()),
27            factories: RwLock::new(HashMap::new()),
28        }
29    }
30
31    // Register a scope for concrete type T. Overwrites existing scope if re-registered.
32    pub fn register<T>(&self, scope: ProviderScope)
33    where
34        T: Provider + Send + Sync + 'static,
35    {
36        let tid = TypeId::of::<T>();
37        let mut map = self.scopes.write();
38        map.insert(tid, scope);
39    }
40
41    pub fn register_with_factory<T, F, Fut>(&self, scope: ProviderScope, factory: F)
42    where
43        T: Provider + Send + Sync + 'static,
44        F: Fn(&CardinalContext) -> Fut + Send + Sync + 'static,
45        Fut: Future<Output = Result<T, CardinalError>> + Send + 'static,
46    {
47        let tid = TypeId::of::<T>();
48        let factory = Arc::new(TypedFactory::<T, F> {
49            inner: factory,
50            _marker: PhantomData,
51        }) as Arc<dyn ProviderFactory>;
52
53        self.factories.write().insert(tid, factory);
54        self.register::<T>(scope);
55    }
56
57    pub fn register_singleton_instance<T>(&self, instance: Arc<T>)
58    where
59        T: Provider + Send + Sync + 'static,
60    {
61        let tid = TypeId::of::<T>();
62        self.register::<T>(ProviderScope::Singleton);
63        let erased: Arc<dyn Any + Send + Sync> = instance;
64        self.singletons.write().insert(tid, erased);
65        self.factories.write().remove(&tid);
66    }
67
68    pub fn is_registered<T>(&self) -> bool
69    where
70        T: Provider + Send + Sync + 'static,
71    {
72        let tid = TypeId::of::<T>();
73        self.scopes.read().contains_key(&tid)
74    }
75
76    // Lazily constructs values on first access and caches singletons.
77    pub async fn get<T>(&self) -> Result<Arc<T>, CardinalError>
78    where
79        T: Provider + Send + Sync + 'static,
80    {
81        let tid = TypeId::of::<T>();
82
83        // Determine scope for T
84        let scope = {
85            let map = self.scopes.read();
86            match map.get(&tid) {
87                Some(s) => *s,
88                None => {
89                    return Err(CardinalError::InternalError(
90                        cardinal_errors::internal::CardinalInternalError::ProviderNotRegistered,
91                    ))
92                }
93            }
94        };
95
96        match scope {
97            ProviderScope::Singleton => {
98                // Fast path: already cached
99                if let Some(existing) = self.singletons.read().get(&tid).cloned() {
100                    return existing
101                        .downcast::<T>()
102                        .map_err(|_| CardinalError::InternalError(cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch));
103                }
104
105                // Build with cycle detection
106                let guard = match self.try_mark_constructing(tid) {
107                    Ok(g) => g,
108                    Err(e) => return Err(e),
109                };
110                let factory = self.factory_for::<T>();
111                let erased: Arc<dyn Any + Send + Sync> = match factory {
112                    Some(factory) => factory.create(self).await?,
113                    None => Arc::new(T::provide(self).await?) as Arc<dyn Any + Send + Sync>,
114                };
115                drop(guard);
116
117                // Insert into cache if still absent; another thread might have inserted meanwhile
118                {
119                    let mut cache = self.singletons.write();
120                    cache.entry(tid).or_insert(erased.clone());
121                }
122
123                // Return the (possibly newly) cached value
124                Arc::downcast::<T>(erased).map_err(|_| {
125                    CardinalError::InternalError(
126                        cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch,
127                    )
128                })
129            }
130            ProviderScope::Transient => {
131                // Build with cycle detection, do not cache
132                let guard = match self.try_mark_constructing(tid) {
133                    Ok(g) => g,
134                    Err(e) => return Err(e),
135                };
136                let factory = self.factory_for::<T>();
137                let erased: Arc<dyn Any + Send + Sync> = match factory {
138                    Some(factory) => factory.create(self).await?,
139                    None => Arc::new(T::provide(self).await?) as Arc<dyn Any + Send + Sync>,
140                };
141                drop(guard);
142                Arc::downcast::<T>(erased).map_err(|_| {
143                    CardinalError::InternalError(
144                        cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch,
145                    )
146                })
147            }
148        }
149    }
150
151    // Convenience that just calls get<T>(), intended for startup pre-warming.
152    pub async fn build_eager<T>(&self) -> Result<Arc<T>, CardinalError>
153    where
154        T: Provider + Send + Sync + 'static,
155    {
156        self.get::<T>().await
157    }
158
159    fn try_mark_constructing(&self, tid: TypeId) -> Result<ConstructGuard<'_>, CardinalError> {
160        let mut set = self.constructing.lock();
161        if set.contains(&tid) {
162            return Err(CardinalError::InternalError(
163                cardinal_errors::internal::CardinalInternalError::DependencyCycleDetected,
164            ));
165        }
166        set.insert(tid);
167        Ok(ConstructGuard { ctx: self, tid })
168    }
169
170    fn unmark_constructing(&self, tid: TypeId) {
171        let mut set = self.constructing.lock();
172        set.remove(&tid);
173    }
174
175    fn factory_for<T>(&self) -> Option<Arc<dyn ProviderFactory>>
176    where
177        T: Provider + Send + Sync + 'static,
178    {
179        let tid = TypeId::of::<T>();
180        self.factories.read().get(&tid).cloned()
181    }
182}
183
184// RAII guard for the constructing set, to ensure cleanup on early returns
185struct ConstructGuard<'a> {
186    ctx: &'a CardinalContext,
187    tid: TypeId,
188}
189
190impl<'a> Drop for ConstructGuard<'a> {
191    fn drop(&mut self) {
192        self.ctx.unmark_constructing(self.tid);
193    }
194}
195
196type ProviderFuture =
197    Pin<Box<dyn Future<Output = Result<Arc<dyn Any + Send + Sync>, CardinalError>> + Send>>;
198
199trait ProviderFactory: Send + Sync {
200    fn create(&self, ctx: &CardinalContext) -> ProviderFuture;
201}
202
203struct TypedFactory<T, F> {
204    inner: F,
205    _marker: PhantomData<T>,
206}
207
208impl<T, F, Fut> ProviderFactory for TypedFactory<T, F>
209where
210    T: Provider + Send + Sync + 'static,
211    F: Fn(&CardinalContext) -> Fut + Send + Sync + 'static,
212    Fut: Future<Output = Result<T, CardinalError>> + Send + 'static,
213{
214    fn create(&self, ctx: &CardinalContext) -> ProviderFuture {
215        let fut = (self.inner)(ctx);
216        Box::pin(async move {
217            let value = fut.await?;
218            Ok(Arc::new(value) as Arc<dyn Any + Send + Sync>)
219        })
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use async_trait::async_trait;
227    use cardinal_errors::CardinalError;
228
229    #[derive(Debug)]
230    struct Db {
231        dsn: String,
232    }
233
234    #[derive(Debug)]
235    struct Repo {
236        db: Arc<Db>,
237    }
238
239    #[derive(Debug)]
240    struct Service {
241        repo: Arc<Repo>,
242    }
243
244    #[async_trait]
245    impl Provider for Db {
246        async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
247            Ok(Db { dsn: "dsn".into() })
248        }
249    }
250
251    #[async_trait]
252    impl Provider for Repo {
253        async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
254            Ok(Repo {
255                db: ctx.get::<Db>().await?,
256            })
257        }
258    }
259
260    #[async_trait]
261    impl Provider for Service {
262        async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
263            Ok(Service {
264                repo: ctx.get::<Repo>().await?,
265            })
266        }
267    }
268
269    fn get_context() -> CardinalContext {
270        CardinalContext::new(CardinalConfig::default())
271    }
272
273    #[tokio::test]
274    async fn register_with_factory_preempts_default_provider() {
275        #[derive(Debug, Clone, PartialEq, Eq)]
276        struct StringProvider(pub String);
277
278        #[async_trait]
279        impl Provider for StringProvider {
280            async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
281                Ok(StringProvider("Hello".to_string()))
282            }
283        }
284
285        let ctx = get_context();
286        ctx.register_with_factory::<StringProvider, _, _>(ProviderScope::Singleton, |_ctx| async {
287            Ok(StringProvider("Overridden".to_string()))
288        });
289
290        assert!(ctx.is_registered::<StringProvider>());
291        let a = ctx.get::<StringProvider>().await.unwrap();
292        let b = ctx.get::<StringProvider>().await.unwrap();
293        assert!(Arc::ptr_eq(&a, &b));
294        assert_eq!(a.0, "Overridden");
295    }
296
297    #[tokio::test]
298    async fn register_singleton_instance_returns_same_arc() {
299        #[derive(Debug)]
300        struct Static;
301
302        #[async_trait]
303        impl Provider for Static {
304            async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
305                Ok(Static)
306            }
307        }
308
309        let ctx = get_context();
310        let instance = Arc::new(Static);
311        ctx.register_singleton_instance::<Static>(instance.clone());
312
313        let a = ctx.get::<Static>().await.unwrap();
314        let b = ctx.get::<Static>().await.unwrap();
315        assert!(Arc::ptr_eq(&a, &b));
316        assert!(Arc::ptr_eq(&a, &instance));
317    }
318
319    #[tokio::test]
320    async fn singleton_reuse_same_arc() {
321        let ctx = get_context();
322        ctx.register::<Db>(ProviderScope::Singleton);
323
324        let a = ctx.get::<Db>().await.unwrap();
325        let b = ctx.get::<Db>().await.unwrap();
326        assert!(Arc::ptr_eq(&a, &b));
327    }
328
329    #[tokio::test]
330    async fn transient_returns_new_arc_each_time() {
331        // Use Service/Repo/Db wiring: Service is transient; Repo and Db singletons
332        let ctx = get_context();
333        ctx.register::<Db>(ProviderScope::Singleton);
334        ctx.register::<Repo>(ProviderScope::Singleton);
335        ctx.register::<Service>(ProviderScope::Transient);
336
337        let a = ctx.get::<Service>().await.unwrap();
338        let b = ctx.get::<Service>().await.unwrap();
339        assert!(!Arc::ptr_eq(&a, &b));
340    }
341
342    #[tokio::test]
343    async fn nested_dependencies_singletons_reused_transient_recreated() {
344        let ctx = get_context();
345        ctx.register::<Db>(ProviderScope::Singleton);
346        ctx.register::<Repo>(ProviderScope::Singleton);
347        ctx.register::<Service>(ProviderScope::Transient);
348
349        let s1 = ctx.get::<Service>().await.unwrap();
350        let s2 = ctx.get::<Service>().await.unwrap();
351
352        assert!(!Arc::ptr_eq(&s1, &s2));
353        assert!(Arc::ptr_eq(&s1.repo, &s2.repo));
354        assert!(Arc::ptr_eq(&s1.repo.db, &s2.repo.db));
355    }
356
357    struct UnregisteredType;
358
359    #[async_trait]
360    impl Provider for UnregisteredType {
361        async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
362            Ok(UnregisteredType)
363        }
364    }
365
366    #[tokio::test]
367    async fn unregistered_type_errors() {
368        let ctx = get_context();
369        let res = ctx.get::<UnregisteredType>().await;
370        assert!(matches!(
371            res,
372            Err(CardinalError::InternalError(
373                cardinal_errors::internal::CardinalInternalError::ProviderNotRegistered
374            ))
375        ));
376    }
377
378    #[derive(Debug)]
379    struct A(Arc<B>);
380    #[derive(Debug)]
381    struct B(Arc<A>);
382
383    #[async_trait]
384    impl Provider for A {
385        async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
386            Ok(A(ctx.get::<B>().await?))
387        }
388    }
389
390    #[async_trait]
391    impl Provider for B {
392        async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
393            Ok(B(ctx.get::<A>().await?))
394        }
395    }
396
397    #[tokio::test]
398    async fn simple_cycle_errors() {
399        let ctx = get_context();
400        ctx.register::<A>(ProviderScope::Transient);
401        ctx.register::<B>(ProviderScope::Transient);
402
403        let res = ctx.get::<A>().await;
404        assert!(matches!(
405            res,
406            Err(CardinalError::InternalError(
407                cardinal_errors::internal::CardinalInternalError::DependencyCycleDetected
408            ))
409        ));
410    }
411}