use std::future::Future;
use std::time::Duration;
use tokio::time::{Instant, sleep_until, timeout_at};
const MIN_POLL_INTERVAL: Duration = Duration::from_millis(1);
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum AsyncAssertError<E> {
Timeout {
last_error: Option<E>,
},
BecameUnstable {
error: E,
},
}
impl<E> std::fmt::Display for AsyncAssertError<E>
where
E: std::fmt::Display,
{
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Timeout {
last_error: Some(error),
} => write!(
formatter,
"condition did not succeed before timeout: {error}"
),
Self::Timeout { last_error: None } => {
formatter.write_str("condition did not succeed before timeout")
}
Self::BecameUnstable { error } => {
write!(formatter, "condition became unstable: {error}")
}
}
}
}
impl<E> std::error::Error for AsyncAssertError<E>
where
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Timeout {
last_error: Some(error),
} => Some(error),
Self::BecameUnstable { error } => Some(error),
Self::Timeout { last_error: None } => None,
}
}
}
pub async fn eventually<F, Fut, E>(
timeout: Duration,
interval: Duration,
mut condition: F,
) -> Result<(), AsyncAssertError<E>>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<(), E>>,
{
let deadline = Instant::now() + timeout;
let interval = normalize_poll_interval(interval);
let mut last_error: Option<E> = None;
loop {
match timeout_at(deadline, condition()).await {
Err(_) => return Err(AsyncAssertError::Timeout { last_error }),
Ok(Ok(())) => return Ok(()),
Ok(Err(error)) => last_error = Some(error),
}
if Instant::now() >= deadline {
return Err(AsyncAssertError::Timeout { last_error });
}
sleep_until(next_poll_deadline(interval, deadline)).await;
}
}
pub async fn consistently<F, Fut, E>(
duration: Duration,
interval: Duration,
mut condition: F,
) -> Result<(), AsyncAssertError<E>>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<(), E>>,
{
let deadline = Instant::now() + duration;
let interval = normalize_poll_interval(interval);
loop {
match timeout_at(deadline, condition()).await {
Err(_) => return Ok(()),
Ok(Ok(())) => {}
Ok(Err(error)) => return Err(AsyncAssertError::BecameUnstable { error }),
}
if Instant::now() >= deadline {
return Ok(());
}
sleep_until(next_poll_deadline(interval, deadline)).await;
}
}
fn normalize_poll_interval(interval: Duration) -> Duration {
interval.max(MIN_POLL_INTERVAL)
}
fn next_poll_deadline(interval: Duration, deadline: Instant) -> Instant {
let next = Instant::now() + interval;
if next > deadline { deadline } else { next }
}