1use crate::error::Error;
2use crate::helpers::BoxAny;
3use crate::helpers::Named;
4use crate::provider::{Provider, ProviderObject, Resolvable};
5use futures::future::{abortable, AbortHandle};
6use std::any::{type_name, TypeId};
7use std::collections::BTreeMap;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use tokio::sync::{mpsc, oneshot};
11
12#[derive(Debug)]
13pub struct Resolver {
14 worker_tx: mpsc::Sender<WorkerMessage>,
15 abort: AbortHandle,
16}
17
18impl Resolver {
19 pub(crate) async fn resolve<S>(&self) -> Result<S, Error>
20 where
21 S: Resolvable,
22 {
23 resolve(&self.worker_tx).await
24 }
25}
26
27impl Drop for Resolver {
28 fn drop(&mut self) {
29 self.abort.abort()
30 }
31}
32
33#[derive(Debug)]
34pub struct ResolverBuilder {
35 provider_map: BTreeMap<TypeId, Named<Arc<dyn ProviderObject>>>,
36}
37
38impl ResolverBuilder {
39 pub(crate) fn new() -> Self {
40 Self {
41 provider_map: BTreeMap::new(),
42 }
43 }
44
45 pub fn register<P>(&mut self, provider: P) -> &mut Self
46 where
47 P: Provider,
48 {
49 let type_id = TypeId::of::<P::Ref>();
50 let provider: Arc<dyn ProviderObject> = Arc::new(provider);
51 self.provider_map.insert(
52 type_id,
53 Named {
54 name: type_name::<P::Ref>(),
55 value: provider,
56 },
57 );
58 self
59 }
60
61 pub(crate) fn finalize(self) -> Resolver {
62 let (worker_tx, worker_rx) = mpsc::channel(32);
63 let worker = Worker {
64 provider_map: self.provider_map,
65 instance_map: BTreeMap::new(),
66 };
67 let (task, abort) = abortable(worker.start(worker_tx.clone(), worker_rx));
68
69 tokio::spawn(task);
70
71 Resolver { worker_tx, abort }
72 }
73}
74
75#[derive(Debug)]
76enum WorkerMessage {
77 ResolveRequest {
78 type_info: Named<TypeId>,
79 tx: oneshot::Sender<Result<Arc<BoxAny>, Error>>,
80 },
81 ProviderCallback {
82 type_info: Named<TypeId>,
83 result: Result<Arc<BoxAny>, Error>,
84 },
85}
86
87struct Worker {
88 provider_map: BTreeMap<TypeId, Named<Arc<dyn ProviderObject>>>,
89 instance_map: BTreeMap<TypeId, Named<InstanceSlot>>,
90}
91
92enum InstanceSlot {
93 ProviderRunning {
94 txs: Vec<oneshot::Sender<Result<Arc<BoxAny>, Error>>>,
97 },
98 Resolved(Arc<BoxAny>),
99}
100
101impl Worker {
102 async fn start(
103 mut self,
104 tx: mpsc::Sender<WorkerMessage>,
105 mut rx: mpsc::Receiver<WorkerMessage>,
106 ) {
107 let r = ResolverRef(tx);
108 loop {
109 tokio::select! {
110 Some(msg) = rx.recv() => {
111 tracing::debug!("msg: {:?}", msg);
112
113 match msg {
114 WorkerMessage::ResolveRequest {
115 type_info,
116 tx
117 } => {
118 self.dispatch(&r, type_info, tx);
119 },
120 WorkerMessage::ProviderCallback { type_info, result } => {
121 match result {
122 Ok(instance) => {
123 let replaced = self.instance_map.insert(type_info.value, type_info.map(|_| {
124 InstanceSlot::Resolved(instance.clone())
125 }));
126 match replaced.map(|v| v.value) {
127 Some(InstanceSlot::ProviderRunning { txs }) => {
128 for tx in txs {
129 tx.send(Ok(instance.clone())).ok();
130 }
131 },
132 _ => unreachable!()
133 }
134 },
135 Err(err) => {
136 match self.instance_map.remove(&type_info.value).map(|v| v.value) {
137 Some(InstanceSlot::ProviderRunning { txs }) => {
138 for tx in txs {
139 tx.send(Err(err.clone())).ok();
140 }
141 },
142 _ => unreachable!()
143 }
144 }
145 }
146 }
147 }
148 }
149 }
150 }
151 }
152
153 fn dispatch(
154 &mut self,
155 r: &ResolverRef,
156 type_info: Named<TypeId>,
157 reply_tx: oneshot::Sender<Result<Arc<BoxAny>, Error>>,
158 ) {
159 let type_id = type_info.value;
160 if let Some(slot) = self.instance_map.get_mut(&type_id) {
161 match slot.value {
162 InstanceSlot::ProviderRunning { ref mut txs } => {
163 txs.push(reply_tx);
164 }
165 InstanceSlot::Resolved(ref instance) => {
166 reply_tx.send(Ok(instance.clone())).ok();
167 }
168 }
169 return;
170 }
171
172 if let Some(provider) = self.provider_map.get(&type_id).cloned() {
173 self.instance_map.insert(
174 type_id,
175 type_info.with_value(InstanceSlot::ProviderRunning {
176 txs: vec![reply_tx],
177 }),
178 );
179 let r = r.clone();
180 tokio::spawn(async move {
181 let result = provider.value.provide(&r).await;
182 r.0.send(WorkerMessage::ProviderCallback {
183 type_info,
184 result: result.map(Arc::new),
185 })
186 .await
187 .ok();
188 });
189 } else {
190 reply_tx
191 .send(Err(Error::UnregisteredServiceType(type_info.name)))
192 .ok();
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
198pub struct ResolverRef(mpsc::Sender<WorkerMessage>);
199
200impl ResolverRef {
201 pub fn deferred<S>(&self) -> Deferred<S> {
202 Deferred::new(self.clone())
203 }
204
205 pub async fn resolve<S>(&self) -> Result<S, Error>
206 where
207 S: Resolvable,
208 {
209 resolve(&self.0).await
210 }
211}
212
213async fn resolve<S>(worker_tx: &mpsc::Sender<WorkerMessage>) -> Result<S, Error>
214where
215 S: Resolvable,
216{
217 let type_name = type_name::<S>();
218 let type_id = TypeId::of::<S>();
219
220 let (tx, rx) = oneshot::channel();
221 let req = WorkerMessage::ResolveRequest {
222 type_info: Named {
223 name: type_name,
224 value: type_id,
225 },
226 tx,
227 };
228 worker_tx.send(req).await.map_err(|_| Error::WorkerGone)?;
229 let res = rx.await.map_err(|_| Error::WorkerGone)?;
230 Ok(res.map(|v| v.downcast_ref::<S>().unwrap().clone())?)
231}
232
233#[derive(Debug, Clone)]
234pub struct Deferred<S> {
235 r: ResolverRef,
236 _p: PhantomData<S>,
237}
238
239impl<S> Deferred<S> {
240 fn new(r: ResolverRef) -> Self {
241 Self { r, _p: PhantomData }
242 }
243
244 pub async fn resolve(&self) -> Result<S, Error>
245 where
246 S: Resolvable,
247 {
248 self.r.resolve::<S>().await
249 }
250}