stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! Reconnecting transport wrapper — retries a connection establishment on failure
//! with configurable exponential backoff.

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::Mutex;

use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};

/// Configuration for reconnection behavior.
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
    /// Initial delay before first retry.
    pub initial_delay: Duration,
    /// Maximum delay between retries.
    pub max_delay: Duration,
    /// Multiplier applied to delay after each failure (exponential backoff).
    pub backoff_factor: f64,
    /// Maximum number of reconnection attempts (0 = unlimited).
    pub max_attempts: u32,
}

impl Default for ReconnectConfig {
    fn default() -> Self {
        Self {
            initial_delay: Duration::from_millis(100),
            max_delay: Duration::from_secs(30),
            backoff_factor: 2.0,
            max_attempts: 10,
        }
    }
}

/// Factory function type for creating new transport connections.
///
/// The factory is called each time a reconnection is needed.
pub type TransportFactory = Arc<
    dyn Fn() -> Pin<Box<dyn Future<Output = crate::error::Result<Box<dyn Transport>>> + Send>>
        + Send
        + Sync,
>;

/// Wraps any [`Transport`] with automatic reconnection on send/recv failure.
///
/// When a sending or recv fails, the wrapper calls the factory to create a new
/// transport instance and retries the operation.
pub struct ReconnectingTransport {
    inner: Mutex<Option<Box<dyn Transport>>>,
    factory: TransportFactory,
    config: ReconnectConfig,
    kind: TransportKind,
    /// Total number of reconnections performed.
    reconnect_count: std::sync::atomic::AtomicU64,
}

impl ReconnectingTransport {
    /// Create a new reconnecting wrapper around an initial transport.
    pub fn new(
        transport: Box<dyn Transport>,
        factory: TransportFactory,
        config: ReconnectConfig,
    ) -> Self {
        let kind = transport.kind();
        Self {
            inner: Mutex::new(Some(transport)),
            factory,
            config,
            kind,
            reconnect_count: std::sync::atomic::AtomicU64::new(0),
        }
    }

    /// Total number of reconnections performed since creation.
    pub fn reconnect_count(&self) -> u64 {
        self.reconnect_count
            .load(std::sync::atomic::Ordering::Relaxed)
    }

    /// Attempt to reconnect using the factory with exponential backoff.
    async fn reconnect(&self) -> crate::error::Result<Box<dyn Transport>> {
        let mut delay = self.config.initial_delay;
        let max = if self.config.max_attempts == 0 {
            u32::MAX
        } else {
            self.config.max_attempts
        };

        for attempt in 1..=max {
            tracing::info!(
                kind = ?self.kind,
                attempt,
                max_attempts = max,
                "reconnecting transport"
            );
            match (self.factory)().await {
                Ok(t) => {
                    self.reconnect_count
                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                    tracing::info!(kind = ?self.kind, attempt, "reconnected successfully");
                    return Ok(t);
                }
                Err(e) => {
                    tracing::warn!(
                        kind = ?self.kind,
                        attempt,
                        error = %e,
                        delay_ms = delay.as_millis(),
                        "reconnection failed, backing off"
                    );
                    if attempt < max {
                        tokio::time::sleep(delay).await;
                        let next = delay.as_secs_f64() * self.config.backoff_factor;
                        delay =
                            Duration::from_secs_f64(next.min(self.config.max_delay.as_secs_f64()));
                    }
                }
            }
        }

        Err(SrxError::Transport(TransportError::ConnectionFailed(
            format!(
                "{:?}: reconnection failed after {} attempts",
                self.kind, max
            ),
        )))
    }

    /// Execute an operation, reconnecting on failure and retrying once.
    async fn with_reconnect<F, T>(&self, op: F) -> crate::error::Result<T>
    where
        F: Fn(&dyn Transport) -> Pin<Box<dyn Future<Output = crate::error::Result<T>> + Send + '_>>,
    {
        // First attempt with existing transport.
        {
            let guard = self.inner.lock().await;
            if let Some(ref t) = *guard {
                match op(t.as_ref()).await {
                    Ok(v) => return Ok(v),
                    Err(_) => {
                        // Will reconnect below.
                    }
                }
            }
        }

        // Reconnect.
        let new_transport = self.reconnect().await?;

        // Retry with fresh transport.
        let result = op(new_transport.as_ref()).await;

        // Store the new transport regardless of result (it's the freshest).
        let mut guard = self.inner.lock().await;
        *guard = Some(new_transport);

        result
    }
}

#[async_trait]
impl Transport for ReconnectingTransport {
    fn kind(&self) -> TransportKind {
        self.kind
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let data_clone = data.clone();
        self.with_reconnect(|t| {
            let d = data_clone.clone();
            Box::pin(async move { t.send(d).await })
        })
        .await
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        self.with_reconnect(|t| Box::pin(async move { t.recv().await }))
            .await
    }

    async fn is_healthy(&self) -> bool {
        let guard = self.inner.lock().await;
        match &*guard {
            Some(t) => t.is_healthy().await,
            None => false,
        }
    }

    async fn close(&self) -> crate::error::Result<()> {
        let mut guard = self.inner.lock().await;
        if let Some(t) = guard.take() {
            t.close().await?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicU32, Ordering};

    /// Mock transport that fails after N operations.
    struct FailAfter {
        remaining: AtomicU32,
    }

    impl FailAfter {
        fn new(n: u32) -> Self {
            Self {
                remaining: AtomicU32::new(n),
            }
        }
    }

    #[async_trait]
    impl Transport for FailAfter {
        fn kind(&self) -> TransportKind {
            TransportKind::Tcp
        }

        async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
            if self.remaining.fetch_sub(1, Ordering::SeqCst) > 0 {
                Ok(())
            } else {
                Err(SrxError::Transport(TransportError::ConnectionFailed(
                    "mock failure".into(),
                )))
            }
        }

        async fn recv(&self) -> crate::error::Result<Bytes> {
            if self.remaining.fetch_sub(1, Ordering::SeqCst) > 0 {
                Ok(Bytes::from_static(b"data"))
            } else {
                Err(SrxError::Transport(TransportError::ChannelClosed))
            }
        }

        async fn is_healthy(&self) -> bool {
            self.remaining.load(Ordering::SeqCst) > 0
        }

        async fn close(&self) -> crate::error::Result<()> {
            Ok(())
        }
    }

    fn test_factory() -> TransportFactory {
        Arc::new(|| Box::pin(async { Ok(Box::new(FailAfter::new(10)) as Box<dyn Transport>) }))
    }

    fn fast_config() -> ReconnectConfig {
        ReconnectConfig {
            initial_delay: Duration::from_millis(1),
            max_delay: Duration::from_millis(10),
            backoff_factor: 2.0,
            max_attempts: 3,
        }
    }

    #[tokio::test]
    async fn send_succeeds_without_reconnect() {
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
        t.send(Bytes::from_static(b"hello")).await.unwrap();
        assert_eq!(t.reconnect_count(), 0);
    }

    #[tokio::test]
    async fn recv_succeeds_without_reconnect() {
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
        let data = t.recv().await.unwrap();
        assert_eq!(data.as_ref(), b"data");
        assert_eq!(t.reconnect_count(), 0);
    }

    #[tokio::test]
    async fn reconnects_on_send_failure() {
        // Transport that fails immediately.
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(0)), test_factory(), fast_config());
        t.send(Bytes::from_static(b"hello")).await.unwrap();
        assert_eq!(t.reconnect_count(), 1);
    }

    #[tokio::test]
    async fn reconnects_on_recv_failure() {
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(0)), test_factory(), fast_config());
        let data = t.recv().await.unwrap();
        assert_eq!(data.as_ref(), b"data");
        assert_eq!(t.reconnect_count(), 1);
    }

    #[tokio::test]
    async fn factory_failure_exhausts_attempts() {
        let factory: TransportFactory = Arc::new(|| {
            Box::pin(async {
                Err(SrxError::Transport(TransportError::ConnectionFailed(
                    "always fail".into(),
                )))
            })
        });
        let t = ReconnectingTransport::new(Box::new(FailAfter::new(0)), factory, fast_config());
        let err = t.send(Bytes::from_static(b"hello")).await;
        assert!(err.is_err());
    }

    #[tokio::test]
    async fn close_clears_inner() {
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
        assert!(t.is_healthy().await);
        t.close().await.unwrap();
        assert!(!t.is_healthy().await);
    }

    #[tokio::test]
    async fn kind_matches_original() {
        let t =
            ReconnectingTransport::new(Box::new(FailAfter::new(5)), test_factory(), fast_config());
        assert_eq!(t.kind(), TransportKind::Tcp);
    }
}