1use async_trait::async_trait;
2use std::any::{Any, TypeId};
3use std::collections::HashMap;
4use std::marker::PhantomData;
5use std::sync::{Arc, Weak};
6use thiserror::Error;
7use tokio::sync::{Mutex, OwnedMutexGuard};
8
9use crate::{Actor, Addr, Owner};
10
11#[derive(Error, Debug)]
12pub enum RegistryError {
13 #[error("registry gone")]
14 RegistryGone,
15}
16
17#[derive(Debug)]
18struct State<D> {
19 data: D,
20 map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
21}
22
23#[derive(Debug)]
24pub struct Registry<D = ()> {
25 state: Arc<Mutex<State<D>>>,
26}
27
28impl<D> Registry<D> {
29 pub fn new() -> Self
30 where
31 D: Default,
32 {
33 Self::with_data(Default::default())
34 }
35
36 pub fn with_data(data: D) -> Self {
37 Self {
38 state: Arc::new(Mutex::new(State {
39 data,
40 map: HashMap::new(),
41 })),
42 }
43 }
44}
45
46pub struct RegistryRef<D = ()> {
47 r: Weak<Mutex<State<D>>>,
48 guard: OwnedMutexGuard<State<D>>,
49}
50
51impl<D> RegistryRef<D> {
52 pub fn deferred<S>(&self) -> Deferred<S, D>
53 where
54 S: Service<D>,
55 {
56 Deferred::new(self.r.clone())
57 }
58
59 pub async fn resolve<S>(&mut self) -> Result<Addr<S>, S::Error>
60 where
61 S: Service<D>,
62 {
63 let type_id = TypeId::of::<S>();
64 if let Some(addr) = self
65 .guard
66 .map
67 .get(&type_id)
68 .map(|v| v.downcast_ref::<Owner<S>>().unwrap().addr())
69 {
70 return Ok(addr);
71 }
72 let container = S::create(self).await?.start();
73 let addr = container.addr();
74 self.guard.map.insert(type_id, Box::new(container));
75 Ok(addr)
76 }
77
78 pub fn data(&self) -> &D {
79 &self.guard.data
80 }
81}
82
83#[async_trait]
84pub trait Service<D = ()>: Actor {
85 type Error: From<RegistryError>;
86 async fn create(registry: &mut RegistryRef<D>) -> Result<Self, Self::Error>;
87}
88
89impl<D> Registry<D> {
90 pub fn deferred<S>(&self) -> Deferred<S, D>
91 where
92 S: Service<D>,
93 {
94 Deferred::new(Arc::downgrade(&self.state))
95 }
96
97 pub async fn resolve<S>(&self) -> Result<Addr<S>, S::Error>
98 where
99 S: Service<D>,
100 {
101 let guard = self.state.clone().lock_owned().await;
102 let mut r = RegistryRef {
103 r: Arc::downgrade(&self.state),
104 guard,
105 };
106 r.resolve().await
107 }
108}
109
110#[derive(Debug)]
111pub struct Deferred<S, D = ()> {
112 r: Weak<Mutex<State<D>>>,
113 resolved: Option<Addr<S>>,
114 _phantom: PhantomData<S>,
115}
116
117impl<S, D> Deferred<S, D>
118where
119 S: Service<D>,
120{
121 fn new(r: Weak<Mutex<State<D>>>) -> Self {
122 Self {
123 r,
124 resolved: None,
125 _phantom: PhantomData,
126 }
127 }
128
129 pub async fn resolve(&mut self) -> Result<Addr<S>, S::Error> {
130 if let Some(addr) = self.resolved.clone() {
131 Ok(addr)
132 } else {
133 let map = self
134 .r
135 .upgrade()
136 .ok_or_else(|| S::Error::from(RegistryError::RegistryGone))?;
137 let guard = map.lock_owned().await;
138 let mut registry = RegistryRef {
139 r: self.r.clone(),
140 guard,
141 };
142 let addr = registry.resolve::<S>().await?;
143 self.resolved = Some(addr.clone());
144 Ok(addr)
145 }
146 }
147}
148
149impl<S, D> Clone for Deferred<S, D> {
150 fn clone(&self) -> Self {
151 Self {
152 r: self.r.clone(),
153 resolved: self.resolved.clone(),
154 _phantom: PhantomData,
155 }
156 }
157}
158
159#[cfg(test)]
160mod test {
161 use crate::registry::RegistryRef;
162 use crate::*;
163
164 #[tokio::test]
165 async fn test_registry() {
166 use once_cell::sync::OnceCell;
167
168 struct Data {
169 id: i32,
170 }
171
172 struct Dep;
173 impl Actor for Dep {}
174 #[async_trait]
175 impl Service<Data> for Dep {
176 type Error = registry::RegistryError;
177
178 async fn create(_: &mut RegistryRef<Data>) -> Result<Self, Self::Error> {
179 static CREATED: OnceCell<bool> = OnceCell::new();
180
181 CREATED.set(true).unwrap();
182
183 Ok(Dep)
184 }
185 }
186
187 struct Number {
188 dep_deferred: Deferred<Dep, Data>,
189 value: i32,
190 }
191 impl Actor for Number {}
192 struct Increase;
193 impl Message for Increase {
194 type Result = ();
195 }
196 #[async_trait]
197 impl Handler<Increase> for Number {
198 async fn handle(&mut self, _: &mut Context<Self>, _: Increase) {
199 self.value += 1;
200 }
201 }
202 struct GetValue;
203 impl Message for GetValue {
204 type Result = i32;
205 }
206 #[async_trait]
207 impl Handler<GetValue> for Number {
208 async fn handle(&mut self, _: &mut Context<Self>, _: GetValue) -> i32 {
209 self.value
210 }
211 }
212
213 struct Resolve;
214 impl Message for Resolve {
215 type Result = Addr<Dep>;
216 }
217
218 #[async_trait]
219 impl Handler<Resolve> for Number {
220 async fn handle(&mut self, _ctx: &mut Context<Self>, _message: Resolve) -> Addr<Dep> {
221 self.dep_deferred.resolve().await.unwrap()
222 }
223 }
224
225 #[async_trait]
226 impl Service<Data> for Number {
227 type Error = registry::RegistryError;
228
229 async fn create(registry: &mut RegistryRef<Data>) -> Result<Self, Self::Error> {
230 let _dep = registry.resolve::<Dep>().await?;
231
232 assert_eq!(registry.data().id, 42);
233
234 Ok(Number {
235 dep_deferred: registry.deferred(),
236 value: 0,
237 })
238 }
239 }
240
241 let registry = Registry::with_data(Data { id: 42 });
242 registry
243 .resolve::<Number>()
244 .await
245 .unwrap()
246 .send(Increase)
247 .await
248 .unwrap();
249 registry
250 .resolve::<Number>()
251 .await
252 .unwrap()
253 .send(Increase)
254 .await
255 .unwrap();
256 registry
257 .resolve::<Number>()
258 .await
259 .unwrap()
260 .send(Increase)
261 .await
262 .unwrap();
263
264 let value = registry
265 .resolve::<Number>()
266 .await
267 .unwrap()
268 .send(GetValue)
269 .await
270 .unwrap();
271
272 let number = registry.resolve::<Number>().await.unwrap();
273
274 let _dep = number.send(Resolve).await.unwrap();
275
276 assert_eq!(value, 0 + 1 + 1 + 1);
277 }
278}