1use std::time::Duration;
5
6use rakka_core::actor::Inbox;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum TestProbeError {
11 #[error("probe timed out waiting for message")]
12 Timeout,
13 #[error("probe sender dropped")]
14 Dropped,
15 #[error("unexpected message")]
16 Unexpected,
17}
18
19pub struct TestProbe<M: Send + 'static> {
20 inbox: Inbox<M>,
21}
22
23impl<M: Send + 'static> TestProbe<M> {
24 pub fn new(name: &str) -> Self {
25 Self { inbox: Inbox::new(name) }
26 }
27
28 pub fn actor_ref(&self) -> &rakka_core::actor::ActorRef<M> {
29 self.inbox.actor_ref()
30 }
31
32 pub async fn expect_msg(&mut self, timeout: Duration) -> Result<M, TestProbeError> {
34 match self.inbox.receive(timeout).await {
35 Ok(m) => Ok(m),
36 Err(rakka_core::actor::AskError::Timeout) => Err(TestProbeError::Timeout),
37 Err(_) => Err(TestProbeError::Dropped),
38 }
39 }
40
41 pub async fn expect_msg_pf<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
44 where
45 F: FnMut(&M) -> bool,
46 {
47 let m = self.expect_msg(timeout).await?;
48 if pred(&m) {
49 Ok(m)
50 } else {
51 Err(TestProbeError::Unexpected)
52 }
53 }
54
55 pub async fn expect_no_msg(&mut self, timeout: Duration) -> Result<(), TestProbeError> {
57 match tokio::time::timeout(timeout, self.inbox.receive(Duration::from_secs(3600))).await {
58 Ok(_) => Err(TestProbeError::Unexpected),
59 Err(_) => Ok(()),
60 }
61 }
62
63 pub async fn expect_msg_class<T, F>(&mut self, timeout: Duration, extract: F) -> Result<T, TestProbeError>
70 where
71 F: FnOnce(M) -> Option<T>,
72 {
73 let m = self.expect_msg(timeout).await?;
74 extract(m).ok_or(TestProbeError::Unexpected)
75 }
76
77 pub async fn receive_n(&mut self, n: usize, timeout: Duration) -> Result<Vec<M>, TestProbeError> {
81 let deadline = std::time::Instant::now() + timeout;
82 let mut out = Vec::with_capacity(n);
83 while out.len() < n {
84 let remaining =
85 deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
86 out.push(self.expect_msg(remaining).await?);
87 }
88 Ok(out)
89 }
90
91 pub async fn receive_while<F>(&mut self, timeout: Duration, mut pred: F) -> Result<Vec<M>, TestProbeError>
95 where
96 F: FnMut(&M) -> bool,
97 {
98 let deadline = std::time::Instant::now() + timeout;
99 let mut out = Vec::new();
100 loop {
101 let remaining = match deadline.checked_duration_since(std::time::Instant::now()) {
102 Some(d) => d,
103 None => return Ok(out),
104 };
105 match self.expect_msg(remaining).await {
106 Ok(m) => {
107 if pred(&m) {
108 out.push(m);
109 } else {
110 return Ok(out);
111 }
112 }
113 Err(TestProbeError::Timeout) => return Ok(out),
114 Err(e) => return Err(e),
115 }
116 }
117 }
118
119 pub async fn fish_for_message<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
122 where
123 F: FnMut(&M) -> bool,
124 {
125 let deadline = std::time::Instant::now() + timeout;
126 loop {
127 let remaining =
128 deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
129 let m = self.expect_msg(remaining).await?;
130 if pred(&m) {
131 return Ok(m);
132 }
133 }
134 }
135
136 pub async fn expect_all_of(&mut self, timeout: Duration, expected: Vec<M>) -> Result<(), TestProbeError>
140 where
141 M: PartialEq + std::fmt::Debug,
142 {
143 let n = expected.len();
144 let received = self.receive_n(n, timeout).await?;
145 let mut remaining: Vec<M> = received;
147 for want in expected {
148 if let Some(idx) = remaining.iter().position(|m| m == &want) {
149 remaining.remove(idx);
150 } else {
151 return Err(TestProbeError::Unexpected);
152 }
153 }
154 Ok(())
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[tokio::test]
163 async fn probe_receives_message() {
164 let mut p = TestProbe::<u32>::new("p");
165 p.actor_ref().tell(42);
166 let m = p.expect_msg(Duration::from_millis(100)).await.unwrap();
167 assert_eq!(m, 42);
168 }
169
170 #[tokio::test]
171 async fn probe_no_msg() {
172 let mut p = TestProbe::<u32>::new("q");
173 p.expect_no_msg(Duration::from_millis(20)).await.unwrap();
174 }
175
176 #[tokio::test]
177 async fn receive_n_collects_messages() {
178 let mut p = TestProbe::<u32>::new("rn");
179 for i in 0..3u32 {
180 p.actor_ref().tell(i);
181 }
182 let msgs = p.receive_n(3, Duration::from_millis(100)).await.unwrap();
183 assert_eq!(msgs, vec![0, 1, 2]);
184 }
185
186 #[tokio::test]
187 async fn receive_n_times_out_partial() {
188 let mut p = TestProbe::<u32>::new("rnt");
189 p.actor_ref().tell(7);
190 let r = p.receive_n(3, Duration::from_millis(20)).await;
191 assert!(matches!(r, Err(TestProbeError::Timeout)));
192 }
193
194 #[tokio::test]
195 async fn fish_for_message_skips_mismatches() {
196 let mut p = TestProbe::<u32>::new("fish");
197 p.actor_ref().tell(1);
198 p.actor_ref().tell(2);
199 p.actor_ref().tell(99);
200 let m = p.fish_for_message(Duration::from_millis(100), |m| *m >= 50).await.unwrap();
201 assert_eq!(m, 99);
202 }
203
204 #[tokio::test]
205 async fn receive_while_stops_on_predicate() {
206 let mut p = TestProbe::<u32>::new("rw");
207 for i in 1..=4u32 {
208 p.actor_ref().tell(i);
209 }
210 let collected = p.receive_while(Duration::from_millis(100), |m| *m < 3).await.unwrap();
211 assert_eq!(collected, vec![1, 2]);
212 }
213
214 #[tokio::test]
215 async fn expect_all_of_order_insensitive() {
216 let mut p = TestProbe::<u32>::new("alf");
217 for i in [3u32, 1, 2] {
218 p.actor_ref().tell(i);
219 }
220 p.expect_all_of(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
221 }
222
223 #[tokio::test]
224 async fn expect_msg_class_extracts_variant() {
225 #[derive(Debug, PartialEq)]
226 #[allow(dead_code)]
227 enum E {
228 A(u32),
229 B(String),
230 }
231 let mut p = TestProbe::<E>::new("cls");
232 p.actor_ref().tell(E::B("hi".into()));
233 let s = p
234 .expect_msg_class(Duration::from_millis(100), |m| match m {
235 E::B(s) => Some(s),
236 _ => None,
237 })
238 .await
239 .unwrap();
240 assert_eq!(s, "hi");
241 }
242}