use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::Mutex;
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_factor: f64,
pub max_attempts: u32,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_factor: 2.0,
max_attempts: 10,
}
}
}
pub type TransportFactory = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = crate::error::Result<Box<dyn Transport>>> + Send>>
+ Send
+ Sync,
>;
pub struct ReconnectingTransport {
inner: Mutex<Option<Box<dyn Transport>>>,
factory: TransportFactory,
config: ReconnectConfig,
kind: TransportKind,
reconnect_count: std::sync::atomic::AtomicU64,
}
impl ReconnectingTransport {
pub fn new(
transport: Box<dyn Transport>,
factory: TransportFactory,
config: ReconnectConfig,
) -> Self {
let kind = transport.kind();
Self {
inner: Mutex::new(Some(transport)),
factory,
config,
kind,
reconnect_count: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn reconnect_count(&self) -> u64 {
self.reconnect_count
.load(std::sync::atomic::Ordering::Relaxed)
}
async fn reconnect(&self) -> crate::error::Result<Box<dyn Transport>> {
let mut delay = self.config.initial_delay;
let max = if self.config.max_attempts == 0 {
u32::MAX
} else {
self.config.max_attempts
};
for attempt in 1..=max {
tracing::info!(
kind = ?self.kind,
attempt,
max_attempts = max,
"reconnecting transport"
);
match (self.factory)().await {
Ok(t) => {
self.reconnect_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tracing::info!(kind = ?self.kind, attempt, "reconnected successfully");
return Ok(t);
}
Err(e) => {
tracing::warn!(
kind = ?self.kind,
attempt,
error = %e,
delay_ms = delay.as_millis(),
"reconnection failed, backing off"
);
if attempt < max {
tokio::time::sleep(delay).await;
let next = delay.as_secs_f64() * self.config.backoff_factor;
delay =
Duration::from_secs_f64(next.min(self.config.max_delay.as_secs_f64()));
}
}
}
}
Err(SrxError::Transport(TransportError::ConnectionFailed(
format!(
"{:?}: reconnection failed after {} attempts",
self.kind, max
),
)))
}
async fn with_reconnect<F, T>(&self, op: F) -> crate::error::Result<T>
where
F: Fn(&dyn Transport) -> Pin<Box<dyn Future<Output = crate::error::Result<T>> + Send + '_>>,
{
{
let guard = self.inner.lock().await;
if let Some(ref t) = *guard {
match op(t.as_ref()).await {
Ok(v) => return Ok(v),
Err(_) => {
}
}
}
}
let new_transport = self.reconnect().await?;
let result = op(new_transport.as_ref()).await;
let mut guard = self.inner.lock().await;
*guard = Some(new_transport);
result
}
}
#[async_trait]
impl Transport for ReconnectingTransport {
fn kind(&self) -> TransportKind {
self.kind
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let data_clone = data.clone();
self.with_reconnect(|t| {
let d = data_clone.clone();
Box::pin(async move { t.send(d).await })
})
.await
}
async fn recv(&self) -> crate::error::Result<Bytes> {
self.with_reconnect(|t| Box::pin(async move { t.recv().await }))
.await
}
async fn is_healthy(&self) -> bool {
let guard = self.inner.lock().await;
match &*guard {
Some(t) => t.is_healthy().await,
None => false,
}
}
async fn close(&self) -> crate::error::Result<()> {
let mut guard = self.inner.lock().await;
if let Some(t) = guard.take() {
t.close().await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct FailAfter {
remaining: AtomicU32,
}
impl FailAfter {
fn new(n: u32) -> Self {
Self {
remaining: AtomicU32::new(n),
}
}
}
#[async_trait]
impl Transport for FailAfter {
fn kind(&self) -> TransportKind {
TransportKind::Tcp
}
async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
if self.remaining.fetch_sub(1, Ordering::SeqCst) > 0 {
Ok(())
} else {
Err(SrxError::Transport(TransportError::ConnectionFailed(
"mock failure".into(),
)))
}
}
async fn recv(&self) -> crate::error::Result<Bytes> {
if self.remaining.fetch_sub(1, Ordering::SeqCst) > 0 {
Ok(Bytes::from_static(b"data"))
} else {
Err(SrxError::Transport(TransportError::ChannelClosed))
}
}
async fn is_healthy(&self) -> bool {
self.remaining.load(Ordering::SeqCst) > 0
}
async fn close(&self) -> crate::error::Result<()> {
Ok(())
}
}
fn test_factory() -> TransportFactory {
Arc::new(|| Box::pin(async { Ok(Box::new(FailAfter::new(10)) as Box<dyn Transport>) }))
}
fn fast_config() -> ReconnectConfig {
ReconnectConfig {
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_factor: 2.0,
max_attempts: 3,
}
}
#[tokio::test]
async fn send_succeeds_without_reconnect() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
t.send(Bytes::from_static(b"hello")).await.unwrap();
assert_eq!(t.reconnect_count(), 0);
}
#[tokio::test]
async fn recv_succeeds_without_reconnect() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
let data = t.recv().await.unwrap();
assert_eq!(data.as_ref(), b"data");
assert_eq!(t.reconnect_count(), 0);
}
#[tokio::test]
async fn reconnects_on_send_failure() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(0)), test_factory(), fast_config());
t.send(Bytes::from_static(b"hello")).await.unwrap();
assert_eq!(t.reconnect_count(), 1);
}
#[tokio::test]
async fn reconnects_on_recv_failure() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(0)), test_factory(), fast_config());
let data = t.recv().await.unwrap();
assert_eq!(data.as_ref(), b"data");
assert_eq!(t.reconnect_count(), 1);
}
#[tokio::test]
async fn factory_failure_exhausts_attempts() {
let factory: TransportFactory = Arc::new(|| {
Box::pin(async {
Err(SrxError::Transport(TransportError::ConnectionFailed(
"always fail".into(),
)))
})
});
let t = ReconnectingTransport::new(Box::new(FailAfter::new(0)), factory, fast_config());
let err = t.send(Bytes::from_static(b"hello")).await;
assert!(err.is_err());
}
#[tokio::test]
async fn close_clears_inner() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
assert!(t.is_healthy().await);
t.close().await.unwrap();
assert!(!t.is_healthy().await);
}
#[tokio::test]
async fn kind_matches_original() {
let t =
ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
assert_eq!(t.kind(), TransportKind::Tcp);
}
}