1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::future::Future;
5use std::marker::Unsize;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use async_lock::RwLock;
11use async_trait::async_trait;
12
13use crate::error::Error;
14use crate::lazy_cell::LazyCell;
15use crate::{Injectable, Singleton};
16
17#[derive(Debug, Eq, PartialEq, Hash)]
18enum BindingKey {
19 Named(TypeId, String),
20 Unnamed(TypeId),
21}
22
23enum Binding {
24 Singleton(LazyCell<Box<dyn Any + Send + Sync + 'static>>),
25 Constructor(
26 Box<
27 dyn Fn() -> Pin<
28 Box<
29 dyn Future<
30 Output = Result<
31 Box<dyn Any + Send + Sync>,
32 Box<dyn std::error::Error + Send + Sync>,
33 >,
34 > + Sync
35 + Send,
36 >,
37 > + Send
38 + Sync,
39 >,
40 ),
41}
42
43#[async_trait]
44trait SpecializationBinder<T: ?Sized, I> {
45 async fn bind_internal(&self, key: BindingKey);
46}
47
48#[derive(Clone)]
49pub struct Injector {
50 inner: Arc<InjectorInner>,
51}
52
53struct InjectorInner {
54 bindings: RwLock<HashMap<BindingKey, Binding>>,
55}
56
57impl Injector {
58 pub fn new() -> Self {
59 Self {
60 inner: Arc::new(InjectorInner {
61 bindings: RwLock::new(HashMap::new()),
62 }),
63 }
64 }
65
66 async fn get_internal<T: ?Sized + Injectable + 'static>(
67 &self,
68 key: BindingKey,
69 ) -> Result<Arc<T>, Error<T::Error>> {
70 if let Some(binding) = self.inner.bindings.read().await.get(&key) {
71 match binding {
72 Binding::Singleton(cell) => match cell.get().await {
73 Ok(instance) => {
74 let instance: &Box<Arc<T>> = unsafe {
75 mem::transmute::<
76 &Box<dyn Any + Sync + Send + 'static>,
77 &Box<Arc<T>>,
78 >(instance)
79 };
80 Ok(*instance.clone())
81 }
82 Err(dyn_error) => {
83 let error: Box<T::Error> =
84 dyn_error.downcast::<T::Error>().unwrap();
85 Err(Error::InstanceCreationFailed(*error))
86 }
87 },
88 Binding::Constructor(constructor) => {
89 match Box::pin(constructor()).await {
90 Ok(instance) => {
91 let instance = instance.downcast::<Arc<T>>().unwrap();
92 Ok(*instance)
93 }
94 Err(dyn_error) => {
95 let error: Box<T::Error> =
96 dyn_error.downcast::<T::Error>().unwrap();
97 Err(Error::InstanceCreationFailed(*error))
98 }
99 }
100 }
101 }
102 } else {
103 Err(Error::NoBinding)
104 }
105 }
106
107 pub async fn bind<
108 T: ?Sized + Sync + Send + 'static,
109 I: Unsize<T> + Injectable + 'static,
110 >(
111 &self,
112 ) {
113 let fut = (self as &dyn SpecializationBinder<T, I>)
114 .bind_internal(BindingKey::Unnamed(TypeId::of::<T>()));
115 fut.await;
116 }
117
118 pub async fn bind_named<
119 T: ?Sized + Sync + Send + 'static,
120 I: Unsize<T> + Singleton + Injectable + 'static,
121 >(
122 &self,
123 name: impl ToString + 'static,
124 ) {
125 let fut = (self as &dyn SpecializationBinder<T, I>)
126 .bind_internal(BindingKey::Named(TypeId::of::<T>(), name.to_string()));
127 fut.await;
128 }
129
130 pub fn bind_cloneable<
131 T: ?Sized + Sync + 'static,
132 I: Unsize<T> + Clone + Injectable + 'static,
133 >(
134 &self,
135 ) {
136 todo!()
137 }
138
139 pub fn get<T: ?Sized + Injectable + 'static>(
140 &self,
141 ) -> impl Future<Output = Result<Arc<T>, Error<T::Error>>> + Send + Sync + '_
142 {
143 async {
144 self
145 .get_internal(BindingKey::Unnamed(TypeId::of::<T>()))
146 .await
147 }
148 }
149
150 pub async fn get_named<T: ?Sized + Injectable + 'static>(
151 &self,
152 name: impl ToString,
153 ) -> Result<Arc<T>, Error<T::Error>> {
154 self
155 .get_internal(BindingKey::Named(TypeId::of::<T>(), name.to_string()))
156 .await
157 }
158
159 pub async fn get_box<T: ?Sized + Injectable + 'static>(
160 &self,
161 ) -> Result<Box<T>, Error<T::Error>> {
162 todo!()
163 }
164}
165
166#[async_trait]
167impl<T, I> SpecializationBinder<T, I> for Injector
168where
169 T: ?Sized + Sync + Send + 'static,
170 I: Unsize<T> + Injectable + 'static,
171{
172 default async fn bind_internal(&self, _key: BindingKey) {
173 let type_id = TypeId::of::<T>();
174 let injector = self.clone();
175
176 self.inner.bindings.write().await.insert(
177 BindingKey::Unnamed(type_id),
178 Binding::Constructor(Box::new(move || {
179 let injector = injector.clone();
180 Box::pin(async move {
181 match I::create(injector).await {
182 Ok(instance) => Ok(Box::new(Arc::new(instance) as Arc<T>)
183 as Box<dyn Any + Sync + Send>),
184 Err(err) => {
185 Err(Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
186 }
187 }
188 })
189 })),
190 );
191 }
192}
193
194#[async_trait]
195impl<T, I> SpecializationBinder<T, I> for Injector
196where
197 T: ?Sized + Sync + Send + 'static,
198 I: Unsize<T> + Injectable + Singleton + 'static,
199{
200 default async fn bind_internal(&self, key: BindingKey) {
201 let generator = I::create(self.clone());
202 let inner = self.inner.clone();
203
204 let cell: LazyCell<Box<dyn Any + Sync + Send + 'static>> =
205 LazyCell::new(Box::pin(async move {
206 Ok(Box::new(Arc::new(generator.await?) as Arc<T>)
207 as Box<dyn Any + Sync + Send>)
208 }));
209
210 inner
211 .bindings
212 .write()
213 .await
214 .insert(key, Binding::Singleton(cell));
215 }
216}