use std::time::Duration;
use atomr_core::actor::Inbox;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TestProbeError {
#[error("probe timed out waiting for message")]
Timeout,
#[error("probe sender dropped")]
Dropped,
#[error("unexpected message")]
Unexpected,
}
pub struct TestProbe<M: Send + 'static> {
inbox: Inbox<M>,
}
impl<M: Send + 'static> TestProbe<M> {
pub fn new(name: &str) -> Self {
Self { inbox: Inbox::new(name) }
}
pub fn actor_ref(&self) -> &atomr_core::actor::ActorRef<M> {
self.inbox.actor_ref()
}
pub async fn expect_msg(&mut self, timeout: Duration) -> Result<M, TestProbeError> {
match self.inbox.receive(timeout).await {
Ok(m) => Ok(m),
Err(atomr_core::actor::AskError::Timeout) => Err(TestProbeError::Timeout),
Err(_) => Err(TestProbeError::Dropped),
}
}
pub async fn expect_msg_pf<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
where
F: FnMut(&M) -> bool,
{
let m = self.expect_msg(timeout).await?;
if pred(&m) {
Ok(m)
} else {
Err(TestProbeError::Unexpected)
}
}
pub async fn expect_no_msg(&mut self, timeout: Duration) -> Result<(), TestProbeError> {
match tokio::time::timeout(timeout, self.inbox.receive(Duration::from_secs(3600))).await {
Ok(_) => Err(TestProbeError::Unexpected),
Err(_) => Ok(()),
}
}
pub async fn expect_msg_class<T, F>(&mut self, timeout: Duration, extract: F) -> Result<T, TestProbeError>
where
F: FnOnce(M) -> Option<T>,
{
let m = self.expect_msg(timeout).await?;
extract(m).ok_or(TestProbeError::Unexpected)
}
pub async fn receive_n(&mut self, n: usize, timeout: Duration) -> Result<Vec<M>, TestProbeError> {
let deadline = std::time::Instant::now() + timeout;
let mut out = Vec::with_capacity(n);
while out.len() < n {
let remaining =
deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
out.push(self.expect_msg(remaining).await?);
}
Ok(out)
}
pub async fn receive_while<F>(&mut self, timeout: Duration, mut pred: F) -> Result<Vec<M>, TestProbeError>
where
F: FnMut(&M) -> bool,
{
let deadline = std::time::Instant::now() + timeout;
let mut out = Vec::new();
loop {
let remaining = match deadline.checked_duration_since(std::time::Instant::now()) {
Some(d) => d,
None => return Ok(out),
};
match self.expect_msg(remaining).await {
Ok(m) => {
if pred(&m) {
out.push(m);
} else {
return Ok(out);
}
}
Err(TestProbeError::Timeout) => return Ok(out),
Err(e) => return Err(e),
}
}
}
pub async fn fish_for_message<F>(&mut self, timeout: Duration, mut pred: F) -> Result<M, TestProbeError>
where
F: FnMut(&M) -> bool,
{
let deadline = std::time::Instant::now() + timeout;
loop {
let remaining =
deadline.checked_duration_since(std::time::Instant::now()).ok_or(TestProbeError::Timeout)?;
let m = self.expect_msg(remaining).await?;
if pred(&m) {
return Ok(m);
}
}
}
pub async fn expect_all_of(&mut self, timeout: Duration, expected: Vec<M>) -> Result<(), TestProbeError>
where
M: PartialEq + std::fmt::Debug,
{
let n = expected.len();
let received = self.receive_n(n, timeout).await?;
let mut remaining: Vec<M> = received;
for want in expected {
if let Some(idx) = remaining.iter().position(|m| m == &want) {
remaining.remove(idx);
} else {
return Err(TestProbeError::Unexpected);
}
}
Ok(())
}
pub async fn expect_msg_eq(&mut self, timeout: Duration, expected: M) -> Result<M, TestProbeError>
where
M: PartialEq + std::fmt::Debug,
{
let m = self.expect_msg(timeout).await?;
if m == expected {
Ok(m)
} else {
Err(TestProbeError::Unexpected)
}
}
pub async fn expect_msg_all_of_in_order(
&mut self,
timeout: Duration,
expected: Vec<M>,
) -> Result<(), TestProbeError>
where
M: PartialEq + std::fmt::Debug,
{
let received = self.receive_n(expected.len(), timeout).await?;
if received == expected {
Ok(())
} else {
Err(TestProbeError::Unexpected)
}
}
}
pub async fn within<F, Fut, T>(timeout: Duration, body: F) -> Result<T, TestProbeError>
where
F: FnOnce(Duration) -> Fut,
Fut: std::future::Future<Output = Result<T, TestProbeError>>,
{
match tokio::time::timeout(timeout, body(timeout)).await {
Ok(r) => r,
Err(_) => Err(TestProbeError::Timeout),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn probe_receives_message() {
let mut p = TestProbe::<u32>::new("p");
p.actor_ref().tell(42);
let m = p.expect_msg(Duration::from_millis(100)).await.unwrap();
assert_eq!(m, 42);
}
#[tokio::test]
async fn probe_no_msg() {
let mut p = TestProbe::<u32>::new("q");
p.expect_no_msg(Duration::from_millis(20)).await.unwrap();
}
#[tokio::test]
async fn receive_n_collects_messages() {
let mut p = TestProbe::<u32>::new("rn");
for i in 0..3u32 {
p.actor_ref().tell(i);
}
let msgs = p.receive_n(3, Duration::from_millis(100)).await.unwrap();
assert_eq!(msgs, vec![0, 1, 2]);
}
#[tokio::test]
async fn receive_n_times_out_partial() {
let mut p = TestProbe::<u32>::new("rnt");
p.actor_ref().tell(7);
let r = p.receive_n(3, Duration::from_millis(20)).await;
assert!(matches!(r, Err(TestProbeError::Timeout)));
}
#[tokio::test]
async fn fish_for_message_skips_mismatches() {
let mut p = TestProbe::<u32>::new("fish");
p.actor_ref().tell(1);
p.actor_ref().tell(2);
p.actor_ref().tell(99);
let m = p.fish_for_message(Duration::from_millis(100), |m| *m >= 50).await.unwrap();
assert_eq!(m, 99);
}
#[tokio::test]
async fn receive_while_stops_on_predicate() {
let mut p = TestProbe::<u32>::new("rw");
for i in 1..=4u32 {
p.actor_ref().tell(i);
}
let collected = p.receive_while(Duration::from_millis(100), |m| *m < 3).await.unwrap();
assert_eq!(collected, vec![1, 2]);
}
#[tokio::test]
async fn expect_all_of_order_insensitive() {
let mut p = TestProbe::<u32>::new("alf");
for i in [3u32, 1, 2] {
p.actor_ref().tell(i);
}
p.expect_all_of(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
}
#[tokio::test]
async fn expect_msg_eq_succeeds_on_match() {
let mut p = TestProbe::<u32>::new("eq");
p.actor_ref().tell(42);
assert_eq!(p.expect_msg_eq(Duration::from_millis(100), 42).await.unwrap(), 42);
}
#[tokio::test]
async fn expect_msg_eq_fails_on_mismatch() {
let mut p = TestProbe::<u32>::new("eq2");
p.actor_ref().tell(42);
let r = p.expect_msg_eq(Duration::from_millis(100), 7).await;
assert!(matches!(r, Err(TestProbeError::Unexpected)));
}
#[tokio::test]
async fn expect_msg_all_of_in_order_matches_sequence() {
let mut p = TestProbe::<u32>::new("seq");
for i in [1u32, 2, 3] {
p.actor_ref().tell(i);
}
p.expect_msg_all_of_in_order(Duration::from_millis(100), vec![1, 2, 3]).await.unwrap();
}
#[tokio::test]
async fn within_returns_inner_result() {
let r = within(Duration::from_millis(100), |budget| async move {
let mut p = TestProbe::<u32>::new("w");
p.actor_ref().tell(11);
p.expect_msg(budget).await
})
.await
.unwrap();
assert_eq!(r, 11);
}
#[tokio::test]
async fn within_times_out_when_inner_blocks() {
let r: Result<u32, _> = within(Duration::from_millis(10), |budget| async move {
let mut p = TestProbe::<u32>::new("wt");
p.expect_msg(budget).await
})
.await;
assert!(matches!(r, Err(TestProbeError::Timeout)));
}
#[tokio::test]
async fn expect_msg_class_extracts_variant() {
#[derive(Debug, PartialEq)]
#[allow(dead_code)]
enum E {
A(u32),
B(String),
}
let mut p = TestProbe::<E>::new("cls");
p.actor_ref().tell(E::B("hi".into()));
let s = p
.expect_msg_class(Duration::from_millis(100), |m| match m {
E::B(s) => Some(s),
_ => None,
})
.await
.unwrap();
assert_eq!(s, "hi");
}
}