Skip to main content

mm1_test_rt/rt/
context.rs

1use futures::FutureExt;
2use mm1_address::address::Address;
3use mm1_address::subnet::NetAddress;
4use mm1_common::errors::error_of::ErrorOf;
5use mm1_common::types::Never;
6use mm1_core::context::{
7    Bind, BindArgs, BindErrorKind, Fork, ForkErrorKind, InitDone, Linking, Messaging, Now, Quit,
8    RecvErrorKind, Start, Stop, Watching,
9};
10use mm1_core::envelope::Envelope;
11use mm1_proto_system::{SpawnErrorKind, StartErrorKind};
12use tokio::sync::{mpsc, oneshot};
13use tokio::time::Instant;
14
15use crate::rt::{Query, TaskKey, TestContext, query};
16
17impl TaskKey {
18    pub fn actor(address: Address) -> Self {
19        Self {
20            actor:   address,
21            context: address,
22        }
23    }
24
25    pub fn fork(actor: Address, fork: Address) -> Self {
26        assert_ne!(actor, fork);
27        Self {
28            actor,
29            context: fork,
30        }
31    }
32}
33
34impl<R> Now for TestContext<R>
35where
36    R: Send,
37{
38    type Instant = Instant;
39
40    fn now(&self) -> Self::Instant {
41        Instant::now()
42    }
43}
44
45impl<R> Start<R> for TestContext<R>
46where
47    R: Send,
48{
49    async fn spawn(&mut self, runnable: R, link: bool) -> Result<Address, ErrorOf<SpawnErrorKind>> {
50        let task_key = self.task_key;
51        invoke(
52            &self.queries_tx,
53            move |outcome_tx| {
54                query::Spawn {
55                    task_key,
56                    runnable,
57                    link,
58                    outcome_tx,
59                }
60            },
61            OnRxFailure::Panic,
62        )
63        .await
64    }
65
66    async fn start(
67        &mut self,
68        runnable: R,
69        link: bool,
70        start_timeout: std::time::Duration,
71    ) -> Result<Address, ErrorOf<StartErrorKind>> {
72        let task_key = self.task_key;
73        invoke(
74            &self.queries_tx,
75            move |outcome_tx| {
76                query::Start {
77                    task_key,
78                    runnable,
79                    link,
80                    start_timeout,
81                    outcome_tx,
82                }
83            },
84            OnRxFailure::Panic,
85        )
86        .await
87    }
88}
89
90impl<R> Bind<NetAddress> for TestContext<R>
91where
92    R: Send + 'static,
93{
94    async fn bind(&mut self, args: BindArgs<NetAddress>) -> Result<(), ErrorOf<BindErrorKind>> {
95        let task_key = self.task_key;
96        invoke(
97            &self.queries_tx,
98            |outcome_tx| {
99                query::Bind {
100                    task_key,
101                    args,
102                    outcome_tx,
103                }
104            },
105            OnRxFailure::Panic,
106        )
107        .await
108    }
109}
110
111impl<R> Fork for TestContext<R>
112where
113    R: Send + 'static,
114{
115    async fn fork(&mut self) -> Result<Self, ErrorOf<ForkErrorKind>> {
116        let task_key = self.task_key;
117        invoke(
118            &self.queries_tx,
119            |outcome_tx| {
120                query::Fork {
121                    task_key,
122                    outcome_tx,
123                }
124            },
125            OnRxFailure::Panic,
126        )
127        .await
128    }
129
130    async fn run<F, Fut>(mut self, fun: F)
131    where
132        F: FnOnce(Self) -> Fut,
133        F: Send + 'static,
134        Fut: Future + Send + 'static,
135    {
136        let queries_tx = self.queries_tx.clone();
137        let task_key = self.task_key;
138        let address_lease = self.address_lease.take();
139        let task_fut = async move {
140            let _ = fun(self).await;
141        }
142        .boxed();
143        invoke(
144            &queries_tx,
145            |outcome_tx| {
146                query::ForkRun {
147                    task_key,
148                    address_lease,
149                    task_fut,
150                    outcome_tx,
151                }
152            },
153            OnRxFailure::Panic,
154        )
155        .await
156    }
157}
158
159impl<R> Quit for TestContext<R>
160where
161    R: Send,
162{
163    async fn quit_ok(&mut self) -> Never {
164        let task_key = self.task_key;
165        invoke(
166            &self.queries_tx,
167            move |outcome_tx| {
168                query::Quit {
169                    task_key,
170                    result: Ok(()),
171                    outcome_tx,
172                }
173            },
174            OnRxFailure::Freeze,
175        )
176        .await
177    }
178
179    async fn quit_err<E>(&mut self, reason: E) -> Never
180    where
181        E: std::error::Error + Send + Sync + 'static,
182    {
183        let task_key = self.task_key;
184        invoke(
185            &self.queries_tx,
186            move |outcome_tx| {
187                query::Quit {
188                    task_key,
189                    result: Err(Box::new(reason)),
190                    outcome_tx,
191                }
192            },
193            OnRxFailure::Freeze,
194        )
195        .await
196    }
197}
198
199impl<R> Messaging for TestContext<R>
200where
201    R: Send,
202{
203    fn address(&self) -> Address {
204        self.task_key.context
205    }
206
207    async fn close(&mut self) {
208        let task_key = self.task_key;
209        invoke(
210            &self.queries_tx,
211            move |outcome_tx| {
212                query::RecvClose {
213                    task_key,
214                    outcome_tx,
215                }
216            },
217            OnRxFailure::Panic,
218        )
219        .await
220    }
221
222    async fn recv(&mut self) -> Result<Envelope, ErrorOf<RecvErrorKind>> {
223        let task_key = self.task_key;
224        invoke(
225            &self.queries_tx,
226            move |outcome_tx| {
227                query::Recv {
228                    task_key,
229                    outcome_tx,
230                }
231            },
232            OnRxFailure::Panic,
233        )
234        .await
235    }
236
237    async fn send(
238        &mut self,
239        envelope: Envelope,
240    ) -> Result<(), ErrorOf<mm1_core::context::SendErrorKind>> {
241        let task_key = self.task_key;
242
243        invoke(
244            &self.queries_tx,
245            move |outcome_tx| {
246                query::Tell {
247                    task_key,
248                    to: envelope.header().to,
249                    envelope,
250                    outcome_tx,
251                }
252            },
253            OnRxFailure::Panic,
254        )
255        .await
256    }
257
258    async fn forward(
259        &mut self,
260        to: Address,
261        envelope: Envelope,
262    ) -> Result<(), ErrorOf<mm1_core::context::SendErrorKind>> {
263        let task_key = self.task_key;
264
265        invoke(
266            &self.queries_tx,
267            move |outcome_tx| {
268                query::Tell {
269                    task_key,
270                    to,
271                    envelope,
272                    outcome_tx,
273                }
274            },
275            OnRxFailure::Panic,
276        )
277        .await
278    }
279}
280
281impl<R> Watching for TestContext<R>
282where
283    R: Send,
284{
285    async fn watch(&mut self, peer: Address) -> mm1_proto_system::WatchRef {
286        let task_key = self.task_key;
287        invoke(
288            &self.queries_tx,
289            move |outcome_tx| {
290                query::Watch {
291                    task_key,
292                    peer,
293                    outcome_tx,
294                }
295            },
296            OnRxFailure::Panic,
297        )
298        .await
299    }
300
301    async fn unwatch(&mut self, watch_ref: mm1_proto_system::WatchRef) {
302        let task_key = self.task_key;
303        invoke(
304            &self.queries_tx,
305            move |outcome_tx| {
306                query::Unwatch {
307                    task_key,
308                    watch_ref,
309                    outcome_tx,
310                }
311            },
312            OnRxFailure::Panic,
313        )
314        .await
315    }
316}
317
318impl<R> Linking for TestContext<R>
319where
320    R: Send,
321{
322    async fn link(&mut self, peer: Address) {
323        let task_key = self.task_key;
324        invoke(
325            &self.queries_tx,
326            |outcome_tx| {
327                query::Link {
328                    task_key,
329                    peer,
330                    outcome_tx,
331                }
332            },
333            OnRxFailure::Panic,
334        )
335        .await
336    }
337
338    async fn unlink(&mut self, peer: Address) {
339        let task_key = self.task_key;
340        invoke(
341            &self.queries_tx,
342            |outcome_tx| {
343                query::Unlink {
344                    task_key,
345                    peer,
346                    outcome_tx,
347                }
348            },
349            OnRxFailure::Panic,
350        )
351        .await
352    }
353
354    async fn set_trap_exit(&mut self, enable: bool) {
355        let task_key = self.task_key;
356        invoke(
357            &self.queries_tx,
358            |outcome_tx| {
359                query::SetTrapExit {
360                    task_key,
361                    enable,
362                    outcome_tx,
363                }
364            },
365            OnRxFailure::Panic,
366        )
367        .await
368    }
369}
370
371impl<R> InitDone for TestContext<R>
372where
373    R: Send,
374{
375    async fn init_done(&mut self, address: Address) {
376        let task_key = self.task_key;
377        invoke(
378            &self.queries_tx,
379            move |outcome_tx| {
380                query::InitDone {
381                    task_key,
382                    address,
383                    outcome_tx,
384                }
385            },
386            OnRxFailure::Panic,
387        )
388        .await
389    }
390}
391
392impl<R> Stop for TestContext<R>
393where
394    R: Send,
395{
396    async fn exit(&mut self, peer: Address) -> bool {
397        let task_key = self.task_key;
398        invoke(
399            &self.queries_tx,
400            move |outcome_tx| {
401                query::Exit {
402                    task_key,
403                    peer,
404                    outcome_tx,
405                }
406            },
407            OnRxFailure::Panic,
408        )
409        .await
410    }
411
412    async fn kill(&mut self, peer: Address) -> bool {
413        let task_key = self.task_key;
414        invoke(
415            &self.queries_tx,
416            move |outcome_tx| {
417                query::Kill {
418                    task_key,
419                    peer,
420                    outcome_tx,
421                }
422            },
423            OnRxFailure::Panic,
424        )
425        .await
426    }
427}
428
429async fn invoke<R, F, Q, Out>(
430    queries_tx: &mpsc::UnboundedSender<Query<R>>,
431    make_query: F,
432    on_rx_failure: OnRxFailure,
433) -> Out
434where
435    F: FnOnce(oneshot::Sender<Out>) -> Q,
436    Q: Into<Query<R>>,
437{
438    let (outcome_tx, outcome_rx) = oneshot::channel();
439    let query = make_query(outcome_tx);
440    queries_tx.send(query.into()).expect("tx failed");
441    match (outcome_rx.await, on_rx_failure) {
442        (Ok(ret), _) => ret,
443        (Err(reason), OnRxFailure::Panic) => panic!("rx failed: {reason}"),
444        (Err(_), OnRxFailure::Freeze) => std::future::pending().await,
445    }
446}
447
448enum OnRxFailure {
449    Panic,
450    Freeze,
451}