Skip to main content

rakka_testkit/
probe.rs

1//! `TestProbe` — typed message receiver used in assertions.
2//! akka.net: `Akka.TestKit/TestProbe.cs`.
3
4use std::time::Duration;
5
6use rakka_core::actor::Inbox;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum TestProbeError {
11    #[error("probe timed out waiting for message")]
12    Timeout,
13    #[error("probe sender dropped")]
14    Dropped,
15    #[error("unexpected message")]
16    Unexpected,
17}
18
19pub struct TestProbe<M: Send + 'static> {
20    inbox: Inbox<M>,
21}
22
23impl<M: Send + 'static> TestProbe<M> {
24    pub fn new(name: &str) -> Self {
25        Self { inbox: Inbox::new(name) }
26    }
27
28    pub fn actor_ref(&self) -> &rakka_core::actor::ActorRef<M> {
29        self.inbox.actor_ref()
30    }
31
32    /// Wait for a single message (akka.net: `ExpectMsg`).
33    pub async fn expect_msg(&mut self, timeout: Duration) -> Result<M, TestProbeError> {
34        match self.inbox.receive(timeout).await {
35            Ok(m) => Ok(m),
36            Err(rakka_core::actor::AskError::Timeout) => Err(TestProbeError::Timeout),
37            Err(_) => Err(TestProbeError::Dropped),
38        }
39    }
40
41    /// Wait for a message that matches the given predicate.
42    /// akka.net: `ExpectMsg<T>(Func<T, bool>)`.
43    pub async fn expect_msg_pf<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
44    where
45        F: FnMut(&M) -> bool,
46    {
47        let m = self.expect_msg(timeout).await?;
48        if pred(&m) {
49            Ok(m)
50        } else {
51            Err(TestProbeError::Unexpected)
52        }
53    }
54
55    /// Assert that no message arrives within the given timeout.
56    pub async fn expect_no_msg(&mut self, timeout: Duration) -> Result<(), TestProbeError> {
57        match tokio::time::timeout(timeout, self.inbox.receive(Duration::from_secs(3600))).await {
58            Ok(_) => Err(TestProbeError::Unexpected),
59            Err(_) => Ok(()),
60        }
61    }
62
63    // -- Phase 4 matchers ------------------------------------------
64
65    /// Wait for a message and assert it matches the variant returned
66    /// by `extract`. Akka.NET: `ExpectMsg<T>(...)` where `T` selects
67    /// a sub-variant of the message enum. The `extract` closure
68    /// returns `Some(payload)` for the desired variant.
69    pub async fn expect_msg_class<T, F>(&mut self, timeout: Duration, extract: F) -> Result<T, TestProbeError>
70    where
71        F: FnOnce(M) -> Option<T>,
72    {
73        let m = self.expect_msg(timeout).await?;
74        extract(m).ok_or(TestProbeError::Unexpected)
75    }
76
77    /// Receive exactly `n` messages or return [`TestProbeError::Timeout`]
78    /// if `timeout` elapses before they all arrive.
79    /// Akka.NET: `ReceiveN(int n, TimeSpan)`.
80    pub async fn receive_n(&mut self, n: usize, timeout: Duration) -> Result<Vec<M>, TestProbeError> {
81        let deadline = std::time::Instant::now() + timeout;
82        let mut out = Vec::with_capacity(n);
83        while out.len() < n {
84            let remaining =
85                deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
86            out.push(self.expect_msg(remaining).await?);
87        }
88        Ok(out)
89    }
90
91    /// Receive messages while `pred` returns true, stopping at the
92    /// first message for which `pred` returns false (that message is
93    /// discarded). Akka.NET: `ReceiveWhile`.
94    pub async fn receive_while<F>(&mut self, timeout: Duration, mut pred: F) -> Result<Vec<M>, TestProbeError>
95    where
96        F: FnMut(&M) -> bool,
97    {
98        let deadline = std::time::Instant::now() + timeout;
99        let mut out = Vec::new();
100        loop {
101            let remaining = match deadline.checked_duration_since(std::time::Instant::now()) {
102                Some(d) => d,
103                None => return Ok(out),
104            };
105            match self.expect_msg(remaining).await {
106                Ok(m) => {
107                    if pred(&m) {
108                        out.push(m);
109                    } else {
110                        return Ok(out);
111                    }
112                }
113                Err(TestProbeError::Timeout) => return Ok(out),
114                Err(e) => return Err(e),
115            }
116        }
117    }
118
119    /// Drain messages until one matches `pred`. Discards mismatches.
120    /// Akka.NET: `FishForMessage`.
121    pub async fn fish_for_message<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
122    where
123        F: FnMut(&M) -> bool,
124    {
125        let deadline = std::time::Instant::now() + timeout;
126        loop {
127            let remaining =
128                deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
129            let m = self.expect_msg(remaining).await?;
130            if pred(&m) {
131                return Ok(m);
132            }
133        }
134    }
135
136    /// Receive `expected.len()` messages and assert that the multi-set
137    /// of received messages equals `expected` (order-insensitive).
138    /// Akka.NET: `ExpectMsgAllOf`.
139    pub async fn expect_all_of(&mut self, timeout: Duration, expected: Vec<M>) -> Result<(), TestProbeError>
140    where
141        M: PartialEq + std::fmt::Debug,
142    {
143        let n = expected.len();
144        let received = self.receive_n(n, timeout).await?;
145        // O(n²) intentional — n is small in practice.
146        let mut remaining: Vec<M> = received;
147        for want in expected {
148            if let Some(idx) = remaining.iter().position(|m| m == &want) {
149                remaining.remove(idx);
150            } else {
151                return Err(TestProbeError::Unexpected);
152            }
153        }
154        Ok(())
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[tokio::test]
163    async fn probe_receives_message() {
164        let mut p = TestProbe::<u32>::new("p");
165        p.actor_ref().tell(42);
166        let m = p.expect_msg(Duration::from_millis(100)).await.unwrap();
167        assert_eq!(m, 42);
168    }
169
170    #[tokio::test]
171    async fn probe_no_msg() {
172        let mut p = TestProbe::<u32>::new("q");
173        p.expect_no_msg(Duration::from_millis(20)).await.unwrap();
174    }
175
176    #[tokio::test]
177    async fn receive_n_collects_messages() {
178        let mut p = TestProbe::<u32>::new("rn");
179        for i in 0..3u32 {
180            p.actor_ref().tell(i);
181        }
182        let msgs = p.receive_n(3, Duration::from_millis(100)).await.unwrap();
183        assert_eq!(msgs, vec![0, 1, 2]);
184    }
185
186    #[tokio::test]
187    async fn receive_n_times_out_partial() {
188        let mut p = TestProbe::<u32>::new("rnt");
189        p.actor_ref().tell(7);
190        let r = p.receive_n(3, Duration::from_millis(20)).await;
191        assert!(matches!(r, Err(TestProbeError::Timeout)));
192    }
193
194    #[tokio::test]
195    async fn fish_for_message_skips_mismatches() {
196        let mut p = TestProbe::<u32>::new("fish");
197        p.actor_ref().tell(1);
198        p.actor_ref().tell(2);
199        p.actor_ref().tell(99);
200        let m = p.fish_for_message(Duration::from_millis(100), |m| *m >= 50).await.unwrap();
201        assert_eq!(m, 99);
202    }
203
204    #[tokio::test]
205    async fn receive_while_stops_on_predicate() {
206        let mut p = TestProbe::<u32>::new("rw");
207        for i in 1..=4u32 {
208            p.actor_ref().tell(i);
209        }
210        let collected = p.receive_while(Duration::from_millis(100), |m| *m < 3).await.unwrap();
211        assert_eq!(collected, vec![1, 2]);
212    }
213
214    #[tokio::test]
215    async fn expect_all_of_order_insensitive() {
216        let mut p = TestProbe::<u32>::new("alf");
217        for i in [3u32, 1, 2] {
218            p.actor_ref().tell(i);
219        }
220        p.expect_all_of(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
221    }
222
223    #[tokio::test]
224    async fn expect_msg_class_extracts_variant() {
225        #[derive(Debug, PartialEq)]
226        #[allow(dead_code)]
227        enum E {
228            A(u32),
229            B(String),
230        }
231        let mut p = TestProbe::<E>::new("cls");
232        p.actor_ref().tell(E::B("hi".into()));
233        let s = p
234            .expect_msg_class(Duration::from_millis(100), |m| match m {
235                E::B(s) => Some(s),
236                _ => None,
237            })
238            .await
239            .unwrap();
240        assert_eq!(s, "hi");
241    }
242}