use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use crate::Backoff;
pub fn retry<B, F, T, E, Fut>(
backoff: B,
func: F,
) -> RetryFuture<F, Fut, impl Fn(&E, u32) -> bool, B>
where
B: Backoff,
F: Fn() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
retry_if(backoff, func, |_, _| true)
}
pub fn retry_if<B, F, P, T, E, Fut>(backoff: B, func: F, predicate: P) -> RetryFuture<F, Fut, P, B>
where
B: Backoff,
F: Fn() -> Fut,
Fut: Future<Output = Result<T, E>>,
P: Fn(&E, u32) -> bool,
{
let future = func();
RetryFuture {
factory: func,
future,
predicate,
backoff,
paused_until: None,
iterations: 0,
}
}
#[pin_project::pin_project]
pub struct RetryFuture<F, Fut, P, B> {
factory: F,
#[pin]
future: Fut,
predicate: P,
backoff: B,
paused_until: Option<Instant>,
iterations: u32,
}
impl<T, E, F, Fut, P, B> Future for RetryFuture<F, Fut, P, B>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, E>>,
P: Fn(&E, u32) -> bool,
B: Backoff,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Some(paused_until) = this.paused_until {
if Instant::now() < *paused_until {
return Poll::Pending;
}
*this.paused_until = None;
}
let result = match this.future.as_mut().poll(cx) {
Poll::Ready(res) => res,
Poll::Pending => return Poll::Pending,
};
match result {
Ok(value) => return Poll::Ready(Ok(value)),
Err(e) => {
*this.iterations += 1;
let can_continue = (this.predicate)(&e, *this.iterations);
if !can_continue {
return Poll::Ready(Err(e));
}
let new_future = (this.factory)();
this.future.set(new_future);
let duration = this.backoff.backoff_period(*this.iterations);
let waker = cx.waker().clone();
*this.paused_until = Some(Instant::now() + duration);
#[cfg(feature = "runtime-tokio")]
tokio::spawn(async move {
tokio::time::sleep(duration).await;
waker.wake();
});
#[cfg(feature = "runtime-async-std")]
async_std::task::spawn(async move {
async_std::task::sleep(duration).await;
waker.wake();
});
Poll::Pending
}
}
}
}