1use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13use std::time::{Duration, Instant};
14
15#[derive(Debug, Clone)]
17pub enum AssertionCheckResult {
18 Pass,
20 Fail(String),
22}
23
24impl AssertionCheckResult {
25 #[must_use]
27 pub const fn is_pass(&self) -> bool {
28 matches!(self, Self::Pass)
29 }
30
31 #[must_use]
33 pub const fn is_fail(&self) -> bool {
34 matches!(self, Self::Fail(_))
35 }
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub struct RetryConfig {
41 pub timeout: Duration,
43 pub poll_interval: Duration,
45 pub max_retries: usize,
47}
48
49impl Default for RetryConfig {
50 fn default() -> Self {
51 Self {
52 timeout: Duration::from_secs(5),
53 poll_interval: Duration::from_millis(100),
54 max_retries: 0,
55 }
56 }
57}
58
59impl RetryConfig {
60 #[must_use]
62 pub const fn new(timeout: Duration) -> Self {
63 contract_pre_retry_assertion!();
64 Self {
65 timeout,
66 poll_interval: Duration::from_millis(100),
67 max_retries: 0,
68 }
69 }
70
71 #[must_use]
73 pub const fn with_poll_interval(mut self, interval: Duration) -> Self {
74 self.poll_interval = interval;
75 self
76 }
77
78 #[must_use]
80 pub const fn with_max_retries(mut self, max: usize) -> Self {
81 self.max_retries = max;
82 self
83 }
84
85 #[must_use]
87 pub const fn fast() -> Self {
88 Self {
89 timeout: Duration::from_millis(500),
90 poll_interval: Duration::from_millis(50),
91 max_retries: 0,
92 }
93 }
94
95 #[must_use]
97 pub const fn slow() -> Self {
98 Self {
99 timeout: Duration::from_secs(30),
100 poll_interval: Duration::from_millis(500),
101 max_retries: 0,
102 }
103 }
104}
105
106pub struct RetryAssertion<F>
123where
124 F: Fn() -> AssertionCheckResult,
125{
126 check: F,
127 config: RetryConfig,
128 description: Option<String>,
129}
130
131impl<F> RetryAssertion<F>
132where
133 F: Fn() -> AssertionCheckResult,
134{
135 #[must_use]
137 pub fn new(check: F) -> Self {
138 Self {
139 check,
140 config: RetryConfig::default(),
141 description: None,
142 }
143 }
144
145 #[must_use]
147 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
148 self.config.timeout = timeout;
149 self
150 }
151
152 #[must_use]
154 pub const fn with_poll_interval(mut self, interval: Duration) -> Self {
155 self.config.poll_interval = interval;
156 self
157 }
158
159 #[must_use]
161 pub const fn with_max_retries(mut self, max: usize) -> Self {
162 self.config.max_retries = max;
163 self
164 }
165
166 #[must_use]
168 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
169 self.description = Some(desc.into());
170 self
171 }
172
173 #[must_use]
175 pub const fn with_config(mut self, config: RetryConfig) -> Self {
176 self.config = config;
177 self
178 }
179
180 #[must_use]
182 pub const fn config(&self) -> &RetryConfig {
183 &self.config
184 }
185
186 #[allow(unused_assignments)]
192 pub fn verify(&self) -> Result<RetryResult, RetryError> {
193 contract_pre_retry_assertion!();
194 let start = Instant::now();
195 let mut attempts = 0;
196 let mut last_error: Option<String> = None;
197
198 loop {
199 attempts += 1;
200
201 match (self.check)() {
202 AssertionCheckResult::Pass => {
203 return Ok(RetryResult {
204 attempts,
205 duration: start.elapsed(),
206 });
207 }
208 AssertionCheckResult::Fail(msg) => {
209 last_error = Some(msg);
210 }
211 }
212
213 if start.elapsed() >= self.config.timeout {
215 return Err(RetryError {
216 message: last_error.unwrap_or_default(),
217 attempts,
218 duration: start.elapsed(),
219 description: self.description.clone(),
220 });
221 }
222
223 if self.config.max_retries > 0 && attempts >= self.config.max_retries {
225 return Err(RetryError {
226 message: last_error.unwrap_or_default(),
227 attempts,
228 duration: start.elapsed(),
229 description: self.description.clone(),
230 });
231 }
232
233 std::thread::sleep(self.config.poll_interval);
235 }
236 }
237
238 pub fn verify_once(&self) -> Result<(), RetryError> {
244 match (self.check)() {
245 AssertionCheckResult::Pass => Ok(()),
246 AssertionCheckResult::Fail(msg) => Err(RetryError {
247 message: msg,
248 attempts: 1,
249 duration: Duration::ZERO,
250 description: self.description.clone(),
251 }),
252 }
253 }
254}
255
256impl<F> Debug for RetryAssertion<F>
257where
258 F: Fn() -> AssertionCheckResult,
259{
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("RetryAssertion")
262 .field("config", &self.config)
263 .field("description", &self.description)
264 .finish()
265 }
266}
267
268#[derive(Debug, Clone, Copy)]
270pub struct RetryResult {
271 pub attempts: usize,
273 pub duration: Duration,
275}
276
277#[derive(Debug, Clone)]
279pub struct RetryError {
280 pub message: String,
282 pub attempts: usize,
284 pub duration: Duration,
286 pub description: Option<String>,
288}
289
290impl std::fmt::Display for RetryError {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 if let Some(ref desc) = self.description {
293 write!(f, "{desc}: ")?;
294 }
295 write!(
296 f,
297 "assertion failed after {} attempt(s) ({:.2}s): {}",
298 self.attempts,
299 self.duration.as_secs_f64(),
300 self.message
301 )
302 }
303}
304
305impl std::error::Error for RetryError {}
306
307pub fn retry_eq<T: PartialEq + Debug + Clone + 'static>(
313 get_actual: impl Fn() -> T + 'static,
314 expected: T,
315) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
316 let expected = expected;
317 RetryAssertion::new(move || {
318 let actual = get_actual();
319 if actual == expected {
320 AssertionCheckResult::Pass
321 } else {
322 AssertionCheckResult::Fail(format!("expected {expected:?}, got {actual:?}"))
323 }
324 })
325}
326
327pub fn retry_true(
329 check: impl Fn() -> bool + 'static,
330 message: impl Into<String>,
331) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
332 let message = message.into();
333 RetryAssertion::new(move || {
334 if check() {
335 AssertionCheckResult::Pass
336 } else {
337 AssertionCheckResult::Fail(message.clone())
338 }
339 })
340}
341
342pub fn retry_some<T>(
344 get_opt: impl Fn() -> Option<T> + 'static,
345) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
346 RetryAssertion::new(move || {
347 if get_opt().is_some() {
348 AssertionCheckResult::Pass
349 } else {
350 AssertionCheckResult::Fail("expected Some, got None".into())
351 }
352 })
353}
354
355pub fn retry_none<T>(
357 get_opt: impl Fn() -> Option<T> + 'static,
358) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
359 RetryAssertion::new(move || {
360 if get_opt().is_none() {
361 AssertionCheckResult::Pass
362 } else {
363 AssertionCheckResult::Fail("expected None, got Some".into())
364 }
365 })
366}
367
368pub fn retry_contains(
370 get_haystack: impl Fn() -> String + 'static,
371 needle: impl Into<String>,
372) -> RetryAssertion<impl Fn() -> AssertionCheckResult> {
373 let needle = needle.into();
374 RetryAssertion::new(move || {
375 let haystack = get_haystack();
376 if haystack.contains(&needle) {
377 AssertionCheckResult::Pass
378 } else {
379 AssertionCheckResult::Fail(format!("expected '{haystack}' to contain '{needle}'"))
380 }
381 })
382}
383
384#[cfg(test)]
389#[allow(clippy::unwrap_used, clippy::expect_used)]
390mod tests {
391 use super::*;
392 use std::sync::atomic::{AtomicUsize, Ordering};
393 use std::sync::Arc;
394
395 mod assertion_check_result {
396 use super::*;
397
398 #[test]
399 fn test_pass() {
400 let result = AssertionCheckResult::Pass;
401 assert!(result.is_pass());
402 assert!(!result.is_fail());
403 }
404
405 #[test]
406 fn test_fail() {
407 let result = AssertionCheckResult::Fail("error".into());
408 assert!(result.is_fail());
409 assert!(!result.is_pass());
410 }
411 }
412
413 mod retry_config {
414 use super::*;
415
416 #[test]
417 fn test_default() {
418 let config = RetryConfig::default();
419 assert_eq!(config.timeout, Duration::from_secs(5));
420 assert_eq!(config.poll_interval, Duration::from_millis(100));
421 assert_eq!(config.max_retries, 0);
422 }
423
424 #[test]
425 fn test_new() {
426 let config = RetryConfig::new(Duration::from_secs(10));
427 assert_eq!(config.timeout, Duration::from_secs(10));
428 }
429
430 #[test]
431 fn test_with_poll_interval() {
432 let config = RetryConfig::default().with_poll_interval(Duration::from_millis(50));
433 assert_eq!(config.poll_interval, Duration::from_millis(50));
434 }
435
436 #[test]
437 fn test_with_max_retries() {
438 let config = RetryConfig::default().with_max_retries(3);
439 assert_eq!(config.max_retries, 3);
440 }
441
442 #[test]
443 fn test_fast() {
444 let config = RetryConfig::fast();
445 assert_eq!(config.timeout, Duration::from_millis(500));
446 assert_eq!(config.poll_interval, Duration::from_millis(50));
447 }
448
449 #[test]
450 fn test_slow() {
451 let config = RetryConfig::slow();
452 assert_eq!(config.timeout, Duration::from_secs(30));
453 assert_eq!(config.poll_interval, Duration::from_millis(500));
454 }
455 }
456
457 mod retry_assertion {
458 use super::*;
459
460 #[test]
461 fn test_immediate_pass() {
462 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass);
463 let result = assertion.verify().unwrap();
464 assert_eq!(result.attempts, 1);
465 }
466
467 #[test]
468 fn test_immediate_fail_with_timeout() {
469 let assertion =
470 RetryAssertion::new(|| AssertionCheckResult::Fail("always fails".into()))
471 .with_timeout(Duration::from_millis(100))
472 .with_poll_interval(Duration::from_millis(20));
473
474 let err = assertion.verify().unwrap_err();
475 assert!(err.attempts > 1);
476 assert!(err.message.contains("always fails"));
477 }
478
479 #[test]
480 fn test_eventual_pass() {
481 let counter = Arc::new(AtomicUsize::new(0));
482 let counter_clone = counter;
483
484 let assertion = RetryAssertion::new(move || {
485 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
486 if count >= 2 {
487 AssertionCheckResult::Pass
488 } else {
489 AssertionCheckResult::Fail("not yet".into())
490 }
491 })
492 .with_timeout(Duration::from_secs(1))
493 .with_poll_interval(Duration::from_millis(10));
494
495 let result = assertion.verify().unwrap();
496 assert_eq!(result.attempts, 3);
497 }
498
499 #[test]
500 fn test_max_retries() {
501 let counter = Arc::new(AtomicUsize::new(0));
502 let counter_clone = counter;
503
504 let assertion = RetryAssertion::new(move || {
505 let _ = counter_clone.fetch_add(1, Ordering::SeqCst);
506 AssertionCheckResult::Fail("always fails".into())
507 })
508 .with_max_retries(3)
509 .with_timeout(Duration::from_secs(10));
510
511 let err = assertion.verify().unwrap_err();
512 assert_eq!(err.attempts, 3);
513 }
514
515 #[test]
516 fn test_with_description() {
517 let assertion = RetryAssertion::new(|| AssertionCheckResult::Fail("error".into()))
518 .with_description("checking visibility")
519 .with_max_retries(1);
520
521 let err = assertion.verify().unwrap_err();
522 assert_eq!(err.description, Some("checking visibility".to_string()));
523 }
524
525 #[test]
526 fn test_with_config() {
527 let config = RetryConfig::fast();
528 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass).with_config(config);
529 assert_eq!(assertion.config().timeout, Duration::from_millis(500));
530 }
531
532 #[test]
533 fn test_verify_once_pass() {
534 let assertion = RetryAssertion::new(|| AssertionCheckResult::Pass);
535 assert!(assertion.verify_once().is_ok());
536 }
537
538 #[test]
539 fn test_verify_once_fail() {
540 let assertion = RetryAssertion::new(|| AssertionCheckResult::Fail("error".into()));
541 let err = assertion.verify_once().unwrap_err();
542 assert_eq!(err.attempts, 1);
543 }
544
545 #[test]
546 fn test_debug() {
547 let assertion =
548 RetryAssertion::new(|| AssertionCheckResult::Pass).with_description("test");
549 let debug = format!("{assertion:?}");
550 assert!(debug.contains("RetryAssertion"));
551 }
552 }
553
554 mod retry_error {
555 use super::*;
556
557 #[test]
558 fn test_display_without_description() {
559 let err = RetryError {
560 message: "failed".into(),
561 attempts: 5,
562 duration: Duration::from_millis(500),
563 description: None,
564 };
565 let display = format!("{err}");
566 assert!(display.contains("5 attempt(s)"));
567 assert!(display.contains("failed"));
568 }
569
570 #[test]
571 fn test_display_with_description() {
572 let err = RetryError {
573 message: "failed".into(),
574 attempts: 3,
575 duration: Duration::from_secs(1),
576 description: Some("visibility check".into()),
577 };
578 let display = format!("{err}");
579 assert!(display.contains("visibility check"));
580 assert!(display.contains("failed"));
581 }
582 }
583
584 mod helper_functions {
585 use super::*;
586
587 #[test]
588 fn test_retry_eq_pass() {
589 let assertion = retry_eq(|| 42, 42).with_max_retries(1);
590 assert!(assertion.verify().is_ok());
591 }
592
593 #[test]
594 fn test_retry_eq_fail() {
595 let assertion = retry_eq(|| 1, 2).with_max_retries(1);
596 let err = assertion.verify().unwrap_err();
597 assert!(err.message.contains("expected"));
598 }
599
600 #[test]
601 fn test_retry_true_pass() {
602 let assertion = retry_true(|| true, "should be true").with_max_retries(1);
603 assert!(assertion.verify().is_ok());
604 }
605
606 #[test]
607 fn test_retry_true_fail() {
608 let assertion = retry_true(|| false, "should be true").with_max_retries(1);
609 let err = assertion.verify().unwrap_err();
610 assert!(err.message.contains("should be true"));
611 }
612
613 #[test]
614 fn test_retry_some_pass() {
615 let assertion = retry_some(|| Some(42)).with_max_retries(1);
616 assert!(assertion.verify().is_ok());
617 }
618
619 #[test]
620 fn test_retry_some_fail() {
621 let assertion = retry_some::<i32>(|| None).with_max_retries(1);
622 assert!(assertion.verify().is_err());
623 }
624
625 #[test]
626 fn test_retry_none_pass() {
627 let assertion = retry_none::<i32>(|| None).with_max_retries(1);
628 assert!(assertion.verify().is_ok());
629 }
630
631 #[test]
632 fn test_retry_none_fail() {
633 let assertion = retry_none(|| Some(42)).with_max_retries(1);
634 assert!(assertion.verify().is_err());
635 }
636
637 #[test]
638 fn test_retry_contains_pass() {
639 let assertion =
640 retry_contains(|| "hello world".to_string(), "world").with_max_retries(1);
641 assert!(assertion.verify().is_ok());
642 }
643
644 #[test]
645 fn test_retry_contains_fail() {
646 let assertion = retry_contains(|| "hello".to_string(), "world").with_max_retries(1);
647 let err = assertion.verify().unwrap_err();
648 assert!(err.message.contains("contain"));
649 }
650 }
651
652 mod retry_result {
653 use super::*;
654
655 #[test]
656 fn test_result_fields() {
657 let result = RetryResult {
658 attempts: 3,
659 duration: Duration::from_millis(100),
660 };
661 assert_eq!(result.attempts, 3);
662 assert_eq!(result.duration, Duration::from_millis(100));
663 }
664 }
665}