1use crate::{
2 Addr, ContainerMessage, Handler, ItemReplySender,
3 Message,
4};
5use futures::future::{abortable, AbortHandle, BoxFuture, Aborted};
6use futures::FutureExt;
7use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::future::Future;
11use tokio::sync::mpsc;
12use tokio::task::{JoinHandle, JoinError};
13
14type BoxedHandler = Box<dyn FnMut(MockMessage) -> BoxFuture<'static, ()> + Send>;
15
16pub struct Mock<S> {
17 tx: mpsc::Sender<ContainerMessage<S>>,
18 abort_handle: AbortHandle,
19 join_handle: Option<JoinHandle<Result<(), Aborted>>>,
20 _t: PhantomData<S>,
21}
22
23impl<S> Mock<S> {
24 pub fn builder() -> MockBuilder<S> {
25 MockBuilder::new()
26 }
27
28 pub fn addr(&self) -> Addr<S> {
29 Addr {
30 tx: self.tx.clone(),
31 }
32 }
33
34 pub async fn shutdown(&mut self) -> Result<(), JoinError> {
35 use tokio::sync::oneshot::channel;
36
37 if let Some(handle) = self.join_handle.take() {
38 let (tx, _rx) = channel();
39 self.tx.send(ContainerMessage::Terminate(tx)).await.ok();
40 handle.await.map(|_| ())
41 } else {
42 Ok(())
43 }
44 }
45}
46
47impl<S> Drop for Mock<S> {
48 fn drop(&mut self) {
49 if self.join_handle.is_some() {
50 self.abort_handle.abort();
51 }
52 }
53}
54
55pub(crate) struct MockMessage {
56 pub(crate) message: Box<dyn Any + Send>,
57 pub(crate) tx: Option<Box<dyn Any + Send>>,
58}
59
60pub struct MockBuilder<S> {
61 handler_map: HashMap<TypeId, BoxedHandler>,
62 _t: PhantomData<S>,
63}
64
65impl<S> MockBuilder<S> {
66 fn new() -> Self {
67 Self {
68 handler_map: HashMap::new(),
69 _t: PhantomData,
70 }
71 }
72
73 pub fn handle<M, F, R>(mut self, mut f: F) -> Self
74 where
75 M: Message + Sync,
76 S: Handler<M>,
77 F: FnMut(M) -> R + Send + 'static,
78 R: Future<Output = M::Result> + Send + 'static
79 {
80 self.handler_map.insert(
81 TypeId::of::<Option<M>>(),
82 Box::new(move |MockMessage { mut message, tx }| {
83 let message = message.downcast_mut::<Option<M>>().unwrap().take().unwrap();
84 if let Some(mut tx) = tx {
85 let tx = tx.downcast_mut::<Option<ItemReplySender<M::Result>>>().unwrap().take().unwrap();
86 let task = f(message);
87 async move {
88 tx.send(task.await).ok();
89 }.boxed()
90 } else {
91 let task = f(message);
92 async move {
93 task.await;
94 }.boxed()
95 }
96 }),
97 );
98 self
99 }
100
101 pub fn build(self) -> Mock<S>
102 where
103 S: Send + 'static,
104 {
105 let (tx, mut rx) = mpsc::channel(1);
106
107 let mut handler_map = self.handler_map;
108
109 let (task, abort_handle) = abortable(async move {
110 while let Some(msg) = rx.recv().await {
111 match msg {
112 ContainerMessage::Item(mut boxed) => {
113 let mock_message = boxed.as_mock_message();
114 let type_id = mock_message.message.as_ref().type_id();
115 match handler_map.get_mut(&type_id) {
116 Some(handler) => handler(mock_message).await,
117 None => panic!("message mock handler not provided: {:?}", type_id),
118 }
119 }
120 ContainerMessage::Terminate(_) => {
121 break
122 },
123 }
124 }
125 });
126
127 Mock {
128 tx,
129 abort_handle,
130 join_handle: tokio::spawn(task).into(),
131 _t: PhantomData,
132 }
133 }
134}
135
136#[cfg(test)]
137mod test {
138 use super::*;
139 use async_trait::async_trait;
140 use crate::{Actor, Context};
141
142 #[tokio::test]
143 async fn test_mock() {
144 struct TestActor;
145 impl Actor for TestActor {}
146 struct TestMessage;
147 impl Message for TestMessage {
148 type Result = i32;
149 }
150 #[async_trait]
151 impl Handler<TestMessage> for TestActor {
152 async fn handle(&mut self, _ctx: &mut Context<Self>, _message: TestMessage) -> <TestMessage as Message>::Result {
153 0
154 }
155 }
156
157 let actor = TestActor;
158 let actor = actor.start();
159
160 assert_eq!(actor.send(TestMessage).await.unwrap(), 0);
161
162 let mut value = 41;
163 let mock = Mock::<TestActor>::builder().handle(move |_: TestMessage| {
164 value += 1;
165 async move {
166 value
167 }
168 }).build();
169
170 assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 42);
171 assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 43);
172 assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 44);
173 assert_eq!(mock.addr().send(TestMessage).await.unwrap(), 45);
174 }
175
176 #[tokio::test]
177 #[should_panic]
178 async fn test_mock_panic() {
179 struct A;
180 impl Actor for A {}
181 struct M;
182 impl Message for M {
183 type Result = ();
184 }
185 #[async_trait]
186 impl Handler<M> for A {
187 async fn handle(&mut self, _ctx: &mut Context<Self>, _message: M) -> <M as Message>::Result {
188 ()
189 }
190 }
191
192 let mut mock = Mock::<A>::builder().build();
193
194 mock.addr().send(M).await.unwrap();
195
196 mock.shutdown().await.unwrap();
197 }
198}