async_di/
resolver.rs

1use crate::error::Error;
2use crate::helpers::BoxAny;
3use crate::helpers::Named;
4use crate::provider::{Provider, ProviderObject, Resolvable};
5use futures::future::{abortable, AbortHandle};
6use std::any::{type_name, TypeId};
7use std::collections::BTreeMap;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use tokio::sync::{mpsc, oneshot};
11
12#[derive(Debug)]
13pub struct Resolver {
14    worker_tx: mpsc::Sender<WorkerMessage>,
15    abort: AbortHandle,
16}
17
18impl Resolver {
19    pub(crate) async fn resolve<S>(&self) -> Result<S, Error>
20    where
21        S: Resolvable,
22    {
23        resolve(&self.worker_tx).await
24    }
25}
26
27impl Drop for Resolver {
28    fn drop(&mut self) {
29        self.abort.abort()
30    }
31}
32
33#[derive(Debug)]
34pub struct ResolverBuilder {
35    provider_map: BTreeMap<TypeId, Named<Arc<dyn ProviderObject>>>,
36}
37
38impl ResolverBuilder {
39    pub(crate) fn new() -> Self {
40        Self {
41            provider_map: BTreeMap::new(),
42        }
43    }
44
45    pub fn register<P>(&mut self, provider: P) -> &mut Self
46    where
47        P: Provider,
48    {
49        let type_id = TypeId::of::<P::Ref>();
50        let provider: Arc<dyn ProviderObject> = Arc::new(provider);
51        self.provider_map.insert(
52            type_id,
53            Named {
54                name: type_name::<P::Ref>(),
55                value: provider,
56            },
57        );
58        self
59    }
60
61    pub(crate) fn finalize(self) -> Resolver {
62        let (worker_tx, worker_rx) = mpsc::channel(32);
63        let worker = Worker {
64            provider_map: self.provider_map,
65            instance_map: BTreeMap::new(),
66        };
67        let (task, abort) = abortable(worker.start(worker_tx.clone(), worker_rx));
68
69        tokio::spawn(task);
70
71        Resolver { worker_tx, abort }
72    }
73}
74
75#[derive(Debug)]
76enum WorkerMessage {
77    ResolveRequest {
78        type_info: Named<TypeId>,
79        tx: oneshot::Sender<Result<Arc<BoxAny>, Error>>,
80    },
81    ProviderCallback {
82        type_info: Named<TypeId>,
83        result: Result<Arc<BoxAny>, Error>,
84    },
85}
86
87struct Worker {
88    provider_map: BTreeMap<TypeId, Named<Arc<dyn ProviderObject>>>,
89    instance_map: BTreeMap<TypeId, Named<InstanceSlot>>,
90}
91
92enum InstanceSlot {
93    ProviderRunning {
94        // We cannot use Arc<Any> because Arc<Any>::downcast<T> requires T: Sized,
95        // but we need T: ?Sized to cast Arc<Any> to Arc<dyn T> so we must box Arc<dyn T>
96        txs: Vec<oneshot::Sender<Result<Arc<BoxAny>, Error>>>,
97    },
98    Resolved(Arc<BoxAny>),
99}
100
101impl Worker {
102    async fn start(
103        mut self,
104        tx: mpsc::Sender<WorkerMessage>,
105        mut rx: mpsc::Receiver<WorkerMessage>,
106    ) {
107        let r = ResolverRef(tx);
108        loop {
109            tokio::select! {
110              Some(msg) = rx.recv() => {
111                tracing::debug!("msg: {:?}", msg);
112
113                match msg {
114                  WorkerMessage::ResolveRequest {
115                    type_info,
116                    tx
117                  } => {
118                    self.dispatch(&r, type_info, tx);
119                  },
120                  WorkerMessage::ProviderCallback { type_info, result } => {
121                    match result {
122                      Ok(instance) => {
123                        let replaced = self.instance_map.insert(type_info.value, type_info.map(|_| {
124                          InstanceSlot::Resolved(instance.clone())
125                        }));
126                        match replaced.map(|v| v.value) {
127                          Some(InstanceSlot::ProviderRunning { txs }) => {
128                            for tx in txs {
129                              tx.send(Ok(instance.clone())).ok();
130                            }
131                          },
132                          _ => unreachable!()
133                        }
134                      },
135                      Err(err) => {
136                        match self.instance_map.remove(&type_info.value).map(|v| v.value) {
137                          Some(InstanceSlot::ProviderRunning { txs }) => {
138                            for tx in txs {
139                              tx.send(Err(err.clone())).ok();
140                            }
141                          },
142                          _ => unreachable!()
143                        }
144                      }
145                    }
146                  }
147                }
148              }
149            }
150        }
151    }
152
153    fn dispatch(
154        &mut self,
155        r: &ResolverRef,
156        type_info: Named<TypeId>,
157        reply_tx: oneshot::Sender<Result<Arc<BoxAny>, Error>>,
158    ) {
159        let type_id = type_info.value;
160        if let Some(slot) = self.instance_map.get_mut(&type_id) {
161            match slot.value {
162                InstanceSlot::ProviderRunning { ref mut txs } => {
163                    txs.push(reply_tx);
164                }
165                InstanceSlot::Resolved(ref instance) => {
166                    reply_tx.send(Ok(instance.clone())).ok();
167                }
168            }
169            return;
170        }
171
172        if let Some(provider) = self.provider_map.get(&type_id).cloned() {
173            self.instance_map.insert(
174                type_id,
175                type_info.with_value(InstanceSlot::ProviderRunning {
176                    txs: vec![reply_tx],
177                }),
178            );
179            let r = r.clone();
180            tokio::spawn(async move {
181                let result = provider.value.provide(&r).await;
182                r.0.send(WorkerMessage::ProviderCallback {
183                    type_info,
184                    result: result.map(Arc::new),
185                })
186                .await
187                .ok();
188            });
189        } else {
190            reply_tx
191                .send(Err(Error::UnregisteredServiceType(type_info.name)))
192                .ok();
193        }
194    }
195}
196
197#[derive(Debug, Clone)]
198pub struct ResolverRef(mpsc::Sender<WorkerMessage>);
199
200impl ResolverRef {
201    pub fn deferred<S>(&self) -> Deferred<S> {
202        Deferred::new(self.clone())
203    }
204
205    pub async fn resolve<S>(&self) -> Result<S, Error>
206    where
207        S: Resolvable,
208    {
209        resolve(&self.0).await
210    }
211}
212
213async fn resolve<S>(worker_tx: &mpsc::Sender<WorkerMessage>) -> Result<S, Error>
214where
215    S: Resolvable,
216{
217    let type_name = type_name::<S>();
218    let type_id = TypeId::of::<S>();
219
220    let (tx, rx) = oneshot::channel();
221    let req = WorkerMessage::ResolveRequest {
222        type_info: Named {
223            name: type_name,
224            value: type_id,
225        },
226        tx,
227    };
228    worker_tx.send(req).await.map_err(|_| Error::WorkerGone)?;
229    let res = rx.await.map_err(|_| Error::WorkerGone)?;
230    Ok(res.map(|v| v.downcast_ref::<S>().unwrap().clone())?)
231}
232
233#[derive(Debug, Clone)]
234pub struct Deferred<S> {
235    r: ResolverRef,
236    _p: PhantomData<S>,
237}
238
239impl<S> Deferred<S> {
240    fn new(r: ResolverRef) -> Self {
241        Self { r, _p: PhantomData }
242    }
243
244    pub async fn resolve(&self) -> Result<S, Error>
245    where
246        S: Resolvable,
247    {
248        self.r.resolve::<S>().await
249    }
250}