Skip to main content

atomr_testkit/
probe.rs

1//! `TestProbe` — typed message receiver used in assertions.
2
3use std::time::Duration;
4
5use atomr_core::actor::Inbox;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum TestProbeError {
10    #[error("probe timed out waiting for message")]
11    Timeout,
12    #[error("probe sender dropped")]
13    Dropped,
14    #[error("unexpected message")]
15    Unexpected,
16}
17
18pub struct TestProbe<M: Send + 'static> {
19    inbox: Inbox<M>,
20}
21
22impl<M: Send + 'static> TestProbe<M> {
23    pub fn new(name: &str) -> Self {
24        Self { inbox: Inbox::new(name) }
25    }
26
27    pub fn actor_ref(&self) -> &atomr_core::actor::ActorRef<M> {
28        self.inbox.actor_ref()
29    }
30
31    /// Wait for a single message.
32    pub async fn expect_msg(&mut self, timeout: Duration) -> Result<M, TestProbeError> {
33        match self.inbox.receive(timeout).await {
34            Ok(m) => Ok(m),
35            Err(atomr_core::actor::AskError::Timeout) => Err(TestProbeError::Timeout),
36            Err(_) => Err(TestProbeError::Dropped),
37        }
38    }
39
40    /// Wait for a message that matches the given predicate.
41    pub async fn expect_msg_pf<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
42    where
43        F: FnMut(&M) -> bool,
44    {
45        let m = self.expect_msg(timeout).await?;
46        if pred(&m) {
47            Ok(m)
48        } else {
49            Err(TestProbeError::Unexpected)
50        }
51    }
52
53    /// Assert that no message arrives within the given timeout.
54    pub async fn expect_no_msg(&mut self, timeout: Duration) -> Result<(), TestProbeError> {
55        match tokio::time::timeout(timeout, self.inbox.receive(Duration::from_secs(3600))).await {
56            Ok(_) => Err(TestProbeError::Unexpected),
57            Err(_) => Ok(()),
58        }
59    }
60
61    // -- Phase 4 matchers ------------------------------------------
62
63    /// Wait for a message and assert it matches the variant returned
64    /// by `extract`.where `T` selects
65    /// a sub-variant of the message enum. The `extract` closure
66    /// returns `Some(payload)` for the desired variant.
67    pub async fn expect_msg_class<T, F>(&mut self, timeout: Duration, extract: F) -> Result<T, TestProbeError>
68    where
69        F: FnOnce(M) -> Option<T>,
70    {
71        let m = self.expect_msg(timeout).await?;
72        extract(m).ok_or(TestProbeError::Unexpected)
73    }
74
75    /// Receive exactly `n` messages or return [`TestProbeError::Timeout`]
76    /// if `timeout` elapses before they all arrive.
77    pub async fn receive_n(&mut self, n: usize, timeout: Duration) -> Result<Vec<M>, TestProbeError> {
78        let deadline = std::time::Instant::now() + timeout;
79        let mut out = Vec::with_capacity(n);
80        while out.len() < n {
81            let remaining =
82                deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
83            out.push(self.expect_msg(remaining).await?);
84        }
85        Ok(out)
86    }
87
88    /// Receive messages while `pred` returns true, stopping at the
89    /// first message for which `pred` returns false (that message is
90    /// discarded).
91    pub async fn receive_while<F>(&mut self, timeout: Duration, mut pred: F) -> Result<Vec<M>, TestProbeError>
92    where
93        F: FnMut(&M) -> bool,
94    {
95        let deadline = std::time::Instant::now() + timeout;
96        let mut out = Vec::new();
97        loop {
98            let remaining = match deadline.checked_duration_since(std::time::Instant::now()) {
99                Some(d) => d,
100                None => return Ok(out),
101            };
102            match self.expect_msg(remaining).await {
103                Ok(m) => {
104                    if pred(&m) {
105                        out.push(m);
106                    } else {
107                        return Ok(out);
108                    }
109                }
110                Err(TestProbeError::Timeout) => return Ok(out),
111                Err(e) => return Err(e),
112            }
113        }
114    }
115
116    /// Drain messages until one matches `pred`. Discards mismatches.
117    pub async fn fish_for_message<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
118    where
119        F: FnMut(&M) -> bool,
120    {
121        let deadline = std::time::Instant::now() + timeout;
122        loop {
123            let remaining =
124                deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
125            let m = self.expect_msg(remaining).await?;
126            if pred(&m) {
127                return Ok(m);
128            }
129        }
130    }
131
132    /// Receive `expected.len()` messages and assert that the multi-set
133    /// of received messages equals `expected` (order-insensitive).
134    pub async fn expect_all_of(&mut self, timeout: Duration, expected: Vec<M>) -> Result<(), TestProbeError>
135    where
136        M: PartialEq + std::fmt::Debug,
137    {
138        let n = expected.len();
139        let received = self.receive_n(n, timeout).await?;
140        // O(n²) intentional — n is small in practice.
141        let mut remaining: Vec<M> = received;
142        for want in expected {
143            if let Some(idx) = remaining.iter().position(|m| m == &want) {
144                remaining.remove(idx);
145            } else {
146                return Err(TestProbeError::Unexpected);
147            }
148        }
149        Ok(())
150    }
151
152    /// Wait for a message and assert it equals `expected`.
153    /// `ExpectMsg<T>(T expected)`.
154    pub async fn expect_msg_eq(&mut self, timeout: Duration, expected: M) -> Result<M, TestProbeError>
155    where
156        M: PartialEq + std::fmt::Debug,
157    {
158        let m = self.expect_msg(timeout).await?;
159        if m == expected {
160            Ok(m)
161        } else {
162            Err(TestProbeError::Unexpected)
163        }
164    }
165
166    /// Receive `n` messages, asserting they appear in the exact order
167    /// of `expected`.with sequential
168    /// matching semantics.
169    pub async fn expect_msg_all_of_in_order(
170        &mut self,
171        timeout: Duration,
172        expected: Vec<M>,
173    ) -> Result<(), TestProbeError>
174    where
175        M: PartialEq + std::fmt::Debug,
176    {
177        let received = self.receive_n(expected.len(), timeout).await?;
178        if received == expected {
179            Ok(())
180        } else {
181            Err(TestProbeError::Unexpected)
182        }
183    }
184}
185
186/// Run `body` with the given budget, returning [`TestProbeError::Timeout`]
187/// if it does not finish in time.
188///
189/// `body` receives the original `Duration` so it can pass it down to
190/// `expect_msg`-style helpers and have them inherit the deadline.
191pub async fn within<F, Fut, T>(timeout: Duration, body: F) -> Result<T, TestProbeError>
192where
193    F: FnOnce(Duration) -> Fut,
194    Fut: std::future::Future<Output = Result<T, TestProbeError>>,
195{
196    match tokio::time::timeout(timeout, body(timeout)).await {
197        Ok(r) => r,
198        Err(_) => Err(TestProbeError::Timeout),
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[tokio::test]
207    async fn probe_receives_message() {
208        let mut p = TestProbe::<u32>::new("p");
209        p.actor_ref().tell(42);
210        let m = p.expect_msg(Duration::from_millis(100)).await.unwrap();
211        assert_eq!(m, 42);
212    }
213
214    #[tokio::test]
215    async fn probe_no_msg() {
216        let mut p = TestProbe::<u32>::new("q");
217        p.expect_no_msg(Duration::from_millis(20)).await.unwrap();
218    }
219
220    #[tokio::test]
221    async fn receive_n_collects_messages() {
222        let mut p = TestProbe::<u32>::new("rn");
223        for i in 0..3u32 {
224            p.actor_ref().tell(i);
225        }
226        let msgs = p.receive_n(3, Duration::from_millis(100)).await.unwrap();
227        assert_eq!(msgs, vec![0, 1, 2]);
228    }
229
230    #[tokio::test]
231    async fn receive_n_times_out_partial() {
232        let mut p = TestProbe::<u32>::new("rnt");
233        p.actor_ref().tell(7);
234        let r = p.receive_n(3, Duration::from_millis(20)).await;
235        assert!(matches!(r, Err(TestProbeError::Timeout)));
236    }
237
238    #[tokio::test]
239    async fn fish_for_message_skips_mismatches() {
240        let mut p = TestProbe::<u32>::new("fish");
241        p.actor_ref().tell(1);
242        p.actor_ref().tell(2);
243        p.actor_ref().tell(99);
244        let m = p.fish_for_message(Duration::from_millis(100), |m| *m >= 50).await.unwrap();
245        assert_eq!(m, 99);
246    }
247
248    #[tokio::test]
249    async fn receive_while_stops_on_predicate() {
250        let mut p = TestProbe::<u32>::new("rw");
251        for i in 1..=4u32 {
252            p.actor_ref().tell(i);
253        }
254        let collected = p.receive_while(Duration::from_millis(100), |m| *m < 3).await.unwrap();
255        assert_eq!(collected, vec![1, 2]);
256    }
257
258    #[tokio::test]
259    async fn expect_all_of_order_insensitive() {
260        let mut p = TestProbe::<u32>::new("alf");
261        for i in [3u32, 1, 2] {
262            p.actor_ref().tell(i);
263        }
264        p.expect_all_of(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
265    }
266
267    #[tokio::test]
268    async fn expect_msg_eq_succeeds_on_match() {
269        let mut p = TestProbe::<u32>::new("eq");
270        p.actor_ref().tell(42);
271        assert_eq!(p.expect_msg_eq(Duration::from_millis(100), 42).await.unwrap(), 42);
272    }
273
274    #[tokio::test]
275    async fn expect_msg_eq_fails_on_mismatch() {
276        let mut p = TestProbe::<u32>::new("eq2");
277        p.actor_ref().tell(42);
278        let r = p.expect_msg_eq(Duration::from_millis(100), 7).await;
279        assert!(matches!(r, Err(TestProbeError::Unexpected)));
280    }
281
282    #[tokio::test]
283    async fn expect_msg_all_of_in_order_matches_sequence() {
284        let mut p = TestProbe::<u32>::new("seq");
285        for i in [1u32, 2, 3] {
286            p.actor_ref().tell(i);
287        }
288        p.expect_msg_all_of_in_order(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
289    }
290
291    #[tokio::test]
292    async fn within_returns_inner_result() {
293        let r = within(Duration::from_millis(100), |budget| async move {
294            let mut p = TestProbe::<u32>::new("w");
295            p.actor_ref().tell(11);
296            p.expect_msg(budget).await
297        })
298        .await
299        .unwrap();
300        assert_eq!(r, 11);
301    }
302
303    #[tokio::test]
304    async fn within_times_out_when_inner_blocks() {
305        let r: Result<u32, _> = within(Duration::from_millis(10), |budget| async move {
306            let mut p = TestProbe::<u32>::new("wt");
307            p.expect_msg(budget).await
308        })
309        .await;
310        assert!(matches!(r, Err(TestProbeError::Timeout)));
311    }
312
313    #[tokio::test]
314    async fn expect_msg_class_extracts_variant() {
315        #[derive(Debug, PartialEq)]
316        #[allow(dead_code)]
317        enum E {
318            A(u32),
319            B(String),
320        }
321        let mut p = TestProbe::<E>::new("cls");
322        p.actor_ref().tell(E::B("hi".into()));
323        let s = p
324            .expect_msg_class(Duration::from_millis(100), |m| match m {
325                E::B(s) => Some(s),
326                _ => None,
327            })
328            .await
329            .unwrap();
330        assert_eq!(s, "hi");
331    }
332}