stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! UDP transport implementation (connected datagram socket).

use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;

use super::{Transport, TransportKind};
use crate::error::{FrameError, SrxError, TransportError};

const MAX_DATAGRAM: usize = 65_507;

/// Connected UDP socket: `send` / `recv` map to one peer.
pub struct UdpTransport {
    socket: Arc<Mutex<Option<UdpSocket>>>,
}

impl UdpTransport {
    /// Bind ephemeral port and `connect` to remote `addr`.
    pub async fn connect(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
        let socket = UdpSocket::bind("0.0.0.0:0")
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        socket
            .connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            socket: Arc::new(Mutex::new(Some(socket))),
        })
    }

    /// Bind local address (server); use [`UdpTransport::connect_peer`] before send/recv in connected mode.
    pub async fn bind(addr: impl tokio::net::ToSocketAddrs) -> crate::error::Result<Self> {
        let socket = UdpSocket::bind(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            socket: Arc::new(Mutex::new(Some(socket))),
        })
    }

    /// After [`UdpTransport::bind`], connect to a single peer for [`Transport`] semantics.
    pub async fn connect_peer(&self, addr: SocketAddr) -> crate::error::Result<()> {
        let mut g = self.socket.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        s.connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    #[must_use]
    pub fn from_socket(socket: UdpSocket) -> Self {
        Self {
            socket: Arc::new(Mutex::new(Some(socket))),
        }
    }

    /// Send one length-prefixed payload (`u32` BE length + bytes) in a single UDP datagram.
    ///
    /// The wire format is: 4-byte big-endian payload length followed by the payload bytes.
    /// This allows the receiver to know the exact payload size before reading the data.
    pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
        let len = u32::try_from(payload.len()).map_err(|_| {
            SrxError::Frame(FrameError::FrameTooLarge {
                size: payload.len(),
                max: MAX_DATAGRAM,
            })
        })?;
        // 4 bytes for length + payload must fit in a single UDP datagram
        let total = 4 + payload.len();
        if total > MAX_DATAGRAM {
            return Err(SrxError::Frame(FrameError::FrameTooLarge {
                size: total,
                max: MAX_DATAGRAM,
            }));
        }
        let mut framed = Vec::with_capacity(total);
        framed.extend_from_slice(&len.to_be_bytes());
        framed.extend_from_slice(payload);
        self.send(Bytes::from(framed)).await
    }

    /// Receive one length-prefixed payload from a single UDP datagram.
    ///
    /// Reads the entire datagram, extracts the `u32` BE length prefix,
    /// validates it, and returns the payload bytes.
    pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
        let data = self.recv().await?;
        if data.len() < 4 {
            return Err(SrxError::Frame(FrameError::Corrupted(
                "datagram too short for length prefix".into(),
            )));
        }
        let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
        let payload_len = len as usize;
        if data.len() != 4 + payload_len {
            return Err(SrxError::Frame(FrameError::Corrupted(
                "datagram size does not match length prefix".into(),
            )));
        }
        Ok(data.slice(4..))
    }
}

#[async_trait]
impl Transport for UdpTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::Udp
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut g = self.socket.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        s.send(&data)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut g = self.socket.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let mut buf = vec![0u8; MAX_DATAGRAM];
        let n = s
            .recv(&mut buf)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        buf.truncate(n);
        Ok(Bytes::from(buf))
    }

    async fn is_healthy(&self) -> bool {
        self.socket.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        *self.socket.lock().await = None;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn connected_udp_roundtrip() {
        let sa = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let sb = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let addr_a = sa.local_addr().unwrap();
        let addr_b = sb.local_addr().unwrap();
        sa.connect(addr_b).await.unwrap();
        sb.connect(addr_a).await.unwrap();

        let a = UdpTransport::from_socket(sa);
        let b = UdpTransport::from_socket(sb);

        let peer = tokio::spawn(async move {
            b.send(Bytes::from_static(b"ping")).await.unwrap();
            let r = b.recv().await.unwrap();
            assert_eq!(r.as_ref(), b"pong");
        });

        let first = a.recv().await.unwrap();
        assert_eq!(first.as_ref(), b"ping");
        a.send(Bytes::from_static(b"pong")).await.unwrap();
        peer.await.unwrap();
    }

    #[tokio::test]
    async fn framed_roundtrip_udp() {
        let sa = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let sb = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let addr_a = sa.local_addr().unwrap();
        let addr_b = sb.local_addr().unwrap();
        sa.connect(addr_b).await.unwrap();
        sb.connect(addr_a).await.unwrap();

        let a = UdpTransport::from_socket(sa);
        let b = UdpTransport::from_socket(sb);

        let payload = b"framed-payload-srx";

        let peer = tokio::spawn(async move {
            let got = b.recv_framed().await.unwrap();
            assert_eq!(got.as_ref(), payload);
            b.send_framed(b"ack").await.unwrap();
        });

        a.send_framed(payload).await.unwrap();
        let reply = a.recv_framed().await.unwrap();
        assert_eq!(reply.as_ref(), b"ack");
        peer.await.unwrap();
    }
}