use std::future::Future;
use std::pin::{pin, Pin};
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use futures_core::ready;
use pin_project::pin_project;
use crate::backoff::BackoffBuilder;
use crate::Backoff;
pub trait RetryableWithContext<
B: BackoffBuilder,
T,
E,
Ctx,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
>
{
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn>;
}
impl<B, T, E, Ctx, Fut, FutureFn> RetryableWithContext<B, T, E, Ctx, Fut, FutureFn> for FutureFn
where
B: BackoffBuilder,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
{
fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Ctx, Fut, FutureFn> {
Retry::new(self, builder.build())
}
}
#[pin_project]
pub struct Retry<
B: Backoff,
T,
E,
Ctx,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF = fn(&E) -> bool,
NF = fn(&E, Duration),
> {
backoff: B,
retryable: RF,
notify: NF,
future_fn: FutureFn,
#[pin]
state: State<T, E, Ctx, Fut>,
}
impl<B, T, E, Ctx, Fut, FutureFn> Retry<B, T, E, Ctx, Fut, FutureFn>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
{
fn new(future_fn: FutureFn, backoff: B) -> Self {
Retry {
backoff,
retryable: |_: &E| true,
notify: |_: &E, _: Duration| {},
future_fn,
state: State::Idle(None),
}
}
}
impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
pub fn context(self, context: Ctx) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify: self.notify,
future_fn: self.future_fn,
state: State::Idle(Some(context)),
}
}
pub fn when<RN: FnMut(&E) -> bool>(
self,
retryable: RN,
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RN, NF> {
Retry {
backoff: self.backoff,
retryable,
notify: self.notify,
future_fn: self.future_fn,
state: self.state,
}
}
pub fn notify<NN: FnMut(&E, Duration)>(
self,
notify: NN,
) -> Retry<B, T, E, Ctx, Fut, FutureFn, RF, NN> {
Retry {
backoff: self.backoff,
retryable: self.retryable,
notify,
future_fn: self.future_fn,
state: self.state,
}
}
}
#[pin_project(project = StateProject)]
enum State<T, E, Ctx, Fut: Future<Output = (Ctx, Result<T, E>)>> {
Idle(Option<Ctx>),
Polling(#[pin] Fut),
Sleeping((Option<Ctx>, Pin<Box<tokio::time::Sleep>>)),
}
impl<B, T, E, Ctx, Fut, FutureFn, RF, NF> Future for Retry<B, T, E, Ctx, Fut, FutureFn, RF, NF>
where
B: Backoff,
Fut: Future<Output = (Ctx, Result<T, E>)>,
FutureFn: FnMut(Ctx) -> Fut,
RF: FnMut(&E) -> bool,
NF: FnMut(&E, Duration),
{
type Output = (Ctx, Result<T, E>);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
let state = this.state.as_mut().project();
match state {
StateProject::Idle(ctx) => {
let ctx = ctx.take().expect("context must be valid");
let fut = (this.future_fn)(ctx);
this.state.set(State::Polling(fut));
continue;
}
StateProject::Polling(fut) => {
let (ctx, res) = ready!(fut.poll(cx));
match res {
Ok(v) => return Poll::Ready((ctx, Ok(v))),
Err(err) => {
if !(this.retryable)(&err) {
return Poll::Ready((ctx, Err(err)));
}
match this.backoff.next() {
None => return Poll::Ready((ctx, Err(err))),
Some(dur) => {
(this.notify)(&err, dur);
this.state.set(State::Sleeping((
Some(ctx),
Box::pin(tokio::time::sleep(dur)),
)));
continue;
}
}
}
}
}
StateProject::Sleeping((ctx, sl)) => {
ready!(pin!(sl).poll(cx));
let ctx = ctx.take().expect("context must be valid");
this.state.set(State::Idle(Some(ctx)));
continue;
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use anyhow::anyhow;
use tokio::sync::Mutex;
use super::*;
use crate::exponential::ExponentialBuilder;
use anyhow::Result;
struct Test;
impl Test {
async fn hello(&mut self) -> Result<usize> {
Err(anyhow!("not retryable"))
}
}
#[tokio::test]
async fn test_retry_with_not_retryable_error() -> Result<()> {
let error_times = Mutex::new(0);
let test = Test;
let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
let (_, result) = {
|mut v: Test| async {
let mut x = error_times.lock().await;
*x += 1;
let res = v.hello().await;
(v, res)
}
}
.retry(&backoff)
.context(test)
.when(|e| e.to_string() == "retryable")
.await;
assert!(result.is_err());
assert_eq!("not retryable", result.unwrap_err().to_string());
assert_eq!(*error_times.lock().await, 1);
Ok(())
}
}