1use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone)]
17pub enum AssertionCheckResult {
18 Pass,
20 Fail(String),
22}
23
24impl AssertionCheckResult {
25 #[must_use]
27 pub const fn is_pass(&self) -> bool {
28 matches!(self, Self::Pass)
29 }
30
31 #[must_use]
33 pub const fn is_fail(&self) -> bool {
34 matches!(self, Self::Fail(_))
35 }
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub struct RetryConfig {
41 pub timeout: Duration,
43 pub poll_interval: Duration,
45 pub max_retries: usize,
47}
48
49impl Default for RetryConfig {
50 fn default() -> Self {
51 Self {
52 timeout: Duration::from_secs(5),
53 poll_interval: Duration::from_millis(100),
54 max_retries: 0,
55 }
56 }
57}
58
59impl RetryConfig {
60 #[must_use]
62 pub const fn new(timeout: Duration) -> Self {
63 Self {
64 timeout,
65 poll_interval: Duration::from_millis(100),
66 max_retries: 0,
67 }
68 }
69
70 #[must_use]
72 pub const fn with_poll_interval(mut self, interval: Duration) -> Self {
73 self.poll_interval = interval;
74 self
75 }
76
77 #[must_use]
79 pub const fn with_max_retries(mut self, max: usize) -> Self {
80 self.max_retries = max;
81 self
82 }
83
84 #[must_use]
86 pub const fn fast() -> Self {
87 Self {
88 timeout: Duration::from_millis(500),
89 poll_interval: Duration::from_millis(50),
90 max_retries: 0,
91 }
92 }
93
94 #[must_use]
96 pub const fn slow() -> Self {
97 Self {
98 timeout: Duration::from_secs(30),
99 poll_interval: Duration::from_millis(500),
100 max_retries: 0,
101 }
102 }
103}
104
105pub struct RetryAssertion<F>
122where
123 F: Fn() -> AssertionCheckResult,
124{
125 check: F,
126 config: RetryConfig,
127 description: Option<String>,
128}
129
130impl<F> RetryAssertion<F>
131where
132 F: Fn() -> AssertionCheckResult,
133{
134 #[must_use]
136 pub fn new(check: F) -> Self {
137 Self {
138 check,
139 config: RetryConfig::default(),
140 description: None,
141 }
142 }
143
144 #[must_use]
146 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
147 self.config.timeout = timeout;
148 self
149 }
150
151 #[must_use]
153 pub const fn with_poll_interval(mut self, interval: Duration) -> Self {
154 self.config.poll_interval = interval;
155 self
156 }
157
158 #[must_use]
160 pub const fn with_max_retries(mut self, max: usize) -> Self {
161 self.config.max_retries = max;
162 self
163 }
164
165 #[must_use]
167 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
168 self.description = Some(desc.into());
169 self
170 }
171
172 #[must_use]
174 pub const fn with_config(mut self, config: RetryConfig) -> Self {
175 self.config = config;
176 self
177 }
178
179 #[must_use]
181 pub const fn config(&self) -> &RetryConfig {
182 &self.config
183 }
184
185 #[allow(unused_assignments)]
191 pub fn verify(&self) -> Result<RetryResult, RetryError> {
192 let start = Instant::now();
193 let mut attempts = 0;
194 let mut last_error: Option<String> = None;
195
196 loop {
197 attempts += 1;
198
199 match (self.check)() {
200 AssertionCheckResult::Pass => {
201 return Ok(RetryResult {
202 attempts,
203 duration: start.elapsed(),
204 });
205 }
206 AssertionCheckResult::Fail(msg) => {
207 last_error = Some(msg);
208 }
209 }
210
211 if start.elapsed() >= self.config.timeout {
213 return Err(RetryError {
214 message: last_error.unwrap_or_default(),
215 attempts,
216 duration: start.elapsed(),
217 description: self.description.clone(),
218 });
219 }
220
221 if self.config.max_retries > 0 && attempts >= self.config.max_retries {
223 return Err(RetryError {
224 message: last_error.unwrap_or_default(),
225 attempts,
226 duration: start.elapsed(),
227 description: self.description.clone(),
228 });
229 }
230
231 std::thread::sleep(self.config.poll_interval);
233 }
234 }
235
236 pub fn verify_once(&self) -> Result<(), RetryError> {
242 match (self.check)() {
243 AssertionCheckResult::Pass => Ok(()),
244 AssertionCheckResult::Fail(msg) => Err(RetryError {
245 message: msg,
246 attempts: 1,
247 duration: Duration::ZERO,
248 description: self.description.clone(),
249 }),
250 }
251 }
252}
253
254impl<F> Debug for RetryAssertion<F>
255where
256 F: Fn() -> AssertionCheckResult,
257{
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("RetryAssertion")
260 .field("config", &self.config)
261 .field("description", &self.description)
262 .finish()
263 }
264}
265
266#[derive(Debug, Clone, Copy)]
268pub struct RetryResult {
269 pub attempts: usize,
271 pub duration: Duration,
273}
274
275#[derive(Debug, Clone)]
277pub struct RetryError {
278 pub message: String,
280 pub attempts: usize,
282 pub duration: Duration,
284 pub description: Option<String>,
286}
287
288impl std::fmt::Display for RetryError {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 if let Some(ref desc) = self.description {
291 write!(f, "{desc}: ")?;
292 }
293 write!(
294 f,
295 "assertion failed after {} attempt(s) ({:.2}s): {}",
296 self.attempts,
297 self.duration.as_secs_f64(),
298 self.message
299 )
300 }
301}
302
303impl std::error::Error for RetryError {}
304
305pub fn retry_eq<T: PartialEq + Debug + Clone + 'static>(
311 get_actual: impl Fn() -> T + 'static,
312 expected: T,
313) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
314 let expected = expected;
315 RetryAssertion::new(move || {
316 let actual = get_actual();
317 if actual == expected {
318 AssertionCheckResult::Pass
319 } else {
320 AssertionCheckResult::Fail(format!("expected {expected:?}, got {actual:?}"))
321 }
322 })
323}
324
325pub fn retry_true(
327 check: impl Fn() -> bool + 'static,
328 message: impl Into<String>,
329) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
330 let message = message.into();
331 RetryAssertion::new(move || {
332 if check() {
333 AssertionCheckResult::Pass
334 } else {
335 AssertionCheckResult::Fail(message.clone())
336 }
337 })
338}
339
340pub fn retry_some<T>(
342 get_opt: impl Fn() -> Option<T> + 'static,
343) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
344 RetryAssertion::new(move || {
345 if get_opt().is_some() {
346 AssertionCheckResult::Pass
347 } else {
348 AssertionCheckResult::Fail("expected Some, got None".into())
349 }
350 })
351}
352
353pub fn retry_none<T>(
355 get_opt: impl Fn() -> Option<T> + 'static,
356) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
357 RetryAssertion::new(move || {
358 if get_opt().is_none() {
359 AssertionCheckResult::Pass
360 } else {
361 AssertionCheckResult::Fail("expected None, got Some".into())
362 }
363 })
364}
365
366pub fn retry_contains(
368 get_haystack: impl Fn() -> String + 'static,
369 needle: impl Into<String>,
370) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
371 let needle = needle.into();
372 RetryAssertion::new(move || {
373 let haystack = get_haystack();
374 if haystack.contains(&needle) {
375 AssertionCheckResult::Pass
376 } else {
377 AssertionCheckResult::Fail(format!("expected '{haystack}' to contain '{needle}'"))
378 }
379 })
380}
381
382#[cfg(test)]
387#[allow(clippy::unwrap_used, clippy::expect_used)]
388mod tests {
389 use super::*;
390 use std::sync::atomic::{AtomicUsize, Ordering};
391 use std::sync::Arc;
392
393 mod assertion_check_result {
394 use super::*;
395
396 #[test]
397 fn test_pass() {
398 let result = AssertionCheckResult::Pass;
399 assert!(result.is_pass());
400 assert!(!result.is_fail());
401 }
402
403 #[test]
404 fn test_fail() {
405 let result = AssertionCheckResult::Fail("error".into());
406 assert!(result.is_fail());
407 assert!(!result.is_pass());
408 }
409 }
410
411 mod retry_config {
412 use super::*;
413
414 #[test]
415 fn test_default() {
416 let config = RetryConfig::default();
417 assert_eq!(config.timeout, Duration::from_secs(5));
418 assert_eq!(config.poll_interval, Duration::from_millis(100));
419 assert_eq!(config.max_retries, 0);
420 }
421
422 #[test]
423 fn test_new() {
424 let config = RetryConfig::new(Duration::from_secs(10));
425 assert_eq!(config.timeout, Duration::from_secs(10));
426 }
427
428 #[test]
429 fn test_with_poll_interval() {
430 let config = RetryConfig::default().with_poll_interval(Duration::from_millis(50));
431 assert_eq!(config.poll_interval, Duration::from_millis(50));
432 }
433
434 #[test]
435 fn test_with_max_retries() {
436 let config = RetryConfig::default().with_max_retries(3);
437 assert_eq!(config.max_retries, 3);
438 }
439
440 #[test]
441 fn test_fast() {
442 let config = RetryConfig::fast();
443 assert_eq!(config.timeout, Duration::from_millis(500));
444 assert_eq!(config.poll_interval, Duration::from_millis(50));
445 }
446
447 #[test]
448 fn test_slow() {
449 let config = RetryConfig::slow();
450 assert_eq!(config.timeout, Duration::from_secs(30));
451 assert_eq!(config.poll_interval, Duration::from_millis(500));
452 }
453 }
454
455 mod retry_assertion {
456 use super::*;
457
458 #[test]
459 fn test_immediate_pass() {
460 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass);
461 let result = assertion.verify().unwrap();
462 assert_eq!(result.attempts, 1);
463 }
464
465 #[test]
466 fn test_immediate_fail_with_timeout() {
467 let assertion =
468 RetryAssertion::new(|| AssertionCheckResult::Fail("always fails".into()))
469 .with_timeout(Duration::from_millis(100))
470 .with_poll_interval(Duration::from_millis(20));
471
472 let err = assertion.verify().unwrap_err();
473 assert!(err.attempts > 1);
474 assert!(err.message.contains("always fails"));
475 }
476
477 #[test]
478 fn test_eventual_pass() {
479 let counter = Arc::new(AtomicUsize::new(0));
480 let counter_clone = counter;
481
482 let assertion = RetryAssertion::new(move || {
483 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
484 if count >= 2 {
485 AssertionCheckResult::Pass
486 } else {
487 AssertionCheckResult::Fail("not yet".into())
488 }
489 })
490 .with_timeout(Duration::from_secs(1))
491 .with_poll_interval(Duration::from_millis(10));
492
493 let result = assertion.verify().unwrap();
494 assert_eq!(result.attempts, 3);
495 }
496
497 #[test]
498 fn test_max_retries() {
499 let counter = Arc::new(AtomicUsize::new(0));
500 let counter_clone = counter;
501
502 let assertion = RetryAssertion::new(move || {
503 let _ = counter_clone.fetch_add(1, Ordering::SeqCst);
504 AssertionCheckResult::Fail("always fails".into())
505 })
506 .with_max_retries(3)
507 .with_timeout(Duration::from_secs(10));
508
509 let err = assertion.verify().unwrap_err();
510 assert_eq!(err.attempts, 3);
511 }
512
513 #[test]
514 fn test_with_description() {
515 let assertion = RetryAssertion::new(|| AssertionCheckResult::Fail("error".into()))
516 .with_description("checking visibility")
517 .with_max_retries(1);
518
519 let err = assertion.verify().unwrap_err();
520 assert_eq!(err.description, Some("checking visibility".to_string()));
521 }
522
523 #[test]
524 fn test_with_config() {
525 let config = RetryConfig::fast();
526 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass).with_config(config);
527 assert_eq!(assertion.config().timeout, Duration::from_millis(500));
528 }
529
530 #[test]
531 fn test_verify_once_pass() {
532 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass);
533 assert!(assertion.verify_once().is_ok());
534 }
535
536 #[test]
537 fn test_verify_once_fail() {
538 let assertion = RetryAssertion::new(|| AssertionCheckResult::Fail("error".into()));
539 let err = assertion.verify_once().unwrap_err();
540 assert_eq!(err.attempts, 1);
541 }
542
543 #[test]
544 fn test_debug() {
545 let assertion =
546 RetryAssertion::new(|| AssertionCheckResult::Pass).with_description("test");
547 let debug = format!("{assertion:?}");
548 assert!(debug.contains("RetryAssertion"));
549 }
550 }
551
552 mod retry_error {
553 use super::*;
554
555 #[test]
556 fn test_display_without_description() {
557 let err = RetryError {
558 message: "failed".into(),
559 attempts: 5,
560 duration: Duration::from_millis(500),
561 description: None,
562 };
563 let display = format!("{err}");
564 assert!(display.contains("5 attempt(s)"));
565 assert!(display.contains("failed"));
566 }
567
568 #[test]
569 fn test_display_with_description() {
570 let err = RetryError {
571 message: "failed".into(),
572 attempts: 3,
573 duration: Duration::from_secs(1),
574 description: Some("visibility check".into()),
575 };
576 let display = format!("{err}");
577 assert!(display.contains("visibility check"));
578 assert!(display.contains("failed"));
579 }
580 }
581
582 mod helper_functions {
583 use super::*;
584
585 #[test]
586 fn test_retry_eq_pass() {
587 let assertion = retry_eq(|| 42, 42).with_max_retries(1);
588 assert!(assertion.verify().is_ok());
589 }
590
591 #[test]
592 fn test_retry_eq_fail() {
593 let assertion = retry_eq(|| 1, 2).with_max_retries(1);
594 let err = assertion.verify().unwrap_err();
595 assert!(err.message.contains("expected"));
596 }
597
598 #[test]
599 fn test_retry_true_pass() {
600 let assertion = retry_true(|| true, "should be true").with_max_retries(1);
601 assert!(assertion.verify().is_ok());
602 }
603
604 #[test]
605 fn test_retry_true_fail() {
606 let assertion = retry_true(|| false, "should be true").with_max_retries(1);
607 let err = assertion.verify().unwrap_err();
608 assert!(err.message.contains("should be true"));
609 }
610
611 #[test]
612 fn test_retry_some_pass() {
613 let assertion = retry_some(|| Some(42)).with_max_retries(1);
614 assert!(assertion.verify().is_ok());
615 }
616
617 #[test]
618 fn test_retry_some_fail() {
619 let assertion = retry_some::<i32>(|| None).with_max_retries(1);
620 assert!(assertion.verify().is_err());
621 }
622
623 #[test]
624 fn test_retry_none_pass() {
625 let assertion = retry_none::<i32>(|| None).with_max_retries(1);
626 assert!(assertion.verify().is_ok());
627 }
628
629 #[test]
630 fn test_retry_none_fail() {
631 let assertion = retry_none(|| Some(42)).with_max_retries(1);
632 assert!(assertion.verify().is_err());
633 }
634
635 #[test]
636 fn test_retry_contains_pass() {
637 let assertion =
638 retry_contains(|| "hello world".to_string(), "world").with_max_retries(1);
639 assert!(assertion.verify().is_ok());
640 }
641
642 #[test]
643 fn test_retry_contains_fail() {
644 let assertion = retry_contains(|| "hello".to_string(), "world").with_max_retries(1);
645 let err = assertion.verify().unwrap_err();
646 assert!(err.message.contains("contain"));
647 }
648 }
649
650 mod retry_result {
651 use super::*;
652
653 #[test]
654 fn test_result_fields() {
655 let result = RetryResult {
656 attempts: 3,
657 duration: Duration::from_millis(100),
658 };
659 assert_eq!(result.attempts, 3);
660 assert_eq!(result.duration, Duration::from_millis(100));
661 }
662 }
663}