stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! Timeout transport wrapper — enforces deadline on send/recv operations.

use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use bytes::Bytes;

use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};

/// Wraps any [`Transport`] with configurable send/recv timeouts.
pub struct TimeoutTransport {
    inner: Arc<dyn Transport>,
    send_timeout: Duration,
    recv_timeout: Duration,
}

impl TimeoutTransport {
    /// Wrap transport with the same timeout for both send and recv.
    pub fn new(inner: Arc<dyn Transport>, timeout: Duration) -> Self {
        Self {
            inner,
            send_timeout: timeout,
            recv_timeout: timeout,
        }
    }

    /// Wrap a transport with separate send and recv timeouts.
    pub fn with_separate_timeouts(
        inner: Arc<dyn Transport>,
        send_timeout: Duration,
        recv_timeout: Duration,
    ) -> Self {
        Self {
            inner,
            send_timeout,
            recv_timeout,
        }
    }

    /// Current send timeout.
    pub fn send_timeout(&self) -> Duration {
        self.send_timeout
    }

    /// Current recv timeout.
    pub fn recv_timeout(&self) -> Duration {
        self.recv_timeout
    }
}

#[async_trait]
impl Transport for TimeoutTransport {
    fn kind(&self) -> TransportKind {
        self.inner.kind()
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        tokio::time::timeout(self.send_timeout, self.inner.send(data))
            .await
            .map_err(|_| {
                SrxError::Transport(TransportError::Timeout {
                    transport: format!("{:?}", self.inner.kind()),
                    details: format!("send timed out after {:?}", self.send_timeout),
                })
            })?
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        tokio::time::timeout(self.recv_timeout, self.inner.recv())
            .await
            .map_err(|_| {
                SrxError::Transport(TransportError::Timeout {
                    transport: format!("{:?}", self.inner.kind()),
                    details: format!("recv timed out after {:?}", self.recv_timeout),
                })
            })?
    }

    async fn is_healthy(&self) -> bool {
        self.inner.is_healthy().await
    }

    async fn close(&self) -> crate::error::Result<()> {
        self.inner.close().await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicU32, Ordering};

    /// Transport that hangs forever on send/recv (to test timeouts).
    struct HangingTransport;

    #[async_trait]
    impl Transport for HangingTransport {
        fn kind(&self) -> TransportKind {
            TransportKind::Tcp
        }

        async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
            std::future::pending().await
        }

        async fn recv(&self) -> crate::error::Result<Bytes> {
            std::future::pending().await
        }

        async fn is_healthy(&self) -> bool {
            true
        }

        async fn close(&self) -> crate::error::Result<()> {
            Ok(())
        }
    }

    /// Transport that completes instantly.
    struct InstantTransport {
        send_count: AtomicU32,
    }

    impl InstantTransport {
        fn new() -> Self {
            Self {
                send_count: AtomicU32::new(0),
            }
        }
    }

    #[async_trait]
    impl Transport for InstantTransport {
        fn kind(&self) -> TransportKind {
            TransportKind::Udp
        }

        async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
            self.send_count.fetch_add(1, Ordering::SeqCst);
            Ok(())
        }

        async fn recv(&self) -> crate::error::Result<Bytes> {
            Ok(Bytes::from_static(b"instant"))
        }

        async fn is_healthy(&self) -> bool {
            true
        }

        async fn close(&self) -> crate::error::Result<()> {
            Ok(())
        }
    }

    #[tokio::test]
    async fn send_timeout_triggers() {
        let t = TimeoutTransport::new(Arc::new(HangingTransport), Duration::from_millis(10));
        let err = t.send(Bytes::from_static(b"test")).await;
        assert!(err.is_err());
        let msg = format!("{}", err.unwrap_err());
        assert!(msg.contains("timed out"));
    }

    #[tokio::test]
    async fn recv_timeout_triggers() {
        let t = TimeoutTransport::new(Arc::new(HangingTransport), Duration::from_millis(10));
        let err = t.recv().await;
        assert!(err.is_err());
    }

    #[tokio::test]
    async fn fast_operation_succeeds() {
        let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(5));
        t.send(Bytes::from_static(b"ok")).await.unwrap();
        let data = t.recv().await.unwrap();
        assert_eq!(data.as_ref(), b"instant");
    }

    #[tokio::test]
    async fn kind_delegates() {
        let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
        assert_eq!(t.kind(), TransportKind::Udp);
    }

    #[tokio::test]
    async fn separate_timeouts() {
        let t = TimeoutTransport::with_separate_timeouts(
            Arc::new(HangingTransport),
            Duration::from_millis(5),
            Duration::from_millis(50),
        );
        assert_eq!(t.send_timeout(), Duration::from_millis(5));
        assert_eq!(t.recv_timeout(), Duration::from_millis(50));

        // Send should time out quickly.
        let err = t.send(Bytes::from_static(b"x")).await;
        assert!(err.is_err());
    }

    #[tokio::test]
    async fn is_healthy_delegates() {
        let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
        assert!(t.is_healthy().await);
    }

    #[tokio::test]
    async fn close_delegates() {
        let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
        t.close().await.unwrap();
    }
}