use crate::error::{Result, ZinitError};
use futures::Future;
use rand::Rng;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RetryStrategy {
max_retries: usize,
base_delay: Duration,
max_delay: Duration,
jitter: bool,
}
impl RetryStrategy {
pub fn new(
max_retries: usize,
base_delay: Duration,
max_delay: Duration,
jitter: bool,
) -> Self {
Self {
max_retries,
base_delay,
max_delay,
jitter,
}
}
pub async fn retry<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut attempt = 0;
loop {
attempt += 1;
debug!("Attempt {}/{}", attempt, self.max_retries + 1);
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
match &err {
ZinitError::UnknownService(_)
| ZinitError::ServiceAlreadyMonitored(_)
| ZinitError::ServiceIsUp(_)
| ZinitError::ServiceIsDown(_)
| ZinitError::InvalidSignal(_)
| ZinitError::ShuttingDown => return Err(err),
_ => {
warn!("Attempt {} failed: {}", attempt, err);
}
}
}
}
if attempt > self.max_retries {
return Err(ZinitError::RetryLimitReached(self.max_retries));
}
let delay = self.calculate_delay(attempt);
debug!("Retrying after {:?}", delay);
sleep(delay).await;
}
}
fn calculate_delay(&self, attempt: usize) -> Duration {
let exp_backoff = self.base_delay.as_millis() * 2u128.pow((attempt - 1) as u32);
let capped_delay = std::cmp::min(exp_backoff, self.max_delay.as_millis());
let delay_ms = if self.jitter {
let jitter_factor = rand::thread_rng().gen_range(0.8..1.2);
(capped_delay as f64 * jitter_factor) as u64
} else {
capped_delay as u64
};
Duration::from_millis(delay_ms)
}
}
impl Default for RetryStrategy {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
jitter: true,
}
}
}