flo_state/
registry.rs

1use async_trait::async_trait;
2use std::any::{Any, TypeId};
3use std::collections::HashMap;
4use std::marker::PhantomData;
5use std::sync::{Arc, Weak};
6use thiserror::Error;
7use tokio::sync::{Mutex, OwnedMutexGuard};
8
9use crate::{Actor, Addr, Owner};
10
11#[derive(Error, Debug)]
12pub enum RegistryError {
13  #[error("registry gone")]
14  RegistryGone,
15}
16
17#[derive(Debug)]
18struct State<D> {
19  data: D,
20  map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
21}
22
23#[derive(Debug)]
24pub struct Registry<D = ()> {
25  state: Arc<Mutex<State<D>>>,
26}
27
28impl<D> Registry<D> {
29  pub fn new() -> Self
30  where
31    D: Default,
32  {
33    Self::with_data(Default::default())
34  }
35
36  pub fn with_data(data: D) -> Self {
37    Self {
38      state: Arc::new(Mutex::new(State {
39        data,
40        map: HashMap::new(),
41      })),
42    }
43  }
44}
45
46pub struct RegistryRef<D = ()> {
47  r: Weak<Mutex<State<D>>>,
48  guard: OwnedMutexGuard<State<D>>,
49}
50
51impl<D> RegistryRef<D> {
52  pub fn deferred<S>(&self) -> Deferred<S, D>
53  where
54    S: Service<D>,
55  {
56    Deferred::new(self.r.clone())
57  }
58
59  pub async fn resolve<S>(&mut self) -> Result<Addr<S>, S::Error>
60  where
61    S: Service<D>,
62  {
63    let type_id = TypeId::of::<S>();
64    if let Some(addr) = self
65      .guard
66      .map
67      .get(&type_id)
68      .map(|v| v.downcast_ref::<Owner<S>>().unwrap().addr())
69    {
70      return Ok(addr);
71    }
72    let container = S::create(self).await?.start();
73    let addr = container.addr();
74    self.guard.map.insert(type_id, Box::new(container));
75    Ok(addr)
76  }
77
78  pub fn data(&self) -> &D {
79    &self.guard.data
80  }
81}
82
83#[async_trait]
84pub trait Service<D = ()>: Actor {
85  type Error: From<RegistryError>;
86  async fn create(registry: &mut RegistryRef<D>) -> Result<Self, Self::Error>;
87}
88
89impl<D> Registry<D> {
90  pub fn deferred<S>(&self) -> Deferred<S, D>
91  where
92    S: Service<D>,
93  {
94    Deferred::new(Arc::downgrade(&self.state))
95  }
96
97  pub async fn resolve<S>(&self) -> Result<Addr<S>, S::Error>
98  where
99    S: Service<D>,
100  {
101    let guard = self.state.clone().lock_owned().await;
102    let mut r = RegistryRef {
103      r: Arc::downgrade(&self.state),
104      guard,
105    };
106    r.resolve().await
107  }
108}
109
110#[derive(Debug)]
111pub struct Deferred<S, D = ()> {
112  r: Weak<Mutex<State<D>>>,
113  resolved: Option<Addr<S>>,
114  _phantom: PhantomData<S>,
115}
116
117impl<S, D> Deferred<S, D>
118where
119  S: Service<D>,
120{
121  fn new(r: Weak<Mutex<State<D>>>) -> Self {
122    Self {
123      r,
124      resolved: None,
125      _phantom: PhantomData,
126    }
127  }
128
129  pub async fn resolve(&mut self) -> Result<Addr<S>, S::Error> {
130    if let Some(addr) = self.resolved.clone() {
131      Ok(addr)
132    } else {
133      let map = self
134        .r
135        .upgrade()
136        .ok_or_else(|| S::Error::from(RegistryError::RegistryGone))?;
137      let guard = map.lock_owned().await;
138      let mut registry = RegistryRef {
139        r: self.r.clone(),
140        guard,
141      };
142      let addr = registry.resolve::<S>().await?;
143      self.resolved = Some(addr.clone());
144      Ok(addr)
145    }
146  }
147}
148
149impl<S, D> Clone for Deferred<S, D> {
150  fn clone(&self) -> Self {
151    Self {
152      r: self.r.clone(),
153      resolved: self.resolved.clone(),
154      _phantom: PhantomData,
155    }
156  }
157}
158
159#[cfg(test)]
160mod test {
161  use crate::registry::RegistryRef;
162  use crate::*;
163
164  #[tokio::test]
165  async fn test_registry() {
166    use once_cell::sync::OnceCell;
167
168    struct Data {
169      id: i32,
170    }
171
172    struct Dep;
173    impl Actor for Dep {}
174    #[async_trait]
175    impl Service<Data> for Dep {
176      type Error = registry::RegistryError;
177
178      async fn create(_: &mut RegistryRef<Data>) -> Result<Self, Self::Error> {
179        static CREATED: OnceCell<bool> = OnceCell::new();
180
181        CREATED.set(true).unwrap();
182
183        Ok(Dep)
184      }
185    }
186
187    struct Number {
188      dep_deferred: Deferred<Dep, Data>,
189      value: i32,
190    }
191    impl Actor for Number {}
192    struct Increase;
193    impl Message for Increase {
194      type Result = ();
195    }
196    #[async_trait]
197    impl Handler<Increase> for Number {
198      async fn handle(&mut self, _: &mut Context<Self>, _: Increase) {
199        self.value += 1;
200      }
201    }
202    struct GetValue;
203    impl Message for GetValue {
204      type Result = i32;
205    }
206    #[async_trait]
207    impl Handler<GetValue> for Number {
208      async fn handle(&mut self, _: &mut Context<Self>, _: GetValue) -> i32 {
209        self.value
210      }
211    }
212
213    struct Resolve;
214    impl Message for Resolve {
215      type Result = Addr<Dep>;
216    }
217
218    #[async_trait]
219    impl Handler<Resolve> for Number {
220      async fn handle(&mut self, _ctx: &mut Context<Self>, _message: Resolve) -> Addr<Dep> {
221        self.dep_deferred.resolve().await.unwrap()
222      }
223    }
224
225    #[async_trait]
226    impl Service<Data> for Number {
227      type Error = registry::RegistryError;
228
229      async fn create(registry: &mut RegistryRef<Data>) -> Result<Self, Self::Error> {
230        let _dep = registry.resolve::<Dep>().await?;
231
232        assert_eq!(registry.data().id, 42);
233
234        Ok(Number {
235          dep_deferred: registry.deferred(),
236          value: 0,
237        })
238      }
239    }
240
241    let registry = Registry::with_data(Data { id: 42 });
242    registry
243      .resolve::<Number>()
244      .await
245      .unwrap()
246      .send(Increase)
247      .await
248      .unwrap();
249    registry
250      .resolve::<Number>()
251      .await
252      .unwrap()
253      .send(Increase)
254      .await
255      .unwrap();
256    registry
257      .resolve::<Number>()
258      .await
259      .unwrap()
260      .send(Increase)
261      .await
262      .unwrap();
263
264    let value = registry
265      .resolve::<Number>()
266      .await
267      .unwrap()
268      .send(GetValue)
269      .await
270      .unwrap();
271
272    let number = registry.resolve::<Number>().await.unwrap();
273
274    let _dep = number.send(Resolve).await.unwrap();
275
276    assert_eq!(value, 0 + 1 + 1 + 1);
277  }
278}