#[cfg(feature = "rand")]
use rand::{distributions::OpenClosed01, thread_rng, Rng};
use std::{cmp::min, future::Future, time::Duration};
use wasm_timer::Delay;
pub async fn retry<T>(task: T) -> Result<T::Item, T::Error>
where
T: Task,
{
crate::retry_if(task, Always).await
}
pub async fn retry_if<T, C>(
task: T,
condition: C,
) -> Result<T::Item, T::Error>
where
T: Task,
C: Condition<T::Error>,
{
RetryPolicy::default().retry_if(task, condition).await
}
#[derive(Clone, Copy)]
enum Backoff {
Fixed,
Exponential,
}
impl Default for Backoff {
fn default() -> Self {
Backoff::Exponential
}
}
impl Backoff {
fn iter(
self,
policy: &RetryPolicy,
) -> BackoffIter {
BackoffIter {
backoff: self,
current: 1,
#[cfg(feature = "rand")]
jitter: policy.jitter,
delay: policy.delay,
max_delay: policy.max_delay,
max_retries: policy.max_retries,
}
}
}
struct BackoffIter {
backoff: Backoff,
current: u32,
#[cfg(feature = "rand")]
jitter: bool,
delay: Duration,
max_delay: Option<Duration>,
max_retries: usize,
}
impl Iterator for BackoffIter {
type Item = Duration;
fn next(&mut self) -> Option<Self::Item> {
if self.max_retries > 0 {
let factor = match self.backoff {
Backoff::Fixed => Some(self.current),
Backoff::Exponential => {
let factor = self.current;
if let Some(next) = self.current.checked_mul(2) {
self.current = next;
} else {
self.current = u32::MAX;
}
Some(factor)
}
};
if let Some(factor) = factor {
if let Some(mut delay) = self.delay.checked_mul(factor) {
#[cfg(feature = "rand")]
{
if self.jitter {
delay = jitter(delay);
}
}
if let Some(max_delay) = self.max_delay {
delay = min(delay, max_delay);
}
self.max_retries -= 1;
return Some(delay);
}
}
}
None
}
}
#[derive(Clone)]
pub struct RetryPolicy {
backoff: Backoff,
#[cfg(feature = "rand")]
jitter: bool,
delay: Duration,
max_delay: Option<Duration>,
max_retries: usize,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
backoff: Backoff::default(),
delay: Duration::from_secs(1),
#[cfg(feature = "rand")]
jitter: false,
max_delay: None,
max_retries: 5,
}
}
}
#[cfg(feature = "rand")]
fn jitter(duration: Duration) -> Duration {
let jitter: f64 = thread_rng().sample(OpenClosed01);
let secs = (duration.as_secs() as f64) * jitter;
let nanos = (duration.subsec_nanos() as f64) * jitter;
let millis = (secs * 1_000_f64) + (nanos / 1_000_000_f64);
Duration::from_millis(millis as u64)
}
impl RetryPolicy {
fn backoffs(&self) -> impl Iterator<Item = Duration> {
self.backoff.iter(self)
}
pub fn exponential(delay: Duration) -> Self {
Self {
backoff: Backoff::Exponential,
delay,
..Self::default()
}
}
pub fn fixed(delay: Duration) -> Self {
Self {
backoff: Backoff::Fixed,
delay,
..Self::default()
}
}
#[cfg(feature = "rand")]
pub fn with_jitter(
mut self,
jitter: bool,
) -> Self {
self.jitter = jitter;
self
}
pub fn with_max_delay(
mut self,
max: Duration,
) -> Self {
self.max_delay = Some(max);
self
}
pub fn with_max_retries(
mut self,
max: usize,
) -> Self {
self.max_retries = max;
self
}
pub async fn retry<T>(
&self,
task: T,
) -> Result<T::Item, T::Error>
where
T: Task,
{
self.retry_if(task, Always).await
}
pub async fn retry_if<T, C>(
&self,
task: T,
condition: C,
) -> Result<T::Item, T::Error>
where
T: Task,
C: Condition<T::Error>,
{
let mut backoffs = self.backoffs();
let mut task = task;
let mut condition = condition;
loop {
match task.call().await {
Ok(result) => return Ok(result),
Err(err) => {
if condition.is_retryable(&err) {
if let Some(delay) = backoffs.next() {
#[cfg(feature = "log")]
{
log::trace!(
"task failed with error {:?}. will try again in {:?}",
err,
delay
);
}
let _ = Delay::new(delay).await;
continue;
}
}
return Err(err);
}
}
}
}
}
pub trait Condition<E> {
fn is_retryable(
&mut self,
error: &E,
) -> bool;
}
struct Always;
impl<E> Condition<E> for Always {
#[inline]
fn is_retryable(
&mut self,
_: &E,
) -> bool {
true
}
}
impl<F, E> Condition<E> for F
where
F: Fn(&E) -> bool,
{
fn is_retryable(
&mut self,
error: &E,
) -> bool {
self(error)
}
}
pub trait Task {
type Item;
type Error: std::fmt::Debug;
type Fut: Future<Output = Result<Self::Item, Self::Error>>;
fn call(&mut self) -> Self::Fut;
}
impl<F, Fut, I, E> Task for F
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<I, E>>,
E: std::fmt::Debug,
{
type Item = I;
type Error = E;
type Fut = Fut;
fn call(&mut self) -> Self::Fut {
self()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
#[test]
fn retry_policy_is_send() {
fn test(_: impl Send) {}
test(RetryPolicy::default())
}
#[test]
#[cfg(feature = "rand")]
fn jitter_adds_variance_to_durations() {
assert!(jitter(Duration::from_secs(1)) != Duration::from_secs(1));
}
#[test]
fn backoff_default() {
assert!(matches!(Backoff::default(), Backoff::Exponential));
}
#[test]
fn fixed_backoff() {
let mut iter = RetryPolicy::fixed(Duration::from_secs(1)).backoffs();
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
}
#[test]
fn exponential_backoff() {
let mut iter = RetryPolicy::exponential(Duration::from_secs(1)).backoffs();
assert_eq!(iter.next(), Some(Duration::from_secs(1)));
assert_eq!(iter.next(), Some(Duration::from_secs(2)));
assert_eq!(iter.next(), Some(Duration::from_secs(4)));
assert_eq!(iter.next(), Some(Duration::from_secs(8)));
}
#[test]
fn always_is_always_retryable() {
assert!(Always.is_retryable(&()))
}
#[test]
fn closures_impl_condition() {
fn test(_: impl Condition<()>) {}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn foo(_err: &()) -> bool {
true
}
test(foo);
test(|_err: &()| true);
}
#[test]
fn closures_impl_task() {
fn test(_: impl Task) {}
async fn foo() -> Result<u32, ()> {
Ok(42)
}
test(foo);
test(|| async { Ok::<u32, ()>(42) });
}
#[test]
fn retried_futures_are_send_when_tasks_are_send() {
fn test(_: impl Send) {}
test(RetryPolicy::default().retry(|| async { Ok::<u32, ()>(42) }))
}
#[tokio::test]
async fn ok_futures_yield_ok() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::default()
.retry(|| async { Ok::<u32, ()>(42) })
.await;
assert_eq!(result, Ok(42));
Ok(())
}
#[tokio::test]
async fn failed_futures_yield_err() -> Result<(), Box<dyn Error>> {
let result = RetryPolicy::fixed(Duration::from_millis(1))
.retry(|| async { Err::<u32, ()>(()) })
.await;
assert_eq!(result, Err(()));
Ok(())
}
}