use futures_util::Stream;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::time::sleep;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub max_attempts: Option<u32>,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
max_attempts: None,
}
}
}
#[derive(Debug, Clone)]
struct ExponentialBackoff {
current_delay: Duration,
max_delay: Duration,
multiplier: f64,
}
impl ExponentialBackoff {
fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
Self {
current_delay: initial_delay,
max_delay,
multiplier,
}
}
fn next_delay(&mut self) -> Duration {
let delay = self.current_delay;
self.current_delay = std::cmp::min(
Duration::from_secs_f64(delay.as_secs_f64() * self.multiplier),
self.max_delay,
);
delay
}
fn reset(&mut self) {
self.current_delay = Duration::from_secs(1);
}
}
enum StreamState<S, Fut> {
Connected(S),
Reconnecting {
attempts: u32,
delay: Duration,
},
Connecting {
attempts: u32,
future: Option<Pin<Box<Fut>>>,
},
Terminated,
}
pub struct ReconnectingStream<T, S, F, Fut>
where
S: Stream<Item = Result<T>> + Unpin,
F: Fn() -> Fut,
Fut: Future<Output = Result<S>>,
{
connect_fn: F,
state: StreamState<S, Fut>,
config: ReconnectConfig,
backoff: ExponentialBackoff,
sleep_future: Option<Pin<Box<tokio::time::Sleep>>>,
}
impl<T, S, F, Fut> ReconnectingStream<T, S, F, Fut>
where
S: Stream<Item = Result<T>> + Unpin,
F: Fn() -> Fut,
Fut: Future<Output = Result<S>>,
{
pub fn new(config: ReconnectConfig, connect_fn: F) -> Self {
let backoff = ExponentialBackoff::new(
config.initial_delay,
config.max_delay,
config.multiplier,
);
Self {
connect_fn,
state: StreamState::Connecting {
attempts: 0,
future: None,
},
config,
backoff,
sleep_future: None,
}
}
fn handle_disconnection(&mut self, attempts: u32) -> Poll<Option<Result<T>>> {
if let Some(max) = self.config.max_attempts {
if attempts >= max {
self.state = StreamState::Terminated;
return Poll::Ready(Some(Err(Error::ReconnectFailed {
attempts,
last_error: "Maximum reconnection attempts reached".to_string(),
})));
}
}
let delay = self.backoff.next_delay();
self.state = StreamState::Reconnecting { attempts, delay };
self.sleep_future = Some(Box::pin(sleep(delay)));
Poll::Pending
}
}
impl<T, S, F, Fut> Stream for ReconnectingStream<T, S, F, Fut>
where
S: Stream<Item = Result<T>> + Unpin,
F: Fn() -> Fut + Unpin,
Fut: Future<Output = Result<S>>,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match &mut self.state {
StreamState::Connected(stream) => {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
self.backoff.reset();
return Poll::Ready(Some(Ok(item)));
}
Poll::Ready(Some(Err(Error::ConnectionClosed))) => {
return self.handle_disconnection(1);
}
Poll::Ready(Some(Err(e))) => {
let _ = self.handle_disconnection(1);
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
return self.handle_disconnection(1);
}
Poll::Pending => {
return Poll::Pending;
}
}
}
StreamState::Reconnecting { attempts, .. } => {
let attempts = *attempts;
if let Some(mut sleep_fut) = self.sleep_future.take() {
match Pin::new(&mut sleep_fut).poll(cx) {
Poll::Ready(()) => {
self.state = StreamState::Connecting {
attempts,
future: None,
};
continue;
}
Poll::Pending => {
self.sleep_future = Some(sleep_fut);
return Poll::Pending;
}
}
} else {
let delay = match &self.state {
StreamState::Reconnecting { delay, .. } => *delay,
_ => unreachable!(),
};
self.sleep_future = Some(Box::pin(sleep(delay)));
continue;
}
}
StreamState::Connecting { attempts, future } => {
let current_attempts = *attempts;
let mut boxed_fut = if let Some(fut) = future.take() {
fut
} else {
Box::pin((self.connect_fn)())
};
match boxed_fut.as_mut().poll(cx) {
Poll::Ready(Ok(stream)) => {
self.state = StreamState::Connected(stream);
self.backoff.reset();
continue;
}
Poll::Ready(Err(_e)) => {
let next_attempts = if current_attempts == 0 { 1 } else { current_attempts + 1 };
return self.handle_disconnection(next_attempts);
}
Poll::Pending => {
self.state = StreamState::Connecting {
attempts: current_attempts,
future: Some(boxed_fut),
};
return Poll::Pending;
}
}
}
StreamState::Terminated => {
return Poll::Ready(None);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff() {
let mut backoff = ExponentialBackoff::new(
Duration::from_secs(1),
Duration::from_secs(60),
2.0,
);
assert_eq!(backoff.next_delay(), Duration::from_secs(1));
assert_eq!(backoff.next_delay(), Duration::from_secs(2));
assert_eq!(backoff.next_delay(), Duration::from_secs(4));
assert_eq!(backoff.next_delay(), Duration::from_secs(8));
}
#[test]
fn test_backoff_max() {
let mut backoff = ExponentialBackoff::new(
Duration::from_secs(1),
Duration::from_secs(5),
2.0,
);
assert_eq!(backoff.next_delay(), Duration::from_secs(1));
assert_eq!(backoff.next_delay(), Duration::from_secs(2));
assert_eq!(backoff.next_delay(), Duration::from_secs(4));
assert_eq!(backoff.next_delay(), Duration::from_secs(5)); assert_eq!(backoff.next_delay(), Duration::from_secs(5)); }
#[test]
fn test_backoff_reset() {
let mut backoff = ExponentialBackoff::new(
Duration::from_secs(1),
Duration::from_secs(60),
2.0,
);
assert_eq!(backoff.next_delay(), Duration::from_secs(1));
assert_eq!(backoff.next_delay(), Duration::from_secs(2));
backoff.reset();
assert_eq!(backoff.next_delay(), Duration::from_secs(1));
}
}