use std::future::Future;
use std::time::Duration;
use crate::AssertionFailure;
pub const DEFAULT_EXPECT_TIMEOUT: Duration = Duration::from_secs(5);
pub const POLL_INTERVALS: &[u64] = &[100, 250, 500, 1000];
#[derive(Debug, Clone)]
pub struct MatchError {
pub expected: String,
pub received: String,
}
impl MatchError {
pub fn new(expected: impl Into<String>, received: impl Into<String>) -> Self {
Self {
expected: expected.into(),
received: received.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExpectContext {
pub method: &'static str,
pub subject: String,
pub is_not: bool,
}
pub async fn poll_until<F, Fut>(timeout: Duration, ctx: ExpectContext, mut check: F) -> Result<(), AssertionFailure>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<(), MatchError>>,
{
let deadline = tokio::time::Instant::now() + timeout;
let mut last_error: Option<MatchError>;
let mut interval_idx = 0;
let mut call_log: Vec<String> = Vec::new();
call_log.push(format!("expect.{} with timeout {}ms", ctx.method, timeout.as_millis()));
call_log.push(format!("waiting for {}", ctx.subject));
loop {
match check().await {
Ok(()) => return Ok(()),
Err(e) => {
call_log.push(format!(" unexpected value {}", e.received));
last_error = Some(e);
let interval_ms = POLL_INTERVALS
.get(interval_idx)
.copied()
.unwrap_or_else(|| POLL_INTERVALS.last().copied().unwrap_or(1000));
interval_idx += 1;
let sleep_dur = Duration::from_millis(interval_ms);
if tokio::time::Instant::now() + sleep_dur > deadline {
break;
}
tokio::time::sleep(sleep_dur).await;
},
}
}
let err = last_error.unwrap_or_else(|| MatchError::new("(unknown)", "(unknown)"));
let not_str = if ctx.is_not { ".not" } else { "" };
let timeout_ms = timeout.as_millis();
let call_log_str = if call_log.is_empty() {
String::new()
} else {
format!(
"\n\nCall log:\n{}",
call_log
.iter()
.map(|l| format!(" - {l}"))
.collect::<Vec<_>>()
.join("\n")
)
};
let message = format!(
"expect({subject}){not_str}.{method}() failed\n\n\
Locator: {locator}\n\
Expected: {expected}\n\
Received: {received}\n\
Timeout: {timeout_ms}ms\
{call_log_str}",
subject = ctx.subject,
method = ctx.method,
locator = ctx.subject,
expected = err.expected,
received = err.received,
);
let diff = format!("Expected: {}\nReceived: {}", err.expected, err.received);
Err(AssertionFailure::new(message, Some(diff)))
}