use std::cmp::min;
use std::future::Future;
use std::time::Duration;
use futures_timer::Delay;
#[cfg(feature = "rand")]
use rand::Rng;
#[cfg(feature = "rand")]
use rand::distr::OpenClosed01;
#[cfg(feature = "rand")]
use rand::rng;
pub async fn retry<T>(task: T) -> Result<T::Item, T::Error>
where
T: Task,
{
retry_if(task, Always).await
}
pub async fn retry_if<T, C>(task: T, condition: C) -> Result<T::Item, T::Error>
where
T: Task,
C: Condition<T::Error>,
{
RetryPolicy::default().retry_if(task, condition).await
}
pub async fn collect<T, C, S>(
task: T,
condition: C,
start_value: S,
) -> Result<Vec<T::Item>, T::Error>
where
T: TaskWithParameter<S>,
C: SuccessCondition<T::Item, S>,
{
RetryPolicy::default()
.collect(task, condition, start_value)
.await
}
pub async fn collect_and_retry<T, C, D, S>(
task: T,
success_condition: C,
error_condition: D,
start_value: S,
) -> Result<Vec<T::Item>, T::Error>
where
T: TaskWithParameter<S>,
C: SuccessCondition<T::Item, S>,
D: Condition<T::Error>,
S: Clone,
{
RetryPolicy::default()
.collect_and_retry(task, success_condition, error_condition, start_value)
.await
}
#[derive(Clone, Copy)]
enum Backoff {
Fixed,
Exponential { exponent: f64 },
}
impl Default for Backoff {
fn default() -> Self {
Backoff::Exponential { exponent: 2.0 }
}
}
impl Backoff {
fn iter(self, policy: &RetryPolicy) -> BackoffIter {
BackoffIter {
backoff: self,
current: 1.0,
#[cfg(feature = "rand")]
jitter: policy.jitter,
delay: policy.delay,
max_delay: policy.max_delay,
max_retries: policy.max_retries,
}
}
}
struct BackoffIter {
backoff: Backoff,
current: f64,
#[cfg(feature = "rand")]
jitter: bool,
delay: Duration,
max_delay: Option<Duration>,
max_retries: usize,
}
impl Iterator for BackoffIter {
type Item = Duration;
fn next(&mut self) -> Option<Self::Item> {
if self.max_retries > 0 {
let factor = match self.backoff {
Backoff::Fixed => self.current,
Backoff::Exponential { exponent } => {
let factor = self.current;
let next_factor = self.current * exponent;
self.current = next_factor;
factor
}
};
let mut delay = self.delay.mul_f64(factor);
#[cfg(feature = "rand")]
{
if self.jitter {
delay = jitter(delay);
}
}
if let Some(max_delay) = self.max_delay {
delay = min(delay, max_delay);
}
self.max_retries -= 1;
return Some(delay);
}
None
}
}
#[derive(Clone)]
pub struct RetryPolicy {
backoff: Backoff,
#[cfg(feature = "rand")]
jitter: bool,
delay: Duration,
max_delay: Option<Duration>,
max_retries: usize,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
backoff: Backoff::default(),
delay: Duration::from_secs(1),
#[cfg(feature = "rand")]
jitter: false,
max_delay: None,
max_retries: 5,
}
}
}
#[cfg(feature = "rand")]
fn jitter(duration: Duration) -> Duration {
let jitter: f64 = rng().sample(OpenClosed01);
let secs = (duration.as_secs() as f64) * jitter;
let nanos = f64::from(duration.subsec_nanos()) * jitter;
let millis = (secs * 1_000_f64) + (nanos / 1_000_000_f64);
Duration::from_millis(millis as u64)
}
impl RetryPolicy {
fn backoffs(&self) -> impl Iterator<Item = Duration> + '_ {
self.backoff.iter(self)
}
pub fn exponential(delay: Duration) -> Self {
Self {
backoff: Backoff::Exponential { exponent: 2.0f64 },
delay,
..Self::default()
}
}
pub fn fixed(delay: Duration) -> Self {
Self {
backoff: Backoff::Fixed,
delay,
..Self::default()
}
}
pub fn with_backoff_exponent(mut self, exp: f64) -> Self {
if let Backoff::Exponential { ref mut exponent } = self.backoff {
*exponent = exp;
}
self
}
#[cfg(feature = "rand")]
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn with_max_delay(mut self, max: Duration) -> Self {
self.max_delay = Some(max);
self
}
pub fn with_max_retries(mut self, max: usize) -> Self {
self.max_retries = max;
self
}
pub async fn retry<T>(&self, task: T) -> Result<T::Item, T::Error>
where
T: Task,
{
self.retry_if(task, Always).await
}
pub async fn collect<T, C, S>(
&self,
task: T,
condition: C,
start_value: S,
) -> Result<Vec<T::Item>, T::Error>
where
T: TaskWithParameter<S>,
C: SuccessCondition<T::Item, S>,
{
let mut backoffs = self.backoffs();
let mut condition = condition;
let mut task = task;
let mut results = vec![];
let mut input = start_value;
loop {
match task.call(input).await {
Ok(result) => {
let maybe_new_input = condition.retry_with(&result);
results.push(result);
if let Some(new_input) = maybe_new_input {
if let Some(delay) = backoffs.next() {
#[cfg(feature = "log")]
{
log::trace!(
"task succeeded and condition is met. will run again in {:?}",
delay
);
}
let () = Delay::new(delay).await;
input = new_input;
continue;
}
}
return Ok(results);
}
Err(err) => return Err(err),
}
}
}
pub async fn collect_and_retry<T, C, D, S>(
&self,
task: T,
success_condition: C,
error_condition: D,
start_value: S,
) -> Result<Vec<T::Item>, T::Error>
where
T: TaskWithParameter<S>,
C: SuccessCondition<T::Item, S>,
D: Condition<T::Error>,
S: Clone,
{
let mut success_backoffs = self.backoffs();
let mut error_backoffs = self.backoffs();
let mut success_condition = success_condition;
let mut error_condition = error_condition;
let mut task = task;
let mut results = vec![];
let mut input = start_value.clone();
let mut last_result = start_value;
loop {
match task.call(input).await {
Ok(result) => {
let maybe_new_input = success_condition.retry_with(&result);
results.push(result);
if let Some(new_input) = maybe_new_input {
if let Some(delay) = success_backoffs.next() {
#[cfg(feature = "log")]
{
log::trace!(
"task succeeded and condition is met. will run again in {:?}",
delay
);
}
let () = Delay::new(delay).await;
input = new_input.clone();
last_result = new_input;
continue;
}
}
return Ok(results);
}
Err(err) => {
if error_condition.is_retryable(&err) {
if let Some(delay) = error_backoffs.next() {
#[cfg(feature = "log")]
{
log::trace!(
"task failed with error {:?}. will try again in {:?}",
err,
delay
);
}
let () = Delay::new(delay).await;
input = last_result.clone();
continue;
}
}
return Err(err);
}
}
}
}
pub async fn retry_if<T, C>(&self, task: T, condition: C) -> Result<T::Item, T::Error>
where
T: Task,
C: Condition<T::Error>,
{
let mut backoffs = self.backoffs();
let mut task = task;
let mut condition = condition;
loop {
match task.call().await {
Ok(result) => return Ok(result),
Err(err) => {
if condition.is_retryable(&err) {
if let Some(delay) = backoffs.next() {
#[cfg(feature = "log")]
{
log::trace!(
"task failed with error {:?}. will try again in {:?}",
err,
delay
);
}
let () = Delay::new(delay).await;
continue;
}
}
return Err(err);
}
}
}
}
}
pub trait Condition<E> {
fn is_retryable(&mut self, error: &E) -> bool;
}
struct Always;
impl<E> Condition<E> for Always {
#[inline]
fn is_retryable(&mut self, _: &E) -> bool {
true
}
}
impl<F, E> Condition<E> for F
where
F: FnMut(&E) -> bool,
{
fn is_retryable(&mut self, error: &E) -> bool {
self(error)
}
}
pub trait SuccessCondition<R, S> {
fn retry_with(&mut self, result: &R) -> Option<S>;
}
impl<F, R, S> SuccessCondition<R, S> for F
where
F: Fn(&R) -> Option<S>,
{
fn retry_with(&mut self, result: &R) -> Option<S> {
self(result)
}
}
pub trait TaskWithParameter<P> {
type Item;
type Error: std::fmt::Debug;
type Fut: Future<Output = Result<Self::Item, Self::Error>>;
fn call(&mut self, parameter: P) -> Self::Fut;
}
impl<F, Fut, I, P, E> TaskWithParameter<P> for F
where
F: FnMut(P) -> Fut,
Fut: Future<Output = Result<I, E>>,
E: std::fmt::Debug,
{
type Error = E;
type Fut = Fut;
type Item = I;
fn call(&mut self, p: P) -> Self::Fut {
self(p)
}
}
pub trait Task {
type Item;
type Error: std::fmt::Debug;
type Fut: Future<Output = Result<Self::Item, Self::Error>>;
fn call(&mut self) -> Self::Fut;
}
impl<F, Fut, I, E> Task for F
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<I, E>>,
E: std::fmt::Debug,
{
type Error = E;
type Fut = Fut;
type Item = I;
fn call(&mut self) -> Self::Fut {
self()
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use approx::assert_relative_eq;
use super::*;
#[test]
fn retry_policy_is_send() {
fn test(_: impl Send) {}
test(RetryPolicy::default());
}
#[test]
#[cfg(feature = "rand")]
fn jitter_adds_variance_to_durations() {
assert!(jitter(Duration::from_secs(1)) != Duration::from_secs(1));
}
#[test]
fn backoff_default() {
if let Backoff::Exponential { exponent } = Backoff::default() {
assert_relative_eq!(exponent, 2.0);
} else {
panic!("Default backoff expected to be exponential!");
}
}
#[test]
fn fixed_backoff() {
let binding = RetryPolicy::fixed(Duration::from_secs(1));
let mut iter = binding.backoffs();
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
}
#[test]
fn exponential_backoff() {
let binding = RetryPolicy::exponential(Duration::from_secs(1));
let mut iter = binding.backoffs();
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.0);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 2.0);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 4.0);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 8.0);
}
#[test]
fn exponential_backoff_factor() {
let binding = RetryPolicy::exponential(Duration::from_secs(1)).with_backoff_exponent(1.5);
let mut iter = binding.backoffs();
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.0);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 1.5);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 2.25);
assert_relative_eq!(iter.next().unwrap().as_secs_f64(), 3.375);
}
#[test]
fn always_is_always_retryable() {
assert!(Always.is_retryable(&()));
}
#[test]
fn closures_impl_condition() {
fn test(_: impl Condition<()>) {}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn foo(_err: &()) -> bool {
true
}
test(foo);
test(|_err: &()| true);
}
#[test]
fn closures_impl_task() {
fn test(_: impl Task) {}
async fn foo() -> Result<u32, ()> {
Ok(42)
}
test(foo);
test(|| async { Ok::<u32, ()>(42) });
}
#[test]
fn retried_futures_are_send_when_tasks_are_send() {
fn test(_: impl Send) {}
test(RetryPolicy::default().retry(|| async { Ok::<u32, ()>(42) }));
}
#[tokio::test]
async fn collect_retries_when_condition_is_met() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.collect(
|input: u32| async move { Ok::<u32, ()>(input + 1) },
|result: &u32| if *result < 2 { Some(*result) } else { None },
0,
)
.await;
assert_eq!(result, Ok(vec![1, 2]));
Ok(())
}
#[tokio::test]
async fn collect_does_not_retry_when_condition_is_not_met() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.collect(
|input: u32| async move { Ok::<u32, ()>(input + 1) },
|result: &u32| if *result < 1 { Some(*result) } else { None },
0,
)
.await;
assert_eq!(result, Ok(vec![1]));
Ok(())
}
#[tokio::test]
async fn collect_and_retry_retries_when_success_condition_is_met() -> Result<(), Box<dyn Error>>
{
let result = RetryPolicy::fixed(Duration::from_millis(1))
.collect_and_retry(
|input: u32| async move { Ok::<u32, u32>(input + 1) },
|result: &u32| if *result < 2 { Some(*result) } else { None },
|err: &u32| *err > 1,
0,
)
.await;
assert_eq!(result, Ok(vec![1, 2]));
Ok(())
}
#[tokio::test]
async fn collect_and_retry_does_not_retry_when_success_condition_is_not_met()
-> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.collect_and_retry(
|input: u32| async move { Ok::<u32, u32>(input + 1) },
|result: &u32| if *result < 1 { Some(*result) } else { None },
|err: &u32| *err > 1,
0,
)
.await;
assert_eq!(result, Ok(vec![1]));
Ok(())
}
#[tokio::test]
async fn collect_and_retry_retries_when_error_condition_is_met() -> Result<(), Box<dyn Error>> {
let mut task_ran = 0;
let _ = RetryPolicy::fixed(Duration::from_millis(1))
.collect_and_retry(
|_input: u32| {
task_ran += 1;
async move { Err::<u32, u32>(0) }
},
|result: &u32| if *result < 2 { Some(*result) } else { None },
|err: &u32| *err == 0,
0,
)
.await;
assert_eq!(task_ran, 6);
Ok(())
}
#[tokio::test]
async fn collect_and_retry_does_not_retry_when_error_condition_is_not_met()
-> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.collect_and_retry(
|input: u32| async move { Err::<u32, u32>(input + 1) },
|result: &u32| if *result < 1 { Some(*result) } else { None },
|err: &u32| *err > 1,
0,
)
.await;
assert_eq!(result, Err(1));
Ok(())
}
#[tokio::test]
async fn ok_futures_yield_ok() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::default()
.retry(|| async { Ok::<u32, ()>(42) })
.await;
assert_eq!(result, Ok(42));
Ok(())
}
#[tokio::test]
async fn failed_futures_yield_err() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.retry(|| async { Err::<u32, ()>(()) })
.await;
assert_eq!(result, Err(()));
Ok(())
}
}