1use std::thread;
2use std::time::{Duration, Instant};
3
4use crate::error::CliError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct WaitOutcome {
8 pub attempts: u32,
9 pub elapsed_ms: u64,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct WaitPolicy {
14 pub timeout_ms: u64,
15 pub poll_ms: u64,
16}
17
18impl WaitPolicy {
19 pub fn new(timeout_ms: u64, poll_ms: u64) -> Self {
20 Self {
21 timeout_ms: timeout_ms.max(1),
22 poll_ms: poll_ms.max(1),
23 }
24 }
25}
26
27pub fn sleep_ms(ms: u64) {
28 if ms > 0 {
29 thread::sleep(Duration::from_millis(ms));
30 }
31}
32
33pub fn wait_until<F>(
34 condition_name: &str,
35 timeout_ms: u64,
36 poll_ms: u64,
37 mut check: F,
38) -> Result<WaitOutcome, CliError>
39where
40 F: FnMut() -> Result<bool, CliError>,
41{
42 wait_until_with_policy(
43 condition_name,
44 WaitPolicy::new(timeout_ms, poll_ms),
45 &mut check,
46 )
47}
48
49pub fn wait_until_with_policy<F>(
50 condition_name: &str,
51 policy: WaitPolicy,
52 mut check: F,
53) -> Result<WaitOutcome, CliError>
54where
55 F: FnMut() -> Result<bool, CliError>,
56{
57 let started = Instant::now();
58 let deadline = started + Duration::from_millis(policy.timeout_ms);
59 let mut attempts = 0u32;
60
61 loop {
62 attempts = attempts.saturating_add(1);
63 if check()? {
64 return Ok(WaitOutcome {
65 attempts,
66 elapsed_ms: started.elapsed().as_millis() as u64,
67 });
68 }
69
70 if Instant::now() >= deadline {
71 return Err(CliError::runtime(format!(
72 "timed out waiting for {condition_name} after {}ms",
73 policy.timeout_ms
74 )));
75 }
76
77 sleep_ms(policy.poll_ms);
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use std::sync::atomic::{AtomicU32, Ordering};
84
85 use super::wait_until;
86
87 #[test]
88 fn wait_until_succeeds_before_timeout() {
89 static ATTEMPTS: AtomicU32 = AtomicU32::new(0);
90 ATTEMPTS.store(0, Ordering::SeqCst);
91
92 let outcome = wait_until("ready", 200, 1, || {
93 let n = ATTEMPTS.fetch_add(1, Ordering::SeqCst);
94 Ok(n >= 2)
95 })
96 .expect("should succeed");
97
98 assert!(outcome.attempts >= 3);
99 assert!(outcome.elapsed_ms <= 200);
100 }
101
102 #[test]
103 fn wait_until_errors_on_timeout() {
104 let err = wait_until("never", 5, 1, || Ok(false)).expect_err("should timeout");
105 assert_eq!(err.exit_code(), 1);
106 assert!(err.to_string().contains("timed out waiting"));
107 }
108}