jabba/
injector.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::future::Future;
5use std::marker::Unsize;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use async_lock::RwLock;
11use async_trait::async_trait;
12
13use crate::error::Error;
14use crate::lazy_cell::LazyCell;
15use crate::{Injectable, Singleton};
16
17#[derive(Debug, Eq, PartialEq, Hash)]
18enum BindingKey {
19  Named(TypeId, String),
20  Unnamed(TypeId),
21}
22
23enum Binding {
24  Singleton(LazyCell<Box<dyn Any + Send + Sync + 'static>>),
25  Constructor(
26    Box<
27      dyn Fn() -> Pin<
28          Box<
29            dyn Future<
30                Output = Result<
31                  Box<dyn Any + Send + Sync>,
32                  Box<dyn std::error::Error + Send + Sync>,
33                >,
34              > + Sync
35              + Send,
36          >,
37        > + Send
38        + Sync,
39    >,
40  ),
41}
42
43#[async_trait]
44trait SpecializationBinder<T: ?Sized, I> {
45  async fn bind_internal(&self, key: BindingKey);
46}
47
48#[derive(Clone)]
49pub struct Injector {
50  inner: Arc<InjectorInner>,
51}
52
53struct InjectorInner {
54  bindings: RwLock<HashMap<BindingKey, Binding>>,
55}
56
57impl Injector {
58  pub fn new() -> Self {
59    Self {
60      inner: Arc::new(InjectorInner {
61        bindings: RwLock::new(HashMap::new()),
62      }),
63    }
64  }
65
66  async fn get_internal<T: ?Sized + Injectable + 'static>(
67    &self,
68    key: BindingKey,
69  ) -> Result<Arc<T>, Error<T::Error>> {
70    if let Some(binding) = self.inner.bindings.read().await.get(&key) {
71      match binding {
72        Binding::Singleton(cell) => match cell.get().await {
73          Ok(instance) => {
74            let instance: &Box<Arc<T>> = unsafe {
75              mem::transmute::<
76                &Box<dyn Any + Sync + Send + 'static>,
77                &Box<Arc<T>>,
78              >(instance)
79            };
80            Ok(*instance.clone())
81          }
82          Err(dyn_error) => {
83            let error: Box<T::Error> =
84              dyn_error.downcast::<T::Error>().unwrap();
85            Err(Error::InstanceCreationFailed(*error))
86          }
87        },
88        Binding::Constructor(constructor) => {
89          match Box::pin(constructor()).await {
90            Ok(instance) => {
91              let instance = instance.downcast::<Arc<T>>().unwrap();
92              Ok(*instance)
93            }
94            Err(dyn_error) => {
95              let error: Box<T::Error> =
96                dyn_error.downcast::<T::Error>().unwrap();
97              Err(Error::InstanceCreationFailed(*error))
98            }
99          }
100        }
101      }
102    } else {
103      Err(Error::NoBinding)
104    }
105  }
106
107  pub async fn bind<
108    T: ?Sized + Sync + Send + 'static,
109    I: Unsize<T> + Injectable + 'static,
110  >(
111    &self,
112  ) {
113    let fut = (self as &dyn SpecializationBinder<T, I>)
114      .bind_internal(BindingKey::Unnamed(TypeId::of::<T>()));
115    fut.await;
116  }
117
118  pub async fn bind_named<
119    T: ?Sized + Sync + Send + 'static,
120    I: Unsize<T> + Singleton + Injectable + 'static,
121  >(
122    &self,
123    name: impl ToString + 'static,
124  ) {
125    let fut = (self as &dyn SpecializationBinder<T, I>)
126      .bind_internal(BindingKey::Named(TypeId::of::<T>(), name.to_string()));
127    fut.await;
128  }
129
130  pub fn bind_cloneable<
131    T: ?Sized + Sync + 'static,
132    I: Unsize<T> + Clone + Injectable + 'static,
133  >(
134    &self,
135  ) {
136    todo!()
137  }
138
139  pub fn get<T: ?Sized + Injectable + 'static>(
140    &self,
141  ) -> impl Future<Output = Result<Arc<T>, Error<T::Error>>> + Send + Sync + '_
142  {
143    async {
144      self
145        .get_internal(BindingKey::Unnamed(TypeId::of::<T>()))
146        .await
147    }
148  }
149
150  pub async fn get_named<T: ?Sized + Injectable + 'static>(
151    &self,
152    name: impl ToString,
153  ) -> Result<Arc<T>, Error<T::Error>> {
154    self
155      .get_internal(BindingKey::Named(TypeId::of::<T>(), name.to_string()))
156      .await
157  }
158
159  pub async fn get_box<T: ?Sized + Injectable + 'static>(
160    &self,
161  ) -> Result<Box<T>, Error<T::Error>> {
162    todo!()
163  }
164}
165
166#[async_trait]
167impl<T, I> SpecializationBinder<T, I> for Injector
168where
169  T: ?Sized + Sync + Send + 'static,
170  I: Unsize<T> + Injectable + 'static,
171{
172  default async fn bind_internal(&self, _key: BindingKey) {
173    let type_id = TypeId::of::<T>();
174    let injector = self.clone();
175
176    self.inner.bindings.write().await.insert(
177      BindingKey::Unnamed(type_id),
178      Binding::Constructor(Box::new(move || {
179        let injector = injector.clone();
180        Box::pin(async move {
181          match I::create(injector).await {
182            Ok(instance) => Ok(Box::new(Arc::new(instance) as Arc<T>)
183              as Box<dyn Any + Sync + Send>),
184            Err(err) => {
185              Err(Box::new(err) as Box<dyn std::error::Error + Send + Sync>)
186            }
187          }
188        })
189      })),
190    );
191  }
192}
193
194#[async_trait]
195impl<T, I> SpecializationBinder<T, I> for Injector
196where
197  T: ?Sized + Sync + Send + 'static,
198  I: Unsize<T> + Injectable + Singleton + 'static,
199{
200  default async fn bind_internal(&self, key: BindingKey) {
201    let generator = I::create(self.clone());
202    let inner = self.inner.clone();
203
204    let cell: LazyCell<Box<dyn Any + Sync + Send + 'static>> =
205      LazyCell::new(Box::pin(async move {
206        Ok(Box::new(Arc::new(generator.await?) as Arc<T>)
207          as Box<dyn Any + Sync + Send>)
208      }));
209
210    inner
211      .bindings
212      .write()
213      .await
214      .insert(key, Binding::Singleton(cell));
215  }
216}