1use std::time::Duration;
4
5use atomr_core::actor::Inbox;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum TestProbeError {
10 #[error("probe timed out waiting for message")]
11 Timeout,
12 #[error("probe sender dropped")]
13 Dropped,
14 #[error("unexpected message")]
15 Unexpected,
16}
17
18pub struct TestProbe<M: Send + 'static> {
19 inbox: Inbox<M>,
20}
21
22impl<M: Send + 'static> TestProbe<M> {
23 pub fn new(name: &str) -> Self {
24 Self { inbox: Inbox::new(name) }
25 }
26
27 pub fn actor_ref(&self) -> &atomr_core::actor::ActorRef<M> {
28 self.inbox.actor_ref()
29 }
30
31 pub async fn expect_msg(&mut self, timeout: Duration) -> Result<M, TestProbeError> {
33 match self.inbox.receive(timeout).await {
34 Ok(m) => Ok(m),
35 Err(atomr_core::actor::AskError::Timeout) => Err(TestProbeError::Timeout),
36 Err(_) => Err(TestProbeError::Dropped),
37 }
38 }
39
40 pub async fn expect_msg_pf<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
42 where
43 F: FnMut(&M) -> bool,
44 {
45 let m = self.expect_msg(timeout).await?;
46 if pred(&m) {
47 Ok(m)
48 } else {
49 Err(TestProbeError::Unexpected)
50 }
51 }
52
53 pub async fn expect_no_msg(&mut self, timeout: Duration) -> Result<(), TestProbeError> {
55 match tokio::time::timeout(timeout, self.inbox.receive(Duration::from_secs(3600))).await {
56 Ok(_) => Err(TestProbeError::Unexpected),
57 Err(_) => Ok(()),
58 }
59 }
60
61 pub async fn expect_msg_class<T, F>(&mut self, timeout: Duration, extract: F) -> Result<T, TestProbeError>
68 where
69 F: FnOnce(M) -> Option<T>,
70 {
71 let m = self.expect_msg(timeout).await?;
72 extract(m).ok_or(TestProbeError::Unexpected)
73 }
74
75 pub async fn receive_n(&mut self, n: usize, timeout: Duration) -> Result<Vec<M>, TestProbeError> {
78 let deadline = std::time::Instant::now() + timeout;
79 let mut out = Vec::with_capacity(n);
80 while out.len() < n {
81 let remaining =
82 deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
83 out.push(self.expect_msg(remaining).await?);
84 }
85 Ok(out)
86 }
87
88 pub async fn receive_while<F>(&mut self, timeout: Duration, mut pred: F) -> Result<Vec<M>, TestProbeError>
92 where
93 F: FnMut(&M) -> bool,
94 {
95 let deadline = std::time::Instant::now() + timeout;
96 let mut out = Vec::new();
97 loop {
98 let remaining = match deadline.checked_duration_since(std::time::Instant::now()) {
99 Some(d) => d,
100 None => return Ok(out),
101 };
102 match self.expect_msg(remaining).await {
103 Ok(m) => {
104 if pred(&m) {
105 out.push(m);
106 } else {
107 return Ok(out);
108 }
109 }
110 Err(TestProbeError::Timeout) => return Ok(out),
111 Err(e) => return Err(e),
112 }
113 }
114 }
115
116 pub async fn fish_for_message<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
118 where
119 F: FnMut(&M) -> bool,
120 {
121 let deadline = std::time::Instant::now() + timeout;
122 loop {
123 let remaining =
124 deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
125 let m = self.expect_msg(remaining).await?;
126 if pred(&m) {
127 return Ok(m);
128 }
129 }
130 }
131
132 pub async fn expect_all_of(&mut self, timeout: Duration, expected: Vec<M>) -> Result<(), TestProbeError>
135 where
136 M: PartialEq + std::fmt::Debug,
137 {
138 let n = expected.len();
139 let received = self.receive_n(n, timeout).await?;
140 let mut remaining: Vec<M> = received;
142 for want in expected {
143 if let Some(idx) = remaining.iter().position(|m| m == &want) {
144 remaining.remove(idx);
145 } else {
146 return Err(TestProbeError::Unexpected);
147 }
148 }
149 Ok(())
150 }
151
152 pub async fn expect_msg_eq(&mut self, timeout: Duration, expected: M) -> Result<M, TestProbeError>
155 where
156 M: PartialEq + std::fmt::Debug,
157 {
158 let m = self.expect_msg(timeout).await?;
159 if m == expected {
160 Ok(m)
161 } else {
162 Err(TestProbeError::Unexpected)
163 }
164 }
165
166 pub async fn expect_msg_all_of_in_order(
170 &mut self,
171 timeout: Duration,
172 expected: Vec<M>,
173 ) -> Result<(), TestProbeError>
174 where
175 M: PartialEq + std::fmt::Debug,
176 {
177 let received = self.receive_n(expected.len(), timeout).await?;
178 if received == expected {
179 Ok(())
180 } else {
181 Err(TestProbeError::Unexpected)
182 }
183 }
184}
185
186pub async fn within<F, Fut, T>(timeout: Duration, body: F) -> Result<T, TestProbeError>
192where
193 F: FnOnce(Duration) -> Fut,
194 Fut: std::future::Future<Output = Result<T, TestProbeError>>,
195{
196 match tokio::time::timeout(timeout, body(timeout)).await {
197 Ok(r) => r,
198 Err(_) => Err(TestProbeError::Timeout),
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[tokio::test]
207 async fn probe_receives_message() {
208 let mut p = TestProbe::<u32>::new("p");
209 p.actor_ref().tell(42);
210 let m = p.expect_msg(Duration::from_millis(100)).await.unwrap();
211 assert_eq!(m, 42);
212 }
213
214 #[tokio::test]
215 async fn probe_no_msg() {
216 let mut p = TestProbe::<u32>::new("q");
217 p.expect_no_msg(Duration::from_millis(20)).await.unwrap();
218 }
219
220 #[tokio::test]
221 async fn receive_n_collects_messages() {
222 let mut p = TestProbe::<u32>::new("rn");
223 for i in 0..3u32 {
224 p.actor_ref().tell(i);
225 }
226 let msgs = p.receive_n(3, Duration::from_millis(100)).await.unwrap();
227 assert_eq!(msgs, vec![0, 1, 2]);
228 }
229
230 #[tokio::test]
231 async fn receive_n_times_out_partial() {
232 let mut p = TestProbe::<u32>::new("rnt");
233 p.actor_ref().tell(7);
234 let r = p.receive_n(3, Duration::from_millis(20)).await;
235 assert!(matches!(r, Err(TestProbeError::Timeout)));
236 }
237
238 #[tokio::test]
239 async fn fish_for_message_skips_mismatches() {
240 let mut p = TestProbe::<u32>::new("fish");
241 p.actor_ref().tell(1);
242 p.actor_ref().tell(2);
243 p.actor_ref().tell(99);
244 let m = p.fish_for_message(Duration::from_millis(100), |m| *m >= 50).await.unwrap();
245 assert_eq!(m, 99);
246 }
247
248 #[tokio::test]
249 async fn receive_while_stops_on_predicate() {
250 let mut p = TestProbe::<u32>::new("rw");
251 for i in 1..=4u32 {
252 p.actor_ref().tell(i);
253 }
254 let collected = p.receive_while(Duration::from_millis(100), |m| *m < 3).await.unwrap();
255 assert_eq!(collected, vec![1, 2]);
256 }
257
258 #[tokio::test]
259 async fn expect_all_of_order_insensitive() {
260 let mut p = TestProbe::<u32>::new("alf");
261 for i in [3u32, 1, 2] {
262 p.actor_ref().tell(i);
263 }
264 p.expect_all_of(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
265 }
266
267 #[tokio::test]
268 async fn expect_msg_eq_succeeds_on_match() {
269 let mut p = TestProbe::<u32>::new("eq");
270 p.actor_ref().tell(42);
271 assert_eq!(p.expect_msg_eq(Duration::from_millis(100), 42).await.unwrap(), 42);
272 }
273
274 #[tokio::test]
275 async fn expect_msg_eq_fails_on_mismatch() {
276 let mut p = TestProbe::<u32>::new("eq2");
277 p.actor_ref().tell(42);
278 let r = p.expect_msg_eq(Duration::from_millis(100), 7).await;
279 assert!(matches!(r, Err(TestProbeError::Unexpected)));
280 }
281
282 #[tokio::test]
283 async fn expect_msg_all_of_in_order_matches_sequence() {
284 let mut p = TestProbe::<u32>::new("seq");
285 for i in [1u32, 2, 3] {
286 p.actor_ref().tell(i);
287 }
288 p.expect_msg_all_of_in_order(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
289 }
290
291 #[tokio::test]
292 async fn within_returns_inner_result() {
293 let r = within(Duration::from_millis(100), |budget| async move {
294 let mut p = TestProbe::<u32>::new("w");
295 p.actor_ref().tell(11);
296 p.expect_msg(budget).await
297 })
298 .await
299 .unwrap();
300 assert_eq!(r, 11);
301 }
302
303 #[tokio::test]
304 async fn within_times_out_when_inner_blocks() {
305 let r: Result<u32, _> = within(Duration::from_millis(10), |budget| async move {
306 let mut p = TestProbe::<u32>::new("wt");
307 p.expect_msg(budget).await
308 })
309 .await;
310 assert!(matches!(r, Err(TestProbeError::Timeout)));
311 }
312
313 #[tokio::test]
314 async fn expect_msg_class_extracts_variant() {
315 #[derive(Debug, PartialEq)]
316 #[allow(dead_code)]
317 enum E {
318 A(u32),
319 B(String),
320 }
321 let mut p = TestProbe::<E>::new("cls");
322 p.actor_ref().tell(E::B("hi".into()));
323 let s = p
324 .expect_msg_class(Duration::from_millis(100), |m| match m {
325 E::B(s) => Some(s),
326 _ => None,
327 })
328 .await
329 .unwrap();
330 assert_eq!(s, "hi");
331 }
332}