stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! TCP transport implementation.

use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;

use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};

/// Maximum single `recv` read size (one frame or chunk; framing is upper-layer).
const RECV_BUF: usize = 65_536;

/// Async TCP [`Transport`] using Tokio.
pub struct TcpTransport {
    stream: Arc<Mutex<Option<TcpStream>>>,
}

impl TcpTransport {
    /// Connect to `addr` (`host:port` or parsed socket address string).
    pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
        let stream = TcpStream::connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            stream: Arc::new(Mutex::new(Some(stream))),
        })
    }

    /// Wrap an established stream (e.g. after `TcpListener::accept`).
    #[must_use]
    pub fn from_stream(stream: TcpStream) -> Self {
        Self {
            stream: Arc::new(Mutex::new(Some(stream))),
        }
    }

    /// Send one length-prefixed payload (`u32` BE length + bytes), same wire as [`crate::write_length_prefixed`].
    pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        write_length_prefixed(stream, payload).await
    }

    /// Receive one length-prefixed payload (full message; not a single kernel chunk).
    pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let v = read_length_prefixed(stream).await?;
        Ok(Bytes::from(v))
    }
}

#[async_trait]
impl Transport for TcpTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::Tcp
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        stream
            .write_all(&data)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        stream
            .flush()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let mut buf = vec![0u8; RECV_BUF];
        let n = stream
            .read(&mut buf)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        if n == 0 {
            return Err(SrxError::Transport(TransportError::ChannelClosed));
        }
        buf.truncate(n);
        Ok(Bytes::from(buf))
    }

    async fn is_healthy(&self) -> bool {
        self.stream.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        if let Some(mut s) = guard.take() {
            let _ = s.shutdown().await;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn send_recv_roundtrip() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let t = TcpTransport::from_stream(stream);
            let got = t.recv().await.unwrap();
            assert_eq!(got.as_ref(), b"ping");
            t.send(Bytes::from_static(b"pong")).await.unwrap();
        });

        let client = TcpTransport::connect(addr).await.unwrap();
        client.send(Bytes::from_static(b"ping")).await.unwrap();
        let reply = client.recv().await.unwrap();
        assert_eq!(reply.as_ref(), b"pong");
        client.close().await.unwrap();

        server.await.unwrap();
    }

    #[tokio::test]
    async fn framed_roundtrip_matches_length_prefix_wire() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let payload = b"framed-payload-srx";

        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let t = TcpTransport::from_stream(stream);
            let got = t.recv_framed().await.unwrap();
            assert_eq!(got.as_ref(), payload);
            t.send_framed(b"ack").await.unwrap();
        });

        let client = TcpTransport::connect(addr).await.unwrap();
        client.send_framed(payload).await.unwrap();
        let reply = client.recv_framed().await.unwrap();
        assert_eq!(reply.as_ref(), b"ack");
        client.close().await.unwrap();

        server.await.unwrap();
    }
}