1use crate::provider::{Provider, ProviderScope};
2use cardinal_config::CardinalConfig;
3use cardinal_errors::CardinalError;
4use parking_lot::{Mutex, RwLock};
5use std::any::{Any, TypeId};
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::marker::PhantomData;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub struct CardinalContext {
13 pub config: Arc<CardinalConfig>,
14 scopes: RwLock<HashMap<TypeId, ProviderScope>>, singletons: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>, constructing: Mutex<HashSet<TypeId>>, factories: RwLock<HashMap<TypeId, Arc<dyn ProviderFactory>>>,
18}
19
20impl CardinalContext {
21 pub fn new(config: CardinalConfig) -> Self {
22 Self {
23 config: Arc::new(config),
24 scopes: RwLock::new(HashMap::new()),
25 singletons: RwLock::new(HashMap::new()),
26 constructing: Mutex::new(HashSet::new()),
27 factories: RwLock::new(HashMap::new()),
28 }
29 }
30
31 pub fn register<T>(&self, scope: ProviderScope)
33 where
34 T: Provider + Send + Sync + 'static,
35 {
36 let tid = TypeId::of::<T>();
37 let mut map = self.scopes.write();
38 map.insert(tid, scope);
39 }
40
41 pub fn register_with_factory<T, F, Fut>(&self, scope: ProviderScope, factory: F)
42 where
43 T: Provider + Send + Sync + 'static,
44 F: Fn(&CardinalContext) -> Fut + Send + Sync + 'static,
45 Fut: Future<Output = Result<T, CardinalError>> + Send + 'static,
46 {
47 let tid = TypeId::of::<T>();
48 let factory = Arc::new(TypedFactory::<T, F> {
49 inner: factory,
50 _marker: PhantomData,
51 }) as Arc<dyn ProviderFactory>;
52
53 self.factories.write().insert(tid, factory);
54 self.register::<T>(scope);
55 }
56
57 pub fn register_singleton_instance<T>(&self, instance: Arc<T>)
58 where
59 T: Provider + Send + Sync + 'static,
60 {
61 let tid = TypeId::of::<T>();
62 self.register::<T>(ProviderScope::Singleton);
63 let erased: Arc<dyn Any + Send + Sync> = instance;
64 self.singletons.write().insert(tid, erased);
65 self.factories.write().remove(&tid);
66 }
67
68 pub fn is_registered<T>(&self) -> bool
69 where
70 T: Provider + Send + Sync + 'static,
71 {
72 let tid = TypeId::of::<T>();
73 self.scopes.read().contains_key(&tid)
74 }
75
76 pub async fn get<T>(&self) -> Result<Arc<T>, CardinalError>
78 where
79 T: Provider + Send + Sync + 'static,
80 {
81 let tid = TypeId::of::<T>();
82
83 let scope = {
85 let map = self.scopes.read();
86 match map.get(&tid) {
87 Some(s) => *s,
88 None => {
89 return Err(CardinalError::InternalError(
90 cardinal_errors::internal::CardinalInternalError::ProviderNotRegistered,
91 ))
92 }
93 }
94 };
95
96 match scope {
97 ProviderScope::Singleton => {
98 if let Some(existing) = self.singletons.read().get(&tid).cloned() {
100 return existing
101 .downcast::<T>()
102 .map_err(|_| CardinalError::InternalError(cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch));
103 }
104
105 let guard = match self.try_mark_constructing(tid) {
107 Ok(g) => g,
108 Err(e) => return Err(e),
109 };
110 let factory = self.factory_for::<T>();
111 let erased: Arc<dyn Any + Send + Sync> = match factory {
112 Some(factory) => factory.create(self).await?,
113 None => Arc::new(T::provide(self).await?) as Arc<dyn Any + Send + Sync>,
114 };
115 drop(guard);
116
117 {
119 let mut cache = self.singletons.write();
120 cache.entry(tid).or_insert(erased.clone());
121 }
122
123 Arc::downcast::<T>(erased).map_err(|_| {
125 CardinalError::InternalError(
126 cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch,
127 )
128 })
129 }
130 ProviderScope::Transient => {
131 let guard = match self.try_mark_constructing(tid) {
133 Ok(g) => g,
134 Err(e) => return Err(e),
135 };
136 let factory = self.factory_for::<T>();
137 let erased: Arc<dyn Any + Send + Sync> = match factory {
138 Some(factory) => factory.create(self).await?,
139 None => Arc::new(T::provide(self).await?) as Arc<dyn Any + Send + Sync>,
140 };
141 drop(guard);
142 Arc::downcast::<T>(erased).map_err(|_| {
143 CardinalError::InternalError(
144 cardinal_errors::internal::CardinalInternalError::DependencyTypeMismatch,
145 )
146 })
147 }
148 }
149 }
150
151 pub async fn build_eager<T>(&self) -> Result<Arc<T>, CardinalError>
153 where
154 T: Provider + Send + Sync + 'static,
155 {
156 self.get::<T>().await
157 }
158
159 fn try_mark_constructing(&self, tid: TypeId) -> Result<ConstructGuard<'_>, CardinalError> {
160 let mut set = self.constructing.lock();
161 if set.contains(&tid) {
162 return Err(CardinalError::InternalError(
163 cardinal_errors::internal::CardinalInternalError::DependencyCycleDetected,
164 ));
165 }
166 set.insert(tid);
167 Ok(ConstructGuard { ctx: self, tid })
168 }
169
170 fn unmark_constructing(&self, tid: TypeId) {
171 let mut set = self.constructing.lock();
172 set.remove(&tid);
173 }
174
175 fn factory_for<T>(&self) -> Option<Arc<dyn ProviderFactory>>
176 where
177 T: Provider + Send + Sync + 'static,
178 {
179 let tid = TypeId::of::<T>();
180 self.factories.read().get(&tid).cloned()
181 }
182}
183
184struct ConstructGuard<'a> {
186 ctx: &'a CardinalContext,
187 tid: TypeId,
188}
189
190impl<'a> Drop for ConstructGuard<'a> {
191 fn drop(&mut self) {
192 self.ctx.unmark_constructing(self.tid);
193 }
194}
195
196type ProviderFuture =
197 Pin<Box<dyn Future<Output = Result<Arc<dyn Any + Send + Sync>, CardinalError>> + Send>>;
198
199trait ProviderFactory: Send + Sync {
200 fn create(&self, ctx: &CardinalContext) -> ProviderFuture;
201}
202
203struct TypedFactory<T, F> {
204 inner: F,
205 _marker: PhantomData<T>,
206}
207
208impl<T, F, Fut> ProviderFactory for TypedFactory<T, F>
209where
210 T: Provider + Send + Sync + 'static,
211 F: Fn(&CardinalContext) -> Fut + Send + Sync + 'static,
212 Fut: Future<Output = Result<T, CardinalError>> + Send + 'static,
213{
214 fn create(&self, ctx: &CardinalContext) -> ProviderFuture {
215 let fut = (self.inner)(ctx);
216 Box::pin(async move {
217 let value = fut.await?;
218 Ok(Arc::new(value) as Arc<dyn Any + Send + Sync>)
219 })
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use async_trait::async_trait;
227 use cardinal_errors::CardinalError;
228
229 #[derive(Debug)]
230 struct Db {
231 dsn: String,
232 }
233
234 #[derive(Debug)]
235 struct Repo {
236 db: Arc<Db>,
237 }
238
239 #[derive(Debug)]
240 struct Service {
241 repo: Arc<Repo>,
242 }
243
244 #[async_trait]
245 impl Provider for Db {
246 async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
247 Ok(Db { dsn: "dsn".into() })
248 }
249 }
250
251 #[async_trait]
252 impl Provider for Repo {
253 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
254 Ok(Repo {
255 db: ctx.get::<Db>().await?,
256 })
257 }
258 }
259
260 #[async_trait]
261 impl Provider for Service {
262 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
263 Ok(Service {
264 repo: ctx.get::<Repo>().await?,
265 })
266 }
267 }
268
269 fn get_context() -> CardinalContext {
270 CardinalContext::new(CardinalConfig::default())
271 }
272
273 #[tokio::test]
274 async fn register_with_factory_preempts_default_provider() {
275 #[derive(Debug, Clone, PartialEq, Eq)]
276 struct StringProvider(pub String);
277
278 #[async_trait]
279 impl Provider for StringProvider {
280 async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
281 Ok(StringProvider("Hello".to_string()))
282 }
283 }
284
285 let ctx = get_context();
286 ctx.register_with_factory::<StringProvider, _, _>(ProviderScope::Singleton, |_ctx| async {
287 Ok(StringProvider("Overridden".to_string()))
288 });
289
290 assert!(ctx.is_registered::<StringProvider>());
291 let a = ctx.get::<StringProvider>().await.unwrap();
292 let b = ctx.get::<StringProvider>().await.unwrap();
293 assert!(Arc::ptr_eq(&a, &b));
294 assert_eq!(a.0, "Overridden");
295 }
296
297 #[tokio::test]
298 async fn register_singleton_instance_returns_same_arc() {
299 #[derive(Debug)]
300 struct Static;
301
302 #[async_trait]
303 impl Provider for Static {
304 async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
305 Ok(Static)
306 }
307 }
308
309 let ctx = get_context();
310 let instance = Arc::new(Static);
311 ctx.register_singleton_instance::<Static>(instance.clone());
312
313 let a = ctx.get::<Static>().await.unwrap();
314 let b = ctx.get::<Static>().await.unwrap();
315 assert!(Arc::ptr_eq(&a, &b));
316 assert!(Arc::ptr_eq(&a, &instance));
317 }
318
319 #[tokio::test]
320 async fn singleton_reuse_same_arc() {
321 let ctx = get_context();
322 ctx.register::<Db>(ProviderScope::Singleton);
323
324 let a = ctx.get::<Db>().await.unwrap();
325 let b = ctx.get::<Db>().await.unwrap();
326 assert!(Arc::ptr_eq(&a, &b));
327 }
328
329 #[tokio::test]
330 async fn transient_returns_new_arc_each_time() {
331 let ctx = get_context();
333 ctx.register::<Db>(ProviderScope::Singleton);
334 ctx.register::<Repo>(ProviderScope::Singleton);
335 ctx.register::<Service>(ProviderScope::Transient);
336
337 let a = ctx.get::<Service>().await.unwrap();
338 let b = ctx.get::<Service>().await.unwrap();
339 assert!(!Arc::ptr_eq(&a, &b));
340 }
341
342 #[tokio::test]
343 async fn nested_dependencies_singletons_reused_transient_recreated() {
344 let ctx = get_context();
345 ctx.register::<Db>(ProviderScope::Singleton);
346 ctx.register::<Repo>(ProviderScope::Singleton);
347 ctx.register::<Service>(ProviderScope::Transient);
348
349 let s1 = ctx.get::<Service>().await.unwrap();
350 let s2 = ctx.get::<Service>().await.unwrap();
351
352 assert!(!Arc::ptr_eq(&s1, &s2));
353 assert!(Arc::ptr_eq(&s1.repo, &s2.repo));
354 assert!(Arc::ptr_eq(&s1.repo.db, &s2.repo.db));
355 }
356
357 struct UnregisteredType;
358
359 #[async_trait]
360 impl Provider for UnregisteredType {
361 async fn provide(_ctx: &CardinalContext) -> Result<Self, CardinalError> {
362 Ok(UnregisteredType)
363 }
364 }
365
366 #[tokio::test]
367 async fn unregistered_type_errors() {
368 let ctx = get_context();
369 let res = ctx.get::<UnregisteredType>().await;
370 assert!(matches!(
371 res,
372 Err(CardinalError::InternalError(
373 cardinal_errors::internal::CardinalInternalError::ProviderNotRegistered
374 ))
375 ));
376 }
377
378 #[derive(Debug)]
379 struct A(Arc<B>);
380 #[derive(Debug)]
381 struct B(Arc<A>);
382
383 #[async_trait]
384 impl Provider for A {
385 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
386 Ok(A(ctx.get::<B>().await?))
387 }
388 }
389
390 #[async_trait]
391 impl Provider for B {
392 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
393 Ok(B(ctx.get::<A>().await?))
394 }
395 }
396
397 #[tokio::test]
398 async fn simple_cycle_errors() {
399 let ctx = get_context();
400 ctx.register::<A>(ProviderScope::Transient);
401 ctx.register::<B>(ProviderScope::Transient);
402
403 let res = ctx.get::<A>().await;
404 assert!(matches!(
405 res,
406 Err(CardinalError::InternalError(
407 cardinal_errors::internal::CardinalInternalError::DependencyCycleDetected
408 ))
409 ));
410 }
411}