flo_state/
mock.rs

1use crate::{
2  Addr, ContainerMessage, Handler, ItemReplySender,
3  Message,
4};
5use futures::future::{abortable, AbortHandle, BoxFuture, Aborted};
6use futures::FutureExt;
7use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::future::Future;
11use tokio::sync::mpsc;
12use tokio::task::{JoinHandle, JoinError};
13
14type BoxedHandler = Box<dyn FnMut(MockMessage) -> BoxFuture<'static, ()> + Send>;
15
16pub struct Mock<S> {
17  tx: mpsc::Sender<ContainerMessage<S>>,
18  abort_handle: AbortHandle,
19  join_handle: Option<JoinHandle<Result<(), Aborted>>>,
20  _t: PhantomData<S>,
21}
22
23impl<S> Mock<S> {
24  pub fn builder() -> MockBuilder<S> {
25    MockBuilder::new()
26  }
27
28  pub fn addr(&self) -> Addr<S> {
29    Addr {
30      tx: self.tx.clone(),
31    }
32  }
33
34  pub async fn shutdown(&mut self) -> Result<(), JoinError> {
35    use tokio::sync::oneshot::channel;
36
37    if let Some(handle) = self.join_handle.take() {
38      let (tx, _rx) = channel();
39      self.tx.send(ContainerMessage::Terminate(tx)).await.ok();
40      handle.await.map(|_| ())
41    } else {
42      Ok(())
43    }
44  }
45}
46
47impl<S> Drop for Mock<S> {
48  fn drop(&mut self) {
49    if self.join_handle.is_some() {
50      self.abort_handle.abort();
51    }
52  }
53}
54
55pub(crate) struct MockMessage {
56  pub(crate) message: Box<dyn Any + Send>,
57  pub(crate) tx: Option<Box<dyn Any + Send>>,
58}
59
60pub struct MockBuilder<S> {
61  handler_map: HashMap<TypeId, BoxedHandler>,
62  _t: PhantomData<S>,
63}
64
65impl<S> MockBuilder<S> {
66  fn new() -> Self {
67    Self {
68      handler_map: HashMap::new(),
69      _t: PhantomData,
70    }
71  }
72
73  pub fn handle<M, F, R>(mut self, mut f: F) -> Self
74  where
75    M: Message + Sync,
76    S: Handler<M>,
77    F: FnMut(M) -> R + Send + 'static,
78    R: Future<Output = M::Result> + Send + 'static
79  {
80    self.handler_map.insert(
81      TypeId::of::<Option<M>>(),
82      Box::new(move |MockMessage { mut message, tx }| {
83        let message = message.downcast_mut::<Option<M>>().unwrap().take().unwrap();
84        if let Some(mut tx) = tx {
85          let tx = tx.downcast_mut::<Option<ItemReplySender<M::Result>>>().unwrap().take().unwrap();
86          let task = f(message);
87          async move {
88            tx.send(task.await).ok();
89          }.boxed()
90        } else {
91          let task = f(message);
92          async move {
93            task.await;
94          }.boxed()
95        }
96      }),
97    );
98    self
99  }
100
101  pub fn build(self) -> Mock<S>
102  where
103    S: Send + 'static,
104  {
105    let (tx, mut rx) = mpsc::channel(1);
106
107    let mut handler_map = self.handler_map;
108
109    let (task, abort_handle) = abortable(async move {
110      while let Some(msg) = rx.recv().await {
111        match msg {
112          ContainerMessage::Item(mut boxed) => {
113            let mock_message = boxed.as_mock_message();
114            let type_id = mock_message.message.as_ref().type_id();
115            match handler_map.get_mut(&type_id) {
116              Some(handler) => handler(mock_message).await,
117              None => panic!("message mock handler not provided: {:?}", type_id),
118            }
119          }
120          ContainerMessage::Terminate(_) => {
121            break
122          },
123        }
124      }
125    });
126
127    Mock {
128      tx,
129      abort_handle,
130      join_handle: tokio::spawn(task).into(),
131      _t: PhantomData,
132    }
133  }
134}
135
136#[cfg(test)]
137mod test {
138  use super::*;
139  use async_trait::async_trait;
140  use crate::{Actor, Context};
141
142  #[tokio::test]
143  async fn test_mock() {
144    struct TestActor;
145    impl Actor for TestActor {}
146    struct TestMessage;
147    impl Message for TestMessage {
148      type Result = i32;
149    }
150    #[async_trait]
151    impl Handler<TestMessage> for TestActor {
152      async fn handle(&mut self, _ctx: &mut Context<Self>, _message: TestMessage) -> <TestMessage as Message>::Result {
153        0
154      }
155    }
156
157    let actor = TestActor;
158    let actor = actor.start();
159
160    assert_eq!(actor.send(TestMessage).await.unwrap(), 0);
161
162    let mut value = 41;
163    let mock = Mock::<TestActor>::builder().handle(move |_: TestMessage| {
164      value += 1;
165      async move {
166        value
167      }
168    }).build();
169
170    assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 42);
171    assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 43);
172    assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 44);
173    assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 45);
174  }
175
176  #[tokio::test]
177  #[should_panic]
178  async fn test_mock_panic() {
179    struct A;
180    impl Actor for A {}
181    struct M;
182    impl Message for M {
183      type Result = ();
184    }
185    #[async_trait]
186    impl Handler<M> for A {
187      async fn handle(&mut self, _ctx: &mut Context<Self>, _message: M) -> <M as Message>::Result {
188        ()
189      }
190    }
191
192    let mut mock = Mock::<A>::builder().build();
193
194    mock.addr().send(M).await.unwrap();
195
196    mock.shutdown().await.unwrap();
197  }
198}