use rand::Rng;
use std::error::Error;
use std::future::Future;
use std::marker::PhantomData;
use std::time::Duration;
pub async fn retry<R, F, Fut, T, E>(mut operation: F, retry_policy: &R) -> Result<T, E>
where
R: RetryPolicy<E = E>,
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: Error,
{
let mut attempts_made = 0;
let mut accumulated_delay = Duration::ZERO;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
attempts_made += 1;
if let Some(delay) = retry_policy.next_delay(&RetryContext {
attempts_made,
accumulated_delay,
error: &err,
}) {
tokio::time::sleep(delay).await;
accumulated_delay += delay;
} else {
return Err(err);
}
},
}
}
}
pub trait RetryPolicy: Sized {
type E: Error;
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration>;
fn with_max_attempts(self, max_attempts: u32) -> MaxAttemptsRetryPolicy<Self> {
MaxAttemptsRetryPolicy { inner_policy: self, max_attempts }
}
fn with_max_total_delay(self, max_total_delay: Duration) -> MaxTotalDelayRetryPolicy<Self> {
MaxTotalDelayRetryPolicy { inner_policy: self, max_total_delay }
}
fn with_max_jitter(self, max_jitter: Duration) -> JitteredRetryPolicy<Self> {
JitteredRetryPolicy { inner_policy: self, max_jitter }
}
fn skip_retry_on_error<F>(self, function: F) -> FilteredRetryPolicy<Self, F>
where
F: 'static + Fn(&Self::E) -> bool,
{
FilteredRetryPolicy { inner_policy: self, function }
}
}
pub struct RetryContext<'a, E: Error> {
attempts_made: u32,
accumulated_delay: Duration,
error: &'a E,
}
pub struct ExponentialBackoffRetryPolicy<E> {
base_delay: Duration,
phantom: PhantomData<E>,
}
impl<E: Error> ExponentialBackoffRetryPolicy<E> {
pub fn new(base_delay: Duration) -> ExponentialBackoffRetryPolicy<E> {
Self { base_delay, phantom: PhantomData }
}
}
impl<E: Error> RetryPolicy for ExponentialBackoffRetryPolicy<E> {
type E = E;
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
let backoff_factor = 2_u32.pow(context.attempts_made) - 1;
let delay = self.base_delay * backoff_factor;
Some(delay)
}
}
pub struct MaxAttemptsRetryPolicy<T: RetryPolicy> {
inner_policy: T,
max_attempts: u32,
}
impl<T: RetryPolicy> RetryPolicy for MaxAttemptsRetryPolicy<T> {
type E = T::E;
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
if self.max_attempts == context.attempts_made {
None
} else {
self.inner_policy.next_delay(context)
}
}
}
pub struct MaxTotalDelayRetryPolicy<T: RetryPolicy> {
inner_policy: T,
max_total_delay: Duration,
}
impl<T: RetryPolicy> RetryPolicy for MaxTotalDelayRetryPolicy<T> {
type E = T::E;
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
let next_delay = self.inner_policy.next_delay(context);
if let Some(next_delay) = next_delay {
if self.max_total_delay < context.accumulated_delay + next_delay {
return None;
}
}
next_delay
}
}
pub struct JitteredRetryPolicy<T: RetryPolicy> {
inner_policy: T,
max_jitter: Duration,
}
impl<T: RetryPolicy> RetryPolicy for JitteredRetryPolicy<T> {
type E = T::E;
fn next_delay(&self, context: &RetryContext<Self::E>) -> Option<Duration> {
if let Some(base_delay) = self.inner_policy.next_delay(context) {
let mut rng = rand::thread_rng();
let jitter =
Duration::from_micros(rng.gen_range(0..self.max_jitter.as_micros() as u64));
Some(base_delay + jitter)
} else {
None
}
}
}
pub struct FilteredRetryPolicy<T: RetryPolicy, F> {
inner_policy: T,
function: F,
}
impl<T, F, E> RetryPolicy for FilteredRetryPolicy<T, F>
where
T: RetryPolicy<E = E>,
F: Fn(&E) -> bool,
E: Error,
{
type E = T::E;
fn next_delay(&self, context: &RetryContext<E>) -> Option<Duration> {
if (self.function)(&context.error) {
None
} else {
self.inner_policy.next_delay(context)
}
}
}