flo_state/
lib.rs

1pub mod error;
2pub mod mock;
3pub mod registry;
4pub mod reply;
5
6pub use async_trait::async_trait;
7pub use registry::{Deferred, Registry, RegistryError, RegistryRef, Service};
8
9use crate::mock::MockMessage;
10use error::{Error, Result};
11use std::marker::PhantomData;
12use std::sync::Arc;
13use std::time::Duration;
14use std::{future::Future, time::Instant};
15use tokio::sync::mpsc;
16use tokio::sync::mpsc::error::SendTimeoutError;
17use tokio::sync::oneshot;
18use tokio::sync::oneshot::Sender;
19use tokio_util::sync::CancellationToken;
20
21#[async_trait]
22pub trait Actor: Send + Sized + 'static {
23  async fn started(&mut self, _ctx: &mut Context<Self>) {}
24  async fn stopped(self) {}
25  fn start(self) -> Owner<Self> {
26    Owner::new(self)
27  }
28}
29
30pub trait Message: Send + 'static {
31  type Result: Send + 'static;
32}
33
34#[async_trait]
35pub trait Handler<M>: Send + Sized
36where
37  M: Message,
38{
39  async fn handle(&mut self, ctx: &mut Context<Self>, message: M) -> M::Result;
40}
41
42pub struct Recipient<M> {
43  tx: Arc<dyn MessageSender<M>>,
44}
45
46impl<M> Clone for Recipient<M> {
47  fn clone(&self) -> Self {
48    Self {
49      tx: self.tx.clone(),
50    }
51  }
52}
53
54impl<M> Recipient<M>
55where
56  M: Message,
57{
58  pub async fn send(&self, message: M) -> Result<M::Result> {
59    self.tx.send(message).await
60  }
61  pub async fn send_timeout(&self, timeout: Duration, message: M) -> Result<M::Result> {
62    self.tx.send_timeout(timeout, message).await
63  }
64}
65
66#[async_trait]
67trait MessageSender<M>
68where
69  M: Message,
70  Self: Send + Sync,
71{
72  async fn send(&self, message: M) -> Result<M::Result>;
73  async fn send_timeout(&self, timeout: Duration, message: M) -> Result<M::Result>;
74}
75
76#[async_trait]
77impl<S, M> MessageSender<M> for Addr<S>
78where
79  S: Handler<M>,
80  M: Message,
81{
82  async fn send(&self, message: M) -> Result<M::Result> {
83    Addr::send(self, message).await
84  }
85  async fn send_timeout(&self, timeout: Duration, message: M) -> Result<M::Result> {
86    Addr::send_timeout(self, timeout, message).await
87  }
88}
89
90type ItemReplySender<T> = Sender<T>;
91
92#[async_trait]
93trait ItemObj<S>: Send + 'static {
94  fn as_mock_message(&mut self) -> MockMessage;
95  async fn handle(&mut self, state: &mut S, ctx: &mut Context<S>);
96}
97
98struct MsgItem<M>
99where
100  M: Message,
101{
102  message: M,
103  tx: ItemReplySender<M::Result>,
104}
105
106impl<M> MsgItem<M>
107where
108  M: Message,
109{
110  async fn handle<S>(self, state: &mut S, ctx: &mut Context<S>)
111  where
112    S: Handler<M>,
113  {
114    self.tx.send(state.handle(ctx, self.message).await).ok();
115  }
116}
117
118#[async_trait]
119impl<S, M> ItemObj<S> for Option<MsgItem<M>>
120where
121  S: Handler<M>,
122  M: Message,
123{
124  fn as_mock_message(&mut self) -> MockMessage {
125    let MsgItem { message, tx } = self.take().expect("item already consumed");
126    MockMessage {
127      message: Box::new(Some(message)),
128      tx: Some(Box::new(Some(tx))),
129    }
130  }
131
132  async fn handle(&mut self, state: &mut S, ctx: &mut Context<S>) {
133    if let Some(item) = self.take() {
134      item.handle(state, ctx).await;
135    }
136  }
137}
138
139struct NotifyItem<M>
140where
141  M: Message,
142{
143  message: M,
144}
145
146#[async_trait]
147impl<S, M> ItemObj<S> for Option<NotifyItem<M>>
148where
149  S: Handler<M>,
150  M: Message<Result = ()>,
151{
152  fn as_mock_message(&mut self) -> MockMessage {
153    let NotifyItem { message } = self.take().expect("item already consumed");
154    MockMessage {
155      message: Box::new(Some(message)),
156      tx: None,
157    }
158  }
159
160  async fn handle(&mut self, state: &mut S, ctx: &mut Context<S>) {
161    if let Some(item) = self.take() {
162      state.handle(ctx, item.message).await;
163    }
164  }
165}
166
167#[derive(Debug)]
168pub struct Addr<S> {
169  tx: mpsc::Sender<ContainerMessage<S>>,
170}
171
172impl<S> Clone for Addr<S> {
173  fn clone(&self) -> Self {
174    Addr {
175      tx: self.tx.clone(),
176    }
177  }
178}
179
180impl<S> Addr<S> {
181  pub fn recipient<M>(self) -> Recipient<M>
182  where
183    S: Handler<M> + 'static,
184    M: Message,
185  {
186    Recipient { tx: Arc::new(self) }
187  }
188
189  /// Sends a message to the actor and wait for the result
190  pub async fn send<M>(&self, message: M) -> Result<M::Result>
191  where
192    M: Message + Send + 'static,
193    S: Handler<M>,
194  {
195    send(&self.tx, message, None).await
196  }
197
198  /// Sends a message to the actor and wait for the result, but only for a limited time.
199  pub async fn send_timeout<M>(&self, timeout: Duration, message: M) -> Result<M::Result>
200  where
201    M: Message + Send + 'static,
202    S: Handler<M>,
203  {
204    send(&self.tx, message, Some(timeout)).await
205  }
206
207  /// Sends a message to the actor without waiting for the result
208  pub async fn notify<M>(&self, message: M) -> Result<()>
209  where
210    M: Message<Result = ()> + Send + 'static,
211    S: Handler<M>,
212  {
213    notify(&self.tx, message, None).await
214  }
215
216  /// Sends a message to the actor without waiting for the result, but only for a limited time.
217  pub async fn notify_timeout<M>(&self, timeout: Duration, message: M) -> Result<()>
218  where
219    M: Message<Result = ()> + Send + 'static,
220    S: Handler<M>,
221  {
222    notify(&self.tx, message, Some(timeout)).await
223  }
224}
225
226pub struct Context<S> {
227  tx: mpsc::Sender<ContainerMessage<S>>,
228  token: CancellationToken,
229}
230
231impl<S> Context<S> {
232  pub fn addr(&self) -> Addr<S> {
233    Addr {
234      tx: self.tx.clone(),
235    }
236  }
237
238  /// Spawns a future into the context.
239  /// All futures spawned into the container will be cancelled if the container dropped.
240  pub fn spawn<F>(&self, f: F)
241  where
242    F: Future<Output = ()> + Send + 'static,
243  {
244    let token = self.token.child_token();
245    tokio::spawn(async move {
246      tokio::select! {
247        _ = token.cancelled() => {}
248        _ = f => {}
249      }
250    });
251  }
252
253  /// Sends the message msg to self after a specified period of time.
254  pub fn send_later<T>(&self, msg: T, after: Duration)
255  where
256    T: Message<Result = ()>,
257    S: Handler<T> + 'static,
258  {
259    let addr = self.addr();
260    self.spawn(async move {
261      tokio::time::sleep(after).await;
262      addr.send(msg).await.ok();
263    });
264  }
265}
266
267impl<S> Drop for Context<S> {
268  fn drop(&mut self) {
269    self.token.cancel();
270  }
271}
272
273#[derive(Debug)]
274pub struct Owner<S> {
275  tx: mpsc::Sender<ContainerMessage<S>>,
276  token: CancellationToken,
277}
278
279impl<S> Drop for Owner<S> {
280  fn drop(&mut self) {
281    self.token.cancel();
282  }
283}
284
285impl<S> Owner<S>
286where
287  S: Actor,
288{
289  pub fn new(initial_state: S) -> Self {
290    OwnerBuilder::new().start(initial_state)
291  }
292
293  pub fn build() -> OwnerBuilder<S> {
294    OwnerBuilder::new()
295  }
296
297  pub fn addr(&self) -> Addr<S> {
298    Addr {
299      tx: self.tx.clone(),
300    }
301  }
302
303  /// Sends a message to the actor and wait for the result
304  pub async fn send<M>(&self, message: M) -> Result<M::Result>
305  where
306    M: Message + Send + 'static,
307    S: Handler<M>,
308  {
309    send(&self.tx, message, None).await
310  }
311
312  /// Sends a message to the actor and wait for the result, but only for a limited time.
313  pub async fn send_timeout<M>(&self, timeout: Duration, message: M) -> Result<M::Result>
314  where
315    M: Message + Send + 'static,
316    S: Handler<M>,
317  {
318    send(&self.tx, message, Some(timeout)).await
319  }
320
321  /// Sends a message to the actor without waiting for the result
322  pub async fn notify<M>(&self, message: M) -> Result<()>
323  where
324    M: Message<Result = ()> + Send + 'static,
325    S: Handler<M>,
326  {
327    notify(&self.tx, message, None).await
328  }
329
330  /// Sends a message to the actor without waiting for the result, but only for a limited time.
331  pub async fn notify_timeout<M>(&self, timeout: Duration, message: M) -> Result<()>
332  where
333    M: Message<Result = ()> + Send + 'static,
334    S: Handler<M>,
335  {
336    notify(&self.tx, message, Some(timeout)).await
337  }
338
339  /// Spawns a future into the container.
340  /// All futures spawned into the container will be cancelled if the container dropped.
341  pub fn spawn<F>(&self, f: F)
342  where
343    F: Future<Output = ()> + Send + 'static,
344  {
345    let token = self.token.child_token();
346    tokio::spawn(async move {
347      tokio::select! {
348        _ = token.cancelled() => {}
349        _ = f => {}
350      }
351    });
352  }
353
354  pub async fn shutdown(self) -> Result<S> {
355    let (tx, rx) = oneshot::channel();
356    self
357      .tx
358      .send(ContainerMessage::Terminate(tx))
359      .await
360      .map_err(|_| Error::WorkerGone)?;
361    rx.await.map_err(|_| Error::WorkerGone)
362  }
363}
364
365#[derive(Debug)]
366pub struct OwnerBuilder<S> {
367  mailbox_capacity: usize,
368  _p: PhantomData<S>,
369}
370
371impl<S> OwnerBuilder<S> {
372  const DEFAULT_MAILBOX_CAPACITY: usize = 32;
373
374  fn new() -> Self {
375    OwnerBuilder {
376      mailbox_capacity: Self::DEFAULT_MAILBOX_CAPACITY,
377      _p: PhantomData,
378    }
379  }
380
381  pub fn mailbox_capacity(self, mailbox_capacity: usize) -> Self {
382    Self {
383      mailbox_capacity,
384      ..self
385    }
386  }
387
388  pub fn start(self, initial_state: S) -> Owner<S>
389  where
390    S: Actor,
391  {
392    let token = CancellationToken::new();
393    let (tx, mut rx) = mpsc::channel(self.mailbox_capacity);
394
395    tokio::spawn({
396      let mut ctx = Context {
397        tx: tx.clone(),
398        token: token.child_token(),
399      };
400      let token = token.child_token();
401      async move {
402        let mut state = initial_state;
403        let mut dropping = false;
404
405        state.started(&mut ctx).await;
406
407        loop {
408          tokio::select! {
409            _ = token.cancelled() => {
410              state.stopped().await;
411              break;
412            }
413            Some(item) = rx.recv(), if !dropping => {
414              match item {
415                ContainerMessage::Item(mut item) => {
416                  let handling = item.handle(&mut state, &mut ctx);
417                  tokio::select! {
418                    _ = handling => {},
419                    _ = token.cancelled() => {
420                      dropping = true;
421                      continue;
422                    }
423                  }
424                }
425                ContainerMessage::Terminate(tx) => {
426                  rx.close();
427
428                  // drain messages
429                  while let Some(item) = rx.recv().await {
430                    match item {
431                      ContainerMessage::Item(mut item) => {
432                        item.handle(&mut state, &mut ctx).await;
433                      }
434                      ContainerMessage::Terminate(_) => unreachable!(),
435                    }
436                  }
437
438                  tx.send(state).ok();
439                  break;
440                }
441              }
442            }
443          }
444        }
445      }
446    });
447
448    Owner { tx, token }
449  }
450}
451
452async fn send<S, M>(
453  tx: &mpsc::Sender<ContainerMessage<S>>,
454  message: M,
455  timeout: Option<Duration>,
456) -> Result<M::Result>
457where
458  S: Handler<M>,
459  M: Message + Send + 'static,
460{
461  let (reply_tx, reply_rx) = oneshot::channel();
462  let boxed = Box::new(Some(MsgItem {
463    message,
464    tx: reply_tx,
465  }));
466
467  if let Some(timeout) = timeout {
468    let t = Instant::now();
469    tx.send_timeout(ContainerMessage::Item(boxed), timeout)
470      .await
471      .map_err(|err| match err {
472        SendTimeoutError::Timeout(_) => Error::SendTimeout,
473        SendTimeoutError::Closed(_) => Error::WorkerGone,
474      })?;
475    let timeout = match timeout.checked_sub(Instant::now() - t) {
476      None => return Err(Error::SendTimeout),
477      Some(v) => v,
478    };
479    let res = tokio::time::timeout(timeout, reply_rx)
480      .await
481      .map_err(|_| Error::SendTimeout)?;
482    let res = res.map_err(|_| Error::WorkerGone)?;
483    Ok(res)
484  } else {
485    tx.send(ContainerMessage::Item(boxed))
486      .await
487      .map_err(|_| Error::WorkerGone)?;
488    let res = reply_rx.await.map_err(|_| Error::WorkerGone)?;
489    Ok(res)
490  }
491}
492
493async fn notify<S, M>(
494  tx: &mpsc::Sender<ContainerMessage<S>>,
495  message: M,
496  timeout: Option<Duration>,
497) -> Result<()>
498where
499  S: Handler<M>,
500  M: Message<Result = ()> + Send + 'static,
501{
502  let boxed = Box::new(Some(NotifyItem { message }));
503  if let Some(timeout) = timeout {
504    tx.send_timeout(ContainerMessage::Item(boxed), timeout)
505      .await
506      .map_err(|err| match err {
507        SendTimeoutError::Timeout(_) => Error::SendTimeout,
508        SendTimeoutError::Closed(_) => Error::WorkerGone,
509      })?;
510  } else {
511    tx.send(ContainerMessage::Item(boxed))
512      .await
513      .map_err(|_| Error::WorkerGone)?;
514  }
515  Ok(())
516}
517
518enum ContainerMessage<S> {
519  Item(Box<dyn ItemObj<S>>),
520  Terminate(oneshot::Sender<S>),
521}
522
523impl Actor for () {}