stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! TLS-wrapped TCP transport — same [`Transport`] and framed API as [`super::TcpTransport`], over `tokio-rustls`.

use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::TlsConnector;
use tokio_rustls::TlsStream;

use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};

const RECV_BUF: usize = 65_536;

/// Async TLS-over-TCP [`Transport`] using Tokio + rustls.
pub struct TlsTcpTransport {
    stream: Arc<Mutex<Option<TlsStream<TcpStream>>>>,
}

impl TlsTcpTransport {
    /// Build a rustls server config from a self-signed rcgen-style cert + PKCS#8 key (DER).
    pub fn server_config_from_der(
        cert_der: CertificateDer<'static>,
        key_der: PrivatePkcs8KeyDer<'static>,
    ) -> crate::error::Result<Arc<rustls::ServerConfig>> {
        let cfg = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(vec![cert_der], key_der.into())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Arc::new(cfg))
    }

    /// Build a rustls client config trusting a single server certificate (e.g. from the same helper as the server).
    pub fn client_config_trust_server_cert(
        cert_der: &CertificateDer<'_>,
    ) -> crate::error::Result<Arc<rustls::ClientConfig>> {
        let mut roots = rustls::RootCertStore::empty();
        roots
            .add(cert_der.clone())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let cfg = rustls::ClientConfig::builder()
            .with_root_certificates(Arc::new(roots))
            .with_no_client_auth();
        Ok(Arc::new(cfg))
    }

    /// Map a hostname to rustls [`ServerName`] (for SNI / certificate verification).
    pub fn server_name_dns(host: &str) -> crate::error::Result<ServerName<'static>> {
        ServerName::try_from(host.to_string()).map_err(|_| {
            SrxError::Transport(TransportError::ConnectionFailed(format!(
                "invalid TLS server name: {host}"
            )))
        })
    }

    /// TLS handshake over TCP to `addr`; `server_name` must match the certificate (e.g. `localhost`).
    pub async fn connect(
        addr: impl tokio::net::ToSocketAddrs,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
    ) -> crate::error::Result<Self> {
        Self::connect_pinned(addr, server_name, client_config, &[]).await
    }

    /// Same as [`Self::connect`], but additionally enforces certificate pinning.
    ///
    /// Pin verification matches SHA-256 over raw certificate DER bytes.
    /// At least one certificate in the peer chain must match a provided pin.
    pub async fn connect_pinned(
        addr: impl tokio::net::ToSocketAddrs,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
        pins: &[CertPin],
    ) -> crate::error::Result<Self> {
        let tcp = TcpStream::connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let connector = TlsConnector::from(client_config);
        let tls_client = connector
            .connect(server_name, tcp)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        verify_peer_cert_pins(tls_client.get_ref().1.peer_certificates(), pins)?;
        let tls: TlsStream<TcpStream> = tls_client.into();
        Ok(Self {
            stream: Arc::new(Mutex::new(Some(tls))),
        })
    }

    /// Server: `accept` one TCP connection from `listener` and perform TLS (`TlsAcceptor`).
    pub async fn accept(
        listener: &TcpListener,
        server_config: Arc<rustls::ServerConfig>,
    ) -> crate::error::Result<Self> {
        let (tcp, _) = listener
            .accept()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let acceptor = TlsAcceptor::from(server_config);
        let tls: TlsStream<TcpStream> = acceptor
            .accept(tcp)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?
            .into();
        Ok(Self {
            stream: Arc::new(Mutex::new(Some(tls))),
        })
    }

    /// Wrap an established TLS stream.
    #[must_use]
    pub fn from_tls_stream(stream: TlsStream<TcpStream>) -> Self {
        Self {
            stream: Arc::new(Mutex::new(Some(stream))),
        }
    }

    /// Same wire as [`super::TcpTransport::send_framed`].
    pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        write_length_prefixed(stream, payload).await
    }

    /// Same wire as [`super::TcpTransport::recv_framed`].
    pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let v = read_length_prefixed(stream).await?;
        Ok(Bytes::from(v))
    }
}

#[async_trait]
impl Transport for TlsTcpTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::Tcp
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        stream
            .write_all(&data)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        stream
            .flush()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let mut buf = vec![0u8; RECV_BUF];
        let n = stream
            .read(&mut buf)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        if n == 0 {
            return Err(SrxError::Transport(TransportError::ChannelClosed));
        }
        buf.truncate(n);
        Ok(Bytes::from(buf))
    }

    async fn is_healthy(&self) -> bool {
        self.stream.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        if let Some(mut s) = guard.take() {
            let _ = s.shutdown().await;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};

    #[tokio::test]
    async fn tls_framed_roundtrip() {
        let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
        let cert_der = CertificateDer::from(ck.cert);
        let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
        let server_cfg = TlsTcpTransport::server_config_from_der(cert_der.clone(), key).unwrap();
        let client_cfg = TlsTcpTransport::client_config_trust_server_cert(&cert_der).unwrap();
        let name = TlsTcpTransport::server_name_dns("localhost").unwrap();

        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let t = TlsTcpTransport::accept(&listener, server_cfg)
                .await
                .unwrap();
            let got = t.recv_framed().await.unwrap();
            assert_eq!(got.as_ref(), b"tls-payload");
            t.send_framed(b"tls-ack").await.unwrap();
        });

        let client = TlsTcpTransport::connect(addr, name, client_cfg)
            .await
            .unwrap();
        client.send_framed(b"tls-payload").await.unwrap();
        let reply = client.recv_framed().await.unwrap();
        assert_eq!(reply.as_ref(), b"tls-ack");
        client.close().await.unwrap();

        server.await.unwrap();
    }

    #[tokio::test]
    async fn tls_pinned_rejects_wrong_pin() {
        let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
        let cert_der = CertificateDer::from(ck.cert);
        let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());
        let server_cfg = TlsTcpTransport::server_config_from_der(cert_der.clone(), key).unwrap();
        let client_cfg = TlsTcpTransport::client_config_trust_server_cert(&cert_der).unwrap();
        let name = TlsTcpTransport::server_name_dns("localhost").unwrap();
        let wrong_pin = [0xA5u8; 32];

        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let _ = TlsTcpTransport::accept(&listener, server_cfg).await;
        });

        let res = TlsTcpTransport::connect_pinned(addr, name, client_cfg, &[wrong_pin]).await;
        assert!(res.is_err(), "expected pin mismatch");

        server.abort();
    }
}