#![deny(unsafe_code)]
#![warn(
clippy::all,
clippy::await_holding_lock,
clippy::dbg_macro,
clippy::debug_assert_with_mut_call,
clippy::doc_markdown,
clippy::empty_enum,
clippy::enum_glob_use,
clippy::exit,
clippy::explicit_into_iter_loop,
clippy::filter_map_next,
clippy::fn_params_excessive_bools,
clippy::if_let_mutex,
clippy::imprecise_flops,
clippy::inefficient_to_string,
clippy::large_types_passed_by_value,
clippy::let_unit_value,
clippy::linkedlist,
clippy::lossy_float_literal,
clippy::macro_use_imports,
clippy::map_err_ignore,
clippy::map_flatten,
clippy::map_unwrap_or,
clippy::match_on_vec_items,
clippy::match_same_arms,
clippy::match_wildcard_for_single_variants,
clippy::mem_forget,
clippy::needless_borrow,
clippy::needless_continue,
clippy::option_option,
clippy::ref_option_ref,
clippy::rest_pat_in_fully_bound_structs,
clippy::string_add_assign,
clippy::string_add,
clippy::string_to_string,
clippy::suboptimal_flops,
clippy::todo,
clippy::unimplemented,
clippy::unnested_or_patterns,
clippy::unused_self,
clippy::verbose_file_reads,
unexpected_cfgs,
future_incompatible,
nonstandard_style,
rust_2018_idioms
)]
#![warn(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
use backoff_strategies::{
BackoffStrategy, ExponentialBackoff, FixedBackoff, LinearBackoff, NoBackoff,
};
use pin_project_lite::pin_project;
use std::time::Duration;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
mod on_retry;
pub mod backoff_strategies;
pub use on_retry::{NoOnRetry, OnRetry};
pub fn retry_fn<F>(f: F) -> RetryFn<F> {
RetryFn { f }
}
#[derive(Debug)]
pub struct RetryFn<F> {
f: F,
}
impl<F, Fut, T, E> RetryFn<F>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
pub fn retries(self, max_retries: u32) -> RetryFuture<F, Fut, NoBackoff, NoOnRetry> {
self.with_config(RetryFutureConfig::new(max_retries))
}
pub fn with_config<BackoffT, OnRetryT>(
self,
config: RetryFutureConfig<BackoffT, OnRetryT>,
) -> RetryFuture<F, Fut, BackoffT, OnRetryT> {
RetryFuture {
make_future: self.f,
attempts_remaining: config.max_retries,
state: RetryState::NotStarted,
attempt: 0,
config,
}
}
}
pin_project! {
pub struct RetryFuture<MakeFutureT, FutureT, BackoffT, OnRetryT> {
make_future: MakeFutureT,
attempts_remaining: u32,
#[pin]
state: RetryState<FutureT>,
attempt: u32,
config: RetryFutureConfig<BackoffT, OnRetryT>,
}
}
impl<MakeFutureT, FutureT, BackoffT, T, E, OnRetryT>
RetryFuture<MakeFutureT, FutureT, BackoffT, OnRetryT>
where
MakeFutureT: FnMut() -> FutureT,
FutureT: Future<Output = Result<T, E>>,
{
#[inline]
pub fn max_delay(mut self, delay: Duration) -> Self {
self.config = self.config.max_delay(delay);
self
}
#[inline]
pub fn no_backoff(self) -> RetryFuture<MakeFutureT, FutureT, NoBackoff, OnRetryT> {
self.custom_backoff(NoBackoff)
}
#[inline]
pub fn exponential_backoff(
self,
initial_delay: Duration,
) -> RetryFuture<MakeFutureT, FutureT, ExponentialBackoff, OnRetryT> {
self.custom_backoff(ExponentialBackoff {
delay: initial_delay,
})
}
#[inline]
pub fn fixed_backoff(
self,
delay: Duration,
) -> RetryFuture<MakeFutureT, FutureT, FixedBackoff, OnRetryT> {
self.custom_backoff(FixedBackoff { delay })
}
#[inline]
pub fn linear_backoff(
self,
delay: Duration,
) -> RetryFuture<MakeFutureT, FutureT, LinearBackoff, OnRetryT> {
self.custom_backoff(LinearBackoff { delay })
}
#[inline]
pub fn custom_backoff<B>(
self,
backoff_strategy: B,
) -> RetryFuture<MakeFutureT, FutureT, B, OnRetryT>
where
for<'a> B: BackoffStrategy<'a, E>,
{
RetryFuture {
make_future: self.make_future,
attempts_remaining: self.attempts_remaining,
state: self.state,
attempt: self.attempt,
config: self.config.custom_backoff(backoff_strategy),
}
}
#[inline]
pub fn on_retry<F, OnRetryFut>(self, f: F) -> RetryFuture<MakeFutureT, FutureT, BackoffT, F>
where
F: Fn(u32, Option<Duration>, &E) -> OnRetryFut,
{
RetryFuture {
make_future: self.make_future,
attempts_remaining: self.attempts_remaining,
state: self.state,
attempt: self.attempt,
config: self.config.on_retry(f),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct RetryFutureConfig<BackoffT, OnRetryT> {
backoff_strategy: BackoffT,
max_delay: Option<Duration>,
on_retry: Option<OnRetryT>,
max_retries: u32,
}
impl RetryFutureConfig<NoBackoff, NoOnRetry> {
pub fn new(max_retries: u32) -> Self {
Self {
backoff_strategy: NoBackoff,
max_delay: None,
on_retry: None::<NoOnRetry>,
max_retries,
}
}
}
impl<BackoffT, OnRetryT> RetryFutureConfig<BackoffT, OnRetryT> {
#[inline]
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = Some(delay);
self
}
#[inline]
pub fn no_backoff(self) -> RetryFutureConfig<NoBackoff, OnRetryT> {
self.custom_backoff(NoBackoff)
}
#[inline]
pub fn exponential_backoff(
self,
initial_delay: Duration,
) -> RetryFutureConfig<ExponentialBackoff, OnRetryT> {
self.custom_backoff(ExponentialBackoff {
delay: initial_delay,
})
}
#[inline]
pub fn fixed_backoff(self, delay: Duration) -> RetryFutureConfig<FixedBackoff, OnRetryT> {
self.custom_backoff(FixedBackoff { delay })
}
#[inline]
pub fn linear_backoff(self, delay: Duration) -> RetryFutureConfig<LinearBackoff, OnRetryT> {
self.custom_backoff(LinearBackoff { delay })
}
#[inline]
pub fn custom_backoff<B>(self, backoff_strategy: B) -> RetryFutureConfig<B, OnRetryT> {
RetryFutureConfig {
backoff_strategy,
max_delay: self.max_delay,
max_retries: self.max_retries,
on_retry: self.on_retry,
}
}
#[inline]
pub fn on_retry<F>(self, f: F) -> RetryFutureConfig<BackoffT, F> {
RetryFutureConfig {
backoff_strategy: self.backoff_strategy,
max_delay: self.max_delay,
max_retries: self.max_retries,
on_retry: Some(f),
}
}
}
impl<BackoffT, OnRetryT> fmt::Debug for RetryFutureConfig<BackoffT, OnRetryT>
where
BackoffT: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RetryFutureConfig")
.field("backoff_strategy", &self.backoff_strategy)
.field("max_delay", &self.max_delay)
.field("max_retries", &self.max_retries)
.field(
"on_retry",
&format_args!("<{}>", std::any::type_name::<OnRetryT>()),
)
.finish()
}
}
pin_project! {
#[project = RetryStateProj]
#[allow(clippy::large_enum_variant)]
enum RetryState<F> {
NotStarted,
WaitingForFuture { #[pin] future: F },
TimerActive { #[pin] sleep: tokio::time::Sleep },
}
}
impl<F, Fut, B, T, E, OnRetryT> Future for RetryFuture<F, Fut, B, OnRetryT>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
for<'a> B: BackoffStrategy<'a, E>,
for<'a> <B as BackoffStrategy<'a, E>>::Output: Into<RetryPolicy>,
OnRetryT: OnRetry<E>,
{
type Output = Result<T, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let this = self.as_mut().project();
let new_state = match this.state.project() {
RetryStateProj::NotStarted => RetryState::WaitingForFuture {
future: (this.make_future)(),
},
RetryStateProj::TimerActive { sleep } => match sleep.poll(cx) {
Poll::Ready(()) => RetryState::WaitingForFuture {
future: (this.make_future)(),
},
Poll::Pending => return Poll::Pending,
},
RetryStateProj::WaitingForFuture { future } => match future.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(value)) => {
return Poll::Ready(Ok(value));
}
Poll::Ready(Err(error)) => {
if *this.attempts_remaining == 0 {
if let Some(on_retry) = &mut this.config.on_retry {
tokio::spawn(on_retry.on_retry(*this.attempt, None, &error));
}
return Poll::Ready(Err(error));
} else {
*this.attempt += 1;
*this.attempts_remaining -= 1;
let delay: RetryPolicy = this
.config
.backoff_strategy
.delay(*this.attempt, &error)
.into();
let mut delay_duration = match delay {
RetryPolicy::Delay(duration) => duration,
RetryPolicy::Break => {
if let Some(on_retry) = &mut this.config.on_retry {
tokio::spawn(on_retry.on_retry(
*this.attempt,
None,
&error,
));
}
return Poll::Ready(Err(error));
}
};
if let Some(max_delay) = this.config.max_delay {
delay_duration = delay_duration.min(max_delay);
}
if let Some(on_retry) = &mut this.config.on_retry {
tokio::spawn(on_retry.on_retry(
*this.attempt,
Some(delay_duration),
&error,
));
}
let sleep = tokio::time::sleep(delay_duration);
RetryState::TimerActive { sleep }
}
}
},
};
self.as_mut().project().state.set(new_state);
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum RetryPolicy {
Delay(Duration),
Break,
}
impl From<Duration> for RetryPolicy {
fn from(duration: Duration) -> Self {
RetryPolicy::Delay(duration)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{convert::Infallible, time::Instant};
#[tokio::test]
async fn succeed() {
retry_fn(|| async { Ok::<_, Infallible>(true) })
.retries(10)
.await
.unwrap();
}
#[tokio::test]
async fn retrying_correct_amount_of_times() {
let counter = AtomicUsize::new(0);
let err = retry_fn(|| async {
counter.fetch_add(1, Ordering::SeqCst);
Err::<Infallible, _>("error")
})
.retries(10)
.await
.unwrap_err();
assert_eq!(err, "error");
assert_eq!(counter.load(Ordering::Relaxed), 11);
}
#[tokio::test]
async fn retry_0_times() {
let counter = AtomicUsize::new(0);
retry_fn(|| async {
counter.fetch_add(1, Ordering::SeqCst);
Err::<Infallible, _>("error")
})
.retries(0)
.await
.unwrap_err();
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn the_backoff_strategy_gets_used() {
async fn make_future() -> Result<Infallible, &'static str> {
Err("foo")
}
let start = Instant::now();
retry_fn(make_future)
.retries(10)
.no_backoff()
.await
.unwrap_err();
let time_with_none = start.elapsed();
let start = Instant::now();
retry_fn(make_future)
.retries(10)
.fixed_backoff(Duration::from_millis(10))
.await
.unwrap_err();
let time_with_fixed = start.elapsed();
assert!(time_with_fixed >= time_with_none);
}
#[test]
fn is_send() {
fn assert_send<T: Send>(_: T) {}
async fn some_future() -> Result<(), Infallible> {
Ok(())
}
assert_send(retry_fn(some_future).retries(10));
}
#[tokio::test]
async fn stop_retrying() {
let mut n = 0;
let make_future = || {
n += 1;
if n == 8 {
panic!("retried too many times");
}
async { Err::<Infallible, _>("foo") }
};
let error = retry_fn(make_future)
.retries(10)
.custom_backoff(|n, _: &&'static str| {
if n >= 3 {
RetryPolicy::Break
} else {
RetryPolicy::Delay(Duration::from_nanos(10))
}
})
.await
.unwrap_err();
assert_eq!(error, "foo");
}
#[tokio::test]
async fn custom_returning_duration() {
retry_fn(|| async { Ok::<_, Infallible>(true) })
.retries(10)
.custom_backoff(|_, _: &Infallible| Duration::from_nanos(10))
.await
.unwrap();
}
#[tokio::test]
async fn retry_hook_succeed() {
use std::sync::Arc;
use tokio::sync::Mutex;
let errors = Arc::new(Mutex::new(Vec::new()));
retry_fn(|| async { Err::<Infallible, String>("error".to_string()) })
.retries(10)
.on_retry(|attempt, next_delay, error: &String| {
let errors = Arc::clone(&errors);
let error = error.clone();
async move {
errors.lock().await.push((attempt, next_delay, error));
}
})
.await
.unwrap_err();
let errors = errors.lock().await;
assert_eq!(errors.len(), 10);
for n in 1_u32..=10 {
assert_eq!(
&errors[(n - 1) as usize],
&(n, Some(Duration::new(0, 0)), "error".to_string())
);
}
}
#[tokio::test]
async fn reusing_the_config() {
let counter = Arc::new(AtomicUsize::new(0));
let config = RetryFutureConfig::new(10)
.linear_backoff(Duration::from_millis(10))
.on_retry(|_, _, _: &&'static str| {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
});
let ok_value = retry_fn(|| async { Ok::<_, &str>(true) })
.with_config(config)
.await
.unwrap();
assert!(ok_value);
assert_eq!(counter.load(Ordering::SeqCst), 0);
let err_value = retry_fn(|| async { Err::<(), _>("foo") })
.with_config(config)
.await
.unwrap_err();
assert_eq!(err_value, "foo");
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[tokio::test]
async fn custom_backoff_wrapping_another_strategy() {
#[derive(Clone)]
struct MyBackoffStrategy {
inner: ExponentialBackoff,
}
impl<'a> BackoffStrategy<'a, std::io::Error> for MyBackoffStrategy {
type Output = RetryPolicy;
fn delay(&mut self, attempt: u32, error: &'a std::io::Error) -> Self::Output {
if error.kind() == std::io::ErrorKind::NotFound {
RetryPolicy::Break
} else {
RetryPolicy::Delay(self.inner.delay(attempt, error))
}
}
}
#[derive(Clone)]
struct MyOnRetry;
impl OnRetry<std::io::Error> for MyOnRetry {
type Future = futures::future::BoxFuture<'static, ()>;
fn on_retry(
&mut self,
attempt: u32,
next_delay: Option<Duration>,
previous_error: &std::io::Error,
) -> Self::Future {
let previous_error = previous_error.to_string();
Box::pin(async move {
println!("{} {:?} {}", attempt, next_delay, previous_error);
})
}
}
let config: RetryFutureConfig<MyBackoffStrategy, MyOnRetry> = RetryFutureConfig::new(10)
.custom_backoff(MyBackoffStrategy {
inner: ExponentialBackoff::new(Duration::from_millis(10)),
})
.on_retry(MyOnRetry);
retry_fn(|| async { Ok::<_, std::io::Error>(true) })
.with_config(config.clone())
.await
.unwrap();
retry_fn(|| async { Ok::<_, std::io::Error>(true) })
.with_config(config)
.await
.unwrap();
}
#[tokio::test]
async fn inference_works() {
std::mem::drop(async {
let _ = retry_fn(|| async { Result::<_, Infallible>::Ok(()) })
.retries(0)
.on_retry(|_, _, _| async {})
.await;
});
}
}