ferridriver_expect/
poll.rs1use std::future::Future;
5use std::time::Duration;
6
7use crate::AssertionFailure;
8
9pub const DEFAULT_EXPECT_TIMEOUT: Duration = Duration::from_secs(5);
11
12pub const POLL_INTERVALS: &[u64] = &[100, 250, 500, 1000];
14
15#[derive(Debug, Clone)]
18pub struct MatchError {
19 pub expected: String,
20 pub received: String,
21}
22
23impl MatchError {
24 pub fn new(expected: impl Into<String>, received: impl Into<String>) -> Self {
25 Self {
26 expected: expected.into(),
27 received: received.into(),
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
35pub struct ExpectContext {
36 pub method: &'static str,
38 pub subject: String,
40 pub is_not: bool,
42}
43
44pub async fn poll_until<F, Fut>(timeout: Duration, ctx: ExpectContext, mut check: F) -> Result<(), AssertionFailure>
48where
49 F: FnMut() -> Fut,
50 Fut: Future<Output = Result<(), MatchError>>,
51{
52 let deadline = tokio::time::Instant::now() + timeout;
53 let mut last_error: Option<MatchError>;
54 let mut interval_idx = 0;
55 let mut call_log: Vec<String> = Vec::new();
56 call_log.push(format!("expect.{} with timeout {}ms", ctx.method, timeout.as_millis()));
57 call_log.push(format!("waiting for {}", ctx.subject));
58
59 loop {
60 match check().await {
61 Ok(()) => return Ok(()),
62 Err(e) => {
63 call_log.push(format!(" unexpected value {}", e.received));
64 last_error = Some(e);
65 let interval_ms = POLL_INTERVALS
66 .get(interval_idx)
67 .copied()
68 .unwrap_or_else(|| POLL_INTERVALS.last().copied().unwrap_or(1000));
69 interval_idx += 1;
70
71 let sleep_dur = Duration::from_millis(interval_ms);
72 if tokio::time::Instant::now() + sleep_dur > deadline {
73 break;
74 }
75 tokio::time::sleep(sleep_dur).await;
76 },
77 }
78 }
79
80 let err = last_error.unwrap_or_else(|| MatchError::new("(unknown)", "(unknown)"));
81
82 let not_str = if ctx.is_not { ".not" } else { "" };
83 let timeout_ms = timeout.as_millis();
84
85 let call_log_str = if call_log.is_empty() {
86 String::new()
87 } else {
88 format!(
89 "\n\nCall log:\n{}",
90 call_log
91 .iter()
92 .map(|l| format!(" - {l}"))
93 .collect::<Vec<_>>()
94 .join("\n")
95 )
96 };
97
98 let message = format!(
99 "expect({subject}){not_str}.{method}() failed\n\n\
100 Locator: {locator}\n\
101 Expected: {expected}\n\
102 Received: {received}\n\
103 Timeout: {timeout_ms}ms\
104 {call_log_str}",
105 subject = ctx.subject,
106 method = ctx.method,
107 locator = ctx.subject,
108 expected = err.expected,
109 received = err.received,
110 );
111
112 let diff = format!("Expected: {}\nReceived: {}", err.expected, err.received);
113
114 Err(AssertionFailure::new(message, Some(diff)))
115}