backon 0.4.3

Retry with backoff without effort.
Documentation
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;

use futures_core::ready;
use pin_project::pin_project;

use crate::backoff::BackoffBuilder;
use crate::Backoff;

/// Retryable will add retry support for functions that produces a futures with results.
///
/// That means all types that implement `FnMut() -> impl Future<Output = Result<T, E>>`
/// will be able to use `retry`.
///
/// For example:
///
/// - Functions without extra args:
///
/// ```ignore
/// async fn fetch() -> Result<String> {
///     Ok(reqwest::get("https://www.rust-lang.org").await?.text().await?)
/// }
/// ```
///
/// - Closures
///
/// ```ignore
/// || async {
///     let x = reqwest::get("https://www.rust-lang.org")
///         .await?
///         .text()
///         .await?;
///
///     Err(anyhow::anyhow!(x))
/// }
/// ```
///
/// # Example
///
/// ```no_run
/// use anyhow::Result;
/// use backon::ExponentialBuilder;
/// use backon::Retryable;
///
/// async fn fetch() -> Result<String> {
///     Ok(reqwest::get("https://www.rust-lang.org")
///         .await?
///         .text()
///         .await?)
/// }
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() -> Result<()> {
///     let content = fetch.retry(&ExponentialBuilder::default()).await?;
///     println!("fetch succeeded: {}", content);
///
///     Ok(())
/// }
/// ```
pub trait Retryable<
    B: BackoffBuilder,
    T,
    E,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
>
{
    /// Generate a new retry
    fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Fut, FutureFn>;
}

impl<B, T, E, Fut, FutureFn> Retryable<B, T, E, Fut, FutureFn> for FutureFn
where
    B: BackoffBuilder,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
{
    fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Fut, FutureFn> {
        Retry::new(self, builder.build())
    }
}

/// Retry struct generated by [`Retryable`].
#[pin_project]
pub struct Retry<
    B: Backoff,
    T,
    E,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
    RF = fn(&E) -> bool,
    NF = fn(&E, Duration),
> {
    backoff: B,
    retryable: RF,
    notify: NF,
    future_fn: FutureFn,

    #[pin]
    state: State<T, E, Fut>,
}

impl<B, T, E, Fut, FutureFn> Retry<B, T, E, Fut, FutureFn>
where
    B: Backoff,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
{
    /// Create a new retry.
    fn new(future_fn: FutureFn, backoff: B) -> Self {
        Retry {
            backoff,
            retryable: |_: &E| true,
            notify: |_: &E, _: Duration| {},
            future_fn,
            state: State::Idle,
        }
    }
}

impl<B, T, E, Fut, FutureFn, RF, NF> Retry<B, T, E, Fut, FutureFn, RF, NF>
where
    B: Backoff,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
    RF: FnMut(&E) -> bool,
    NF: FnMut(&E, Duration),
{
    /// Set the conditions for retrying.
    ///
    /// If not specified, we treat all errors as retryable.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use anyhow::Result;
    /// use backon::ExponentialBuilder;
    /// use backon::Retryable;
    ///
    /// async fn fetch() -> Result<String> {
    ///     Ok(reqwest::get("https://www.rust-lang.org")
    ///         .await?
    ///         .text()
    ///         .await?)
    /// }
    ///
    /// #[tokio::main(flavor = "current_thread")]
    /// async fn main() -> Result<()> {
    ///     let content = fetch
    ///         .retry(&ExponentialBuilder::default())
    ///         .when(|e| e.to_string() == "EOF")
    ///         .await?;
    ///     println!("fetch succeeded: {}", content);
    ///
    ///     Ok(())
    /// }
    /// ```
    pub fn when<RN: FnMut(&E) -> bool>(
        self,
        retryable: RN,
    ) -> Retry<B, T, E, Fut, FutureFn, RN, NF> {
        Retry {
            backoff: self.backoff,
            retryable,
            notify: self.notify,
            future_fn: self.future_fn,
            state: self.state,
        }
    }

    /// Set to notify for everything retrying.
    ///
    /// If not specified, this is a no-op.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::time::Duration;
    ///
    /// use anyhow::Result;
    /// use backon::ExponentialBuilder;
    /// use backon::Retryable;
    ///
    /// async fn fetch() -> Result<String> {
    ///     Ok(reqwest::get("https://www.rust-lang.org")
    ///         .await?
    ///         .text()
    ///         .await?)
    /// }
    ///
    /// #[tokio::main(flavor = "current_thread")]
    /// async fn main() -> Result<()> {
    ///     let content = fetch
    ///         .retry(&ExponentialBuilder::default())
    ///         .notify(|err: &anyhow::Error, dur: Duration| {
    ///             println!("retrying error {:?} with sleeping {:?}", err, dur);
    ///         })
    ///         .await?;
    ///     println!("fetch succeeded: {}", content);
    ///
    ///     Ok(())
    /// }
    /// ```
    pub fn notify<NN: FnMut(&E, Duration)>(
        self,
        notify: NN,
    ) -> Retry<B, T, E, Fut, FutureFn, RF, NN> {
        Retry {
            backoff: self.backoff,
            retryable: self.retryable,
            notify,
            future_fn: self.future_fn,
            state: self.state,
        }
    }
}

/// State maintains internal state of retry.
///
/// # Notes
///
/// `tokio::time::Sleep` is a very struct that occupy 640B, so we wrap it
/// into a `Pin<Box<_>>` to avoid this enum too large.
#[derive(Default)]
#[pin_project(project = StateProject)]
enum State<T, E, Fut: Future<Output = Result<T, E>>> {
    #[default]
    Idle,
    Polling(#[pin] Fut),
    // TODO: we need to support other sleeper
    Sleeping(#[pin] Pin<Box<tokio::time::Sleep>>),
}

impl<B, T, E, Fut, FutureFn, RF, NF> Future for Retry<B, T, E, Fut, FutureFn, RF, NF>
where
    B: Backoff,
    Fut: Future<Output = Result<T, E>>,
    FutureFn: FnMut() -> Fut,
    RF: FnMut(&E) -> bool,
    NF: FnMut(&E, Duration),
{
    type Output = Result<T, E>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();
        loop {
            let state = this.state.as_mut().project();
            match state {
                StateProject::Idle => {
                    let fut = (this.future_fn)();
                    this.state.set(State::Polling(fut));
                    continue;
                }
                StateProject::Polling(fut) => match ready!(fut.poll(cx)) {
                    Ok(v) => return Poll::Ready(Ok(v)),
                    Err(err) => {
                        // If input error is not retryable, return error directly.
                        if !(this.retryable)(&err) {
                            return Poll::Ready(Err(err));
                        }
                        match this.backoff.next() {
                            None => return Poll::Ready(Err(err)),
                            Some(dur) => {
                                (this.notify)(&err, dur);
                                this.state
                                    .set(State::Sleeping(Box::pin(tokio::time::sleep(dur))));
                                continue;
                            }
                        }
                    }
                },
                StateProject::Sleeping(sl) => {
                    ready!(sl.poll(cx));
                    this.state.set(State::Idle);
                    continue;
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use tokio::sync::Mutex;

    use super::*;
    use crate::exponential::ExponentialBuilder;

    async fn always_error() -> anyhow::Result<()> {
        Err(anyhow::anyhow!("test_query meets error"))
    }

    #[tokio::test]
    async fn test_retry() -> anyhow::Result<()> {
        let result = always_error
            .retry(&ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
            .await;

        assert!(result.is_err());
        assert_eq!("test_query meets error", result.unwrap_err().to_string());
        Ok(())
    }

    #[tokio::test]
    async fn test_retry_with_not_retryable_error() -> anyhow::Result<()> {
        let error_times = Mutex::new(0);

        let f = || async {
            let mut x = error_times.lock().await;
            *x += 1;
            Err::<(), anyhow::Error>(anyhow::anyhow!("not retryable"))
        };

        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
        let result = f
            .retry(&backoff)
            // Only retry If error message is `retryable`
            .when(|e| e.to_string() == "retryable")
            .await;

        assert!(result.is_err());
        assert_eq!("not retryable", result.unwrap_err().to_string());
        // `f` always returns error "not retryable", so it should be executed
        // only once.
        assert_eq!(*error_times.lock().await, 1);
        Ok(())
    }

    #[tokio::test]
    async fn test_retry_with_retryable_error() -> anyhow::Result<()> {
        let error_times = Mutex::new(0);

        let f = || async {
            let mut x = error_times.lock().await;
            *x += 1;
            Err::<(), anyhow::Error>(anyhow::anyhow!("retryable"))
        };

        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
        let result = f
            .retry(&backoff)
            // Only retry If error message is `retryable`
            .when(|e| e.to_string() == "retryable")
            .await;

        assert!(result.is_err());
        assert_eq!("retryable", result.unwrap_err().to_string());
        // `f` always returns error "retryable", so it should be executed
        // 4 times (retry 3 times).
        assert_eq!(*error_times.lock().await, 4);
        Ok(())
    }

    #[tokio::test]
    async fn test_fn_mut_when_and_notify() -> anyhow::Result<()> {
        let mut calls_retryable: Vec<()> = vec![];
        let mut calls_notify: Vec<()> = vec![];

        let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };

        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
        let result = f
            .retry(&backoff)
            .when(|_| {
                calls_retryable.push(());
                true
            })
            .notify(|_, _| {
                calls_notify.push(());
            })
            .await;

        assert!(result.is_err());
        assert_eq!("retryable", result.unwrap_err().to_string());
        // `f` always returns error "retryable", so it should be executed
        // 4 times (retry 3 times).
        assert_eq!(calls_retryable.len(), 4);
        assert_eq!(calls_notify.len(), 3);
        Ok(())
    }
}