use futures::{FutureExt, Stream};
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::sleep;
use crate::ConnOptions;
pub trait Backoff: Stream<Item = Duration> + Unpin {}
impl<T> Backoff for T where T: Stream<Item = Duration> + Unpin {}
pub struct ExponentialBackoff {
retry_count: usize,
max_retries: Option<usize>,
backoff: Duration,
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
}
impl ExponentialBackoff {
pub fn new(initial: Duration, max_retries: Option<usize>) -> Self {
Self { retry_count: 0, max_retries, backoff: initial, timeout: None }
}
fn reset_timeout(&mut self) {
self.timeout = Some(Box::pin(sleep(self.backoff)));
}
}
impl From<&ConnOptions> for ExponentialBackoff {
fn from(options: &ConnOptions) -> Self {
Self::new(options.backoff_duration, options.retry_attempts)
}
}
impl Stream for ExponentialBackoff {
type Item = Duration;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
let Some(ref mut timeout) = this.timeout else {
this.reset_timeout();
continue;
};
if timeout.poll_unpin(cx).is_ready() {
this.backoff *= 2;
this.retry_count += 1;
if let Some(max_retries) = this.max_retries &&
this.retry_count >= max_retries
{
return Poll::Ready(None);
}
this.reset_timeout();
cx.waker().wake_by_ref();
return Poll::Ready(Some(this.backoff));
} else {
return Poll::Pending;
}
}
}
}