1#![allow(dead_code)]
2
3use futures::future::BoxFuture;
4use tokio::sync::{Mutex, RwLock};
5
6use crate::{
7 helpers::service_container, service::Service, Handler, Injectable, Resolver, Singleton,
8};
9use std::{
10 any::{Any, TypeId},
11 collections::HashMap,
12 sync::{Arc, OnceLock},
13};
14
15pub(crate) static SERVICE_CONTAINER: OnceLock<Arc<ServiceContainer>> = OnceLock::new();
16pub(crate) const GLOBAL_INSTANCE_ID: &str = "_global_ci";
17
18type ResolverCollection = HashMap<
19 TypeId,
20 Arc<
21 Mutex<
22 Box<
23 dyn FnMut(
24 ServiceContainer,
25 )
26 -> BoxFuture<'static, Box<dyn Any + Send + Sync + 'static>>
27 + Sync
28 + Send
29 + 'static,
30 >,
31 >,
32 >,
33>;
34
35#[derive(Default, Clone)]
36pub(crate) struct Container {
37 services: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync + 'static>>>>,
38 resolvers: Arc<RwLock<ResolverCollection>>,
39}
40
41impl Container {
42 pub(crate) async fn get<T: Clone + 'static>(&self, ci: ServiceContainer) -> Option<T> {
43 let lock = self.services.read().await;
44 if let Some(raw) = lock.get(&TypeId::of::<T>()) {
45 return raw.downcast_ref().cloned();
46 }
47 drop(lock);
48
49 let lock = self.resolvers.read().await;
50
51 if let Some(mutex) = lock.get(&TypeId::of::<T>()).cloned() {
52 drop(lock);
53 let mut callback = mutex.lock().await;
54 return callback(ci).await.downcast_ref::<T>().cloned();
55 }
56
57 None
58 }
59
60 pub(crate) async fn set<T: Send + Sync + 'static>(&self, value: T) -> &Self {
61 let mut lock = self.services.write().await;
62 lock.insert(
63 TypeId::of::<T>(),
64 Box::new(value) as Box<dyn Any + Send + Sync + 'static>,
65 );
66 drop(lock);
67
68 self
69 }
70
71 pub(crate) async fn forget<T: 'static>(&self, ci: ServiceContainer) -> Option<Box<T>> {
72 let mut lock = self.services.write().await;
73 if let Some(raw) = lock.remove(&TypeId::of::<T>()) {
74 self.resolvers.write().await.remove(&TypeId::of::<T>());
75 return raw.downcast().ok();
76 }
77
78 let mut lock = self.resolvers.write().await;
79 if let Some(mutex) = lock.remove(&TypeId::of::<T>()) {
80 drop(lock);
81 let mut callback = mutex.lock().await;
82 return callback(ci).await.downcast::<T>().ok();
83 }
84
85 None
86 }
87
88 pub(crate) async fn remove_resolver<T: 'static>(&self) -> bool {
89 if self.has_resolver::<T>().await {
90 let mut lock = self.resolvers.write().await;
91 lock.remove(&TypeId::of::<T>());
92 true
93 } else {
94 false
95 }
96 }
97
98 pub(crate) async fn resolver<T: Clone + Send + Sync + 'static>(
99 &self,
100 mut callback: impl FnMut(ServiceContainer) -> BoxFuture<'static, T>
101 + Send
102 + Sync
103 + Clone
104 + 'static,
105 ) -> &Self {
106 let mut lock = self.resolvers.write().await;
107 lock.insert(
108 TypeId::of::<T>(),
109 Arc::new(Mutex::new(Box::new(move |c| {
110 let f = (callback)(c);
111 Box::pin(async move { Box::new(f.await) as Box<dyn Any + Send + Sync + 'static> })
112 }))),
113 );
114 self
115 }
116
117 pub(crate) async fn soft_resolver<T: Clone + Send + Sync + 'static>(
118 &self,
119 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Clone + 'static,
120 ) -> &Self {
121 if self.has_resolver::<T>().await {
122 return self;
123 }
124
125 self.resolver(callback).await
126 }
127
128 pub(crate) async fn has_resolver<T: 'static>(&self) -> bool {
129 let lock = self.resolvers.read().await;
130 lock.get(&TypeId::of::<T>()).is_some()
131 }
132}
133
134#[derive(Clone)]
135pub struct ServiceContainer {
136 in_proxy_mode: bool,
137 is_reference: bool,
138 container: Container,
139 id: Arc<String>,
140}
141
142impl Default for ServiceContainer {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl ServiceContainer {
149 pub(crate) fn new() -> Self {
150 let id = GLOBAL_INSTANCE_ID.to_string();
151 Self {
152 id: Arc::new(id),
153 in_proxy_mode: false,
154 is_reference: false,
155 container: Default::default(),
156 }
157 }
158
159 pub(crate) fn make_reference(&self) -> Self {
160 Self {
161 is_reference: true,
162 id: self.id.clone(),
163 in_proxy_mode: self.in_proxy_mode,
164 container: self.container.clone(),
165 }
166 }
167
168 pub fn proxy() -> Self {
176 let id = ulid::Ulid::new().to_string().to_lowercase();
177
178 let mut ci = Self::default();
179 ci.id = Arc::new(id);
180 ci.in_proxy_mode = true;
181 ci
182 }
183
184 pub fn is_proxy(&self) -> bool {
186 self.in_proxy_mode
187 }
188
189 pub async fn proxy_value<T: Clone + 'static>(&self) -> Option<T> {
193 if self.is_proxy() {
194 self.get_type::<T>().await
195 } else {
196 None
197 }
198 }
199
200 pub async fn get<T: 'static>(&self) -> Option<Service<T>> {
202 self.get_type::<Service<T>>().await
203 }
204
205 pub async fn forget_type<T: 'static>(&self) -> Option<Box<T>> {
206 self.container.forget::<T>(self.make_reference()).await
207 }
208
209 pub async fn forget_resolver<T: 'static>(&self) -> bool {
210 self.container.remove_resolver::<T>().await
211 }
212
213 pub async fn forget<T: 'static>(&self) -> Option<Box<Service<T>>> {
214 self.forget_type().await
215 }
216
217 #[deprecated(note = "use `get`")]
220 pub async fn get_or_inject<T: Injectable + Send + Sync + 'static>(&self) -> Service<T> {
221 let result = self.get::<T>().await;
222
223 if result.is_none() {
224 let instance = T::inject(self).await;
225 return self.set(instance).await.get::<T>().await.unwrap();
226 }
227
228 result.unwrap()
229 }
230
231 #[deprecated(note = "use `get_type`")]
234 pub async fn get_type_or_inject<T: Injectable + Clone + Send + Sync + 'static>(&self) -> T {
235 let result = self.get_type::<T>().await;
236 if result.is_none() {
237 let instance = T::inject(self).await;
238 self.set_type(instance.clone()).await;
239 return instance;
240 }
241
242 result.unwrap()
243 }
244
245 pub async fn get_type<T: Clone + 'static>(&self) -> Option<T> {
247 let value = self.container.get::<T>(self.make_reference()).await;
248 if value.is_some() {
249 return value;
250 }
251
252 if self.is_proxy() {
253 return Box::pin(service_container().get_type()).await;
254 }
255
256 None
257 }
258
259 pub(crate) async fn instance<T: Clone + 'static>(&self) -> Option<T> {
260 self.container.get::<T>(self.make_reference()).await
261 }
262
263 pub async fn set_type<T: Clone + Send + Sync + 'static>(&self, value: T) -> &Self {
265 self.resolver(move |_| {
266 let c = value.clone();
267 Box::pin(async move { c })
268 })
269 .await;
270 self
271 }
272
273 pub(crate) async fn remember<T: Clone + Send + Sync + 'static>(&self, value: T) -> &Self {
274 self.container.set(value).await;
275 self
276 }
277
278 pub async fn set<T: Send + Sync + 'static>(&self, ext: T) -> &Self {
281 self.set_type(Service::new(ext)).await
282 }
283
284 pub async fn resolver<T: Clone + Send + Sync + 'static>(
289 &self,
290 callback: impl FnMut(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Clone + 'static,
291 ) -> &Self {
292 self.container.resolver(callback).await;
293
294 self
295 }
296 pub async fn resolvable<T: Resolver + Clone + Send + Sync + 'static>(&self) -> &Self {
297 self.container
298 .resolver(|c| Box::pin(async move { T::resolve(&c).await }))
299 .await;
300 self
301 }
302
303 pub async fn resolvable_once<T: Resolver + Clone + Send + Sync + 'static>(&self) -> &Self {
304 self.resolver_once(|c| Box::pin(async move { T::resolve(&c).await }))
305 .await;
306 self
307 }
308
309 pub async fn soft_resolvable<T: Resolver + Clone + Send + Sync + 'static>(&self) -> &Self {
310 self.soft_resolver(|c| Box::pin(async move { T::resolve(&c).await }))
311 .await;
312 self
313 }
314
315 pub async fn soft_resolver<T: Clone + Send + Sync + 'static>(
320 &self,
321 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Clone + 'static,
322 ) -> &Self {
323 self.container.soft_resolver(callback).await;
324 self
325 }
326
327 pub async fn resolver_once<T: Clone + Send + Sync + 'static>(
332 &self,
333 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Copy + 'static,
334 ) -> &Self {
335 self.container
336 .resolver(move |container| {
337 let f = (callback)(container.clone());
338 Box::pin(async move {
339 let value = f.await;
340 container.set_type(value.clone()).await;
341 value
342 })
343 })
344 .await;
345
346 self
347 }
348
349 pub async fn soft_resolver_once<T: Clone + Send + Sync + 'static>(
357 &self,
358 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Copy + 'static,
359 ) -> &Self {
360 if !self.container.has_resolver::<T>().await {
361 self.resolver_once(callback).await;
362 }
363
364 self
365 }
366
367 #[deprecated(note = "use `resolve_and_call`")]
373 pub async fn inject_and_call<F, Args>(&self, mut handler: F) -> F::Output
374 where
375 F: Handler<Args>,
376 Args: Injectable + 'static,
377 {
378 let args = Args::inject(self).await;
379 handler.call(args).await
380 }
381
382 pub async fn resolve_and_call<F, Args>(&self, mut handler: F) -> F::Output
387 where
388 F: Handler<Args>,
389 Args: Resolver,
390 {
391 let args = Args::resolve(self).await;
392 handler.call(args).await
393 }
394
395 pub async fn resolve_all<Args>(&self) -> Args
399 where
400 Args: Resolver,
401 {
402 Args::resolve(self).await
403 }
404
405 #[deprecated(note = "use `resolve_all`")]
411 pub async fn inject_all<Args>(&self) -> Args
412 where
413 Args: Injectable + 'static,
414 {
415 Args::inject(self).await
416 }
417
418 #[deprecated(note = "use `get` or `get_type`")]
422 pub async fn provide<T: Injectable + 'static>(&self) -> T {
423 T::inject(self).await
424 }
425
426 #[deprecated(note = "use `get`")]
430 pub async fn service<T: Send + Sync + 'static>(&self) -> Service<T> {
431 Service::inject(self).await
432 }
433
434 #[deprecated(note = "use `get` or `get_type`")]
441 pub async fn singleton<T: Injectable + Sized + Send + Sync + 'static>(&self) -> Singleton<T> {
442 Singleton::inject(self).await
443 }
444}
445
446pub struct ServiceContainerBuilder {
447 service_container: ServiceContainer,
448}
449
450impl Default for ServiceContainerBuilder {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456impl ServiceContainerBuilder {
457 pub fn new() -> Self {
458 Self {
459 service_container: ServiceContainer::new(),
460 }
461 }
462
463 pub fn new_proxy() -> Self {
468 Self {
469 service_container: ServiceContainer::proxy(),
470 }
471 }
472
473 pub async fn register<T: Clone + Send + Sync + 'static>(self, ext: T) -> Self {
475 self.service_container.set_type(ext).await;
476 self
477 }
478
479 pub async fn resolver<T: Clone + Send + Sync + 'static>(
484 self,
485 callback: impl FnMut(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Copy + 'static,
486 ) -> Self {
487 self.service_container.resolver(callback).await;
488 self
489 }
490
491 pub async fn resolvable<T: Resolver + Clone + Send + Sync + 'static>(self) -> Self {
495 self.service_container.resolvable::<T>().await;
496 self
497 }
498
499 pub async fn resolvable_once<T: Resolver + Clone + Send + Sync + 'static>(self) -> Self {
504 self.service_container.resolvable_once::<T>().await;
505 self
506 }
507
508 pub async fn soft_resolvable<T: Resolver + Clone + Send + Sync + 'static>(self) -> Self {
512 self.service_container.soft_resolvable::<T>().await;
513 self
514 }
515
516 pub async fn resolver_once<T: Clone + Send + Sync + 'static>(
524 self,
525 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Copy + 'static,
526 ) -> Self {
527 self.service_container.resolver_once(callback).await;
528 self
529 }
530
531 pub async fn soft_resolver<T: Clone + Send + Sync + 'static>(
537 self,
538 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Clone + 'static,
539 ) -> Self {
540 self.service_container.soft_resolver(callback).await;
541 self
542 }
543
544 pub async fn soft_resolver_once<T: Clone + Send + Sync + 'static>(
554 self,
555 callback: impl Fn(ServiceContainer) -> BoxFuture<'static, T> + Send + Sync + Copy + 'static,
556 ) -> Self {
557 self.service_container.soft_resolver_once(callback).await;
558 self
559 }
560
561 pub async fn service<T: Send + Sync + 'static>(self, ext: T) -> Self {
565 self.service_container.set(ext).await;
566 self
567 }
568
569 pub fn build(self) -> Arc<ServiceContainer> {
571 if self.service_container.id.as_str() == GLOBAL_INSTANCE_ID {
572 SERVICE_CONTAINER
573 .get_or_init(|| Arc::new(self.service_container))
574 .clone()
575 } else {
576 Arc::new(self.service_container)
577 }
578 }
579}
580
581#[cfg(test)]
582mod test {
583 use async_trait::async_trait;
584
585 use crate::helpers::service_container;
586
587 use super::*;
588
589 #[derive(Debug, Clone)]
590 struct Counter {
591 start_point: usize,
592 }
593
594 #[async_trait]
595 impl Injectable for Counter {
596 async fn inject(container: &ServiceContainer) -> Self {
597 let mut result = container.get_type().await;
598 if result.is_none() {
599 result = container
600 .set_type(Counter { start_point: 44 })
601 .await
602 .get_type()
603 .await;
604 }
605 result.unwrap()
606 }
607 }
608
609 #[derive(Debug, Clone)]
610 struct User {
611 id: i32,
612 }
613
614 #[async_trait]
615 impl Injectable for User {
616 async fn inject(_: &ServiceContainer) -> Self {
617 Self { id: 1000 }
618 }
619 }
620
621 #[tokio::test]
622 async fn test_builder() {
623 let container = ServiceContainerBuilder::new_proxy()
624 .service(5usize)
625 .await
626 .register(true)
627 .await
628 .build();
629
630 assert_eq!(*container.get::<usize>().await.unwrap(), 5usize);
631 assert_eq!(container.get_type::<bool>().await, Some(true));
632 }
633
634 #[tokio::test]
635 async fn test_empty_container() {
636 let container = ServiceContainer::proxy();
637
638 assert_eq!(container.get::<i32>().await.is_none(), true);
639 assert_eq!(container.get_type::<i32>().await, None);
640 }
641
642 #[tokio::test]
643 async fn test_getting_raw_type() {
644 let container = ServiceContainer::proxy();
645 container.set_type(400).await;
646 container.set_type(300f32).await;
647 container.set_type(true).await;
648
649 assert_eq!(container.get_type::<i32>().await, Some(400));
650 assert_eq!(container.get_type::<f32>().await, Some(300f32));
651 assert_eq!(container.get_type::<bool>().await, Some(true));
652 }
653
654 #[tokio::test]
655 async fn test_getting_service_type() {
656 let container = ServiceContainer::proxy();
657 container.set(400).await;
658 container.set(300f32).await;
659 container.set(true).await;
660
661 assert_eq!(*container.get::<i32>().await.unwrap(), 400);
662 assert_eq!(*container.get::<f32>().await.unwrap(), 300f32);
663 assert_eq!(*container.get::<bool>().await.unwrap(), true);
664 }
665
666 #[tokio::test]
667 async fn test_proxy_service() {
668 service_container().set_type(true).await;
669 let container = ServiceContainer::proxy();
670
671 let is_true: Option<bool> = container.get_type().await;
672 let an_i32: Option<i32> = container.get_type().await;
673
674 assert_eq!(is_true, Some(true));
675 assert_eq!(an_i32, None);
676
677 container.set_type(30000).await;
678 let rate_per_hour: Option<i32> = container.get_type().await;
679 assert_eq!(rate_per_hour, Some(30000));
680 }
681
682 #[tokio::test]
683 async fn test_injecting() {
684 let container = ServiceContainer::proxy();
685 let counter = container.inject_all::<Counter>().await;
686
687 assert_eq!(counter.start_point, 44usize);
688 }
689
690 #[tokio::test]
691 async fn test_injecting_stored_instance() {
692 let container = ServiceContainer::proxy();
693 container.set_type(Counter { start_point: 6000 }).await;
694
695 let counter = container.inject_all::<Counter>().await;
696 assert_eq!(counter.start_point, 6000usize);
697 }
698
699 #[tokio::test]
700 async fn test_singleton() {
701 let container = ServiceContainer::proxy();
702
703 let user = container.singleton::<User>().await;
704 assert_eq!(user.id, 1000);
705
706 container.set_type(User { id: 88 }).await;
707 let user = container.singleton::<User>().await;
708 assert_eq!(user.id, 1000);
709 }
710
711 #[tokio::test]
712 async fn test_inject_and_call() {
713 let container = ServiceContainer::proxy();
714
715 let result = container
716 .inject_and_call(|user: User, counter: Counter| async move {
717 assert_eq!(user.id, 1000);
718 assert_eq!(counter.start_point, 44);
719 (1, 2, 3)
720 })
721 .await;
722
723 assert_eq!(result, (1, 2, 3));
724 }
725
726 #[tokio::test]
727 async fn test_get_or_inject_raw_type() {
728 let container = ServiceContainer::proxy();
729 assert_eq!(container.get_type::<User>().await.is_none(), true);
730
731 let a_user = container.get_type_or_inject::<User>().await;
732 let a_user2 = container.get_type::<User>().await;
733
734 assert_eq!(a_user.id, 1000);
735 assert_eq!(a_user2.is_some(), true);
736 assert_eq!(a_user2.unwrap().id, a_user.id);
737 }
738
739 #[tokio::test]
740 async fn test_get_or_inject_service_type() {
741 let container = ServiceContainer::proxy();
742
743 assert_eq!(container.get::<User>().await.is_none(), true);
744
745 let a_user = container.get_or_inject::<User>().await;
746 let a_user2 = container.get::<User>().await;
747
748 assert_eq!(a_user.id, 1000);
749 assert_eq!(a_user2.is_some(), true);
750 assert_eq!(a_user2.unwrap().id, a_user.id);
751 }
752
753 #[tokio::test]
754 async fn test_forgetting_a_type() {
755 let container = ServiceContainer::proxy();
756
757 assert_eq!(container.get_type::<usize>().await, None);
758
759 container.set_type(300_usize).await;
760 assert_eq!(container.get_type::<usize>().await, Some(300_usize));
761
762 let value = container.forget_type::<usize>().await;
763 assert_eq!(value.is_some(), true);
764
765 assert_eq!(container.get_type::<usize>().await, None);
766 }
767
768 #[tokio::test]
769 async fn test_forgetting_service_a_type() {
770 let container = ServiceContainer::proxy();
771
772 assert_eq!(container.get::<usize>().await.is_none(), true);
773
774 container.set(300_usize).await;
775 assert_eq!(*container.get::<usize>().await.unwrap(), 300_usize);
776
777 let value = container.forget::<usize>().await;
778 assert_eq!(value.is_some(), true);
779
780 assert_eq!(container.get::<usize>().await.is_none(), true);
781 }
782
783 #[tokio::test]
784 async fn test_service_without_clone_type() {
785 struct UserName(String);
786
787 let container = ServiceContainer::proxy();
788 container.set(UserName("foobar".to_string())).await;
789
790 let result: Option<Service<_>> = container.get::<UserName>().await;
791
792 assert_eq!(true, result.is_some());
793 assert_eq!("foobar", result.unwrap().as_ref().0);
794 }
795
796 #[tokio::test]
797 async fn test_resolver() {
798 let container = ServiceContainer::proxy();
799
800 container
801 .resolver::<String>(|_| Box::pin(async { "foo".to_string() }))
802 .await;
803
804 assert_eq!(
805 container.get_type::<String>().await,
806 Some("foo".to_string()),
807 );
808 }
809
810 #[tokio::test(flavor = "multi_thread")]
811 async fn test_resolving_once() {
812 let container = ServiceContainer::proxy();
813
814 #[derive(Debug, Clone, PartialEq)]
815 struct Special(String);
816
817 container
818 .resolver_once::<Special>(|c| {
819 Box::pin(async move {
820 let counter: i32 = c.get_type().await.unwrap_or_default();
821 c.set_type(counter + 1).await;
822 Special(format!("id:{counter}"))
823 })
824 })
825 .await;
826
827 assert_eq!(
828 container.get_type::<Special>().await,
829 Some(Special("id:0".to_string()))
830 );
831 assert_eq!(
832 container.get_type::<Special>().await,
833 Some(Special("id:0".to_string())),
834 "ID should have been zero (0)"
835 );
836 }
837
838 #[tokio::test]
839 async fn test_soft_resolving() {
840 let container = ServiceContainer::proxy();
841
842 container
843 .resolver(|_| Box::pin(async { SoftCounter(1) }))
844 .await;
845 container
846 .soft_resolver(|_| Box::pin(async { SoftCounter(100) }))
847 .await;
848
849 #[derive(Debug, Clone, PartialEq)]
850 struct SoftCounter(i32);
851
852 let counter: SoftCounter = container.get_type().await.unwrap();
853 assert_eq!(counter.0, 1);
854
855 let counter: SoftCounter = container.get_type().await.unwrap();
856 assert_ne!(counter.0, 100);
857 }
858
859 #[tokio::test]
860 async fn test_soft_resolving2() {
861 let container = ServiceContainer::proxy();
862
863 container
864 .soft_resolver(|_| Box::pin(async { SoftCounter(100) }))
865 .await;
866
867 #[derive(Debug, Clone, PartialEq)]
868 struct SoftCounter(i32);
869
870 let counter: SoftCounter = container.get_type().await.unwrap();
871 assert_eq!(counter.0, 100);
872 }
873
874 #[tokio::test]
875 async fn test_forgetting_resolver() {
876 let container = ServiceContainer::proxy();
877 container.resolver(|_| Box::pin(async { 100 })).await;
878
879 let number = container.get_type::<i32>().await;
880 assert_eq!(number.is_some(), true);
881
882 container.forget_resolver::<i32>().await;
883 let number2 = container.get_type::<i32>().await;
884 assert_eq!(number2.is_none(), true);
885 }
886}