1use crate::error::{Result, ZinitError};
2use futures::Future;
3use rand::Rng;
4use std::time::Duration;
5use tokio::time::sleep;
6use tracing::{debug, warn};
7
8#[derive(Debug, Clone)]
10pub struct RetryStrategy {
11 max_retries: usize,
13 base_delay: Duration,
15 max_delay: Duration,
17 jitter: bool,
19}
20
21impl RetryStrategy {
22 pub fn new(
24 max_retries: usize,
25 base_delay: Duration,
26 max_delay: Duration,
27 jitter: bool,
28 ) -> Self {
29 Self {
30 max_retries,
31 base_delay,
32 max_delay,
33 jitter,
34 }
35 }
36
37 pub async fn retry<F, Fut, T>(&self, operation: F) -> Result<T>
39 where
40 F: Fn() -> Fut,
41 Fut: Future<Output = Result<T>>,
42 {
43 let mut attempt = 0;
44
45 loop {
46 attempt += 1;
47 debug!("Attempt {}/{}", attempt, self.max_retries + 1);
48
49 match operation().await {
50 Ok(result) => return Ok(result),
51 Err(err) => {
52 match &err {
54 ZinitError::UnknownService(_)
55 | ZinitError::ServiceAlreadyMonitored(_)
56 | ZinitError::ServiceIsUp(_)
57 | ZinitError::ServiceIsDown(_)
58 | ZinitError::InvalidSignal(_)
59 | ZinitError::ShuttingDown => return Err(err),
60 _ => {
61 warn!("Attempt {} failed: {}", attempt, err);
62 }
63 }
64 }
65 }
66
67 if attempt > self.max_retries {
68 return Err(ZinitError::RetryLimitReached(self.max_retries));
69 }
70
71 let delay = self.calculate_delay(attempt);
72 debug!("Retrying after {:?}", delay);
73 sleep(delay).await;
74 }
75 }
76
77 fn calculate_delay(&self, attempt: usize) -> Duration {
79 let exp_backoff = self.base_delay.as_millis() * 2u128.pow((attempt - 1) as u32);
81
82 let capped_delay = std::cmp::min(exp_backoff, self.max_delay.as_millis());
84
85 let delay_ms = if self.jitter {
87 let jitter_factor = rand::thread_rng().gen_range(0.8..1.2);
88 (capped_delay as f64 * jitter_factor) as u64
89 } else {
90 capped_delay as u64
91 };
92
93 Duration::from_millis(delay_ms)
94 }
95}
96
97impl Default for RetryStrategy {
99 fn default() -> Self {
100 Self {
101 max_retries: 3,
102 base_delay: Duration::from_millis(100),
103 max_delay: Duration::from_secs(5),
104 jitter: true,
105 }
106 }
107}