use std::{sync::Arc, time::Duration};
use tokio::time::sleep;
use crate::policy::Policy;
use crate::metrics::Metrics;
#[derive(Clone)]
pub enum Backoff {
Fixed(Duration),
Exponential {
base: Duration,
factor: f64,
},
}
pub struct RetryPolicy {
max_retries: usize,
backoff: Backoff,
metrics: Option<Arc<dyn Metrics>>,
}
impl Clone for RetryPolicy {
fn clone(&self) -> Self {
Self {
max_retries: self.max_retries,
backoff: self.backoff.clone(),
metrics: self.metrics.clone(),
}
}
}
impl RetryPolicy {
pub fn fixed(max_retries: usize, delay: Duration) -> Self {
Self {
max_retries,
backoff: Backoff::Fixed(delay),
metrics: None,
}
}
pub fn exponential(max_retries: usize, base: Duration, factor: f64) -> Self {
Self {
max_retries,
backoff: Backoff::Exponential { base, factor },
metrics: None,
}
}
pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
self.metrics = Some(metrics);
self
}
fn delay(&self, attempt: usize) -> Duration {
match self.backoff {
Backoff::Fixed(d) => d,
Backoff::Exponential { base, factor } => {
base.mul_f64(factor.powi(attempt as i32))
}
}
}
}
#[async_trait::async_trait]
impl<E> Policy<E> for RetryPolicy
where
E: Send + Sync,
{
async fn execute<F, Fut, T>(&self, f: F) -> Result<T, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, E>> + Send,
T: Send,
{
let mut attempt = 0;
loop {
match f().await {
Ok(v) => {
if let Some(m) = &self.metrics { m.on_success(); }
return Ok(v);
}
Err(_e) if attempt < self.max_retries => {
if let Some(m) = &self.metrics { m.on_retry(); }
attempt += 1;
sleep(self.delay(attempt)).await;
}
Err(e) => {
if let Some(m) = &self.metrics { m.on_failure(); }
return Err(e);
}
}
}
}
}