stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! WebSocket transport — binary messages carry SRX frame bytes (`Message::Binary`).

use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use rustls::pki_types::ServerName;
use tokio::{
    io::{AsyncRead, AsyncWrite},
    net::TcpStream,
    sync::Mutex,
};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{WebSocketStream, accept_async, client_async};

use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};

/// Unified stream wrapper that can hold plain TCP, client TLS, or server TLS.
enum WsStream {
    Plain(TcpStream),
    ClientTls(tokio_rustls::client::TlsStream<TcpStream>),
    ServerTls(tokio_rustls::server::TlsStream<TcpStream>),
}

impl AsyncRead for WsStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            WsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
            WsStream::ClientTls(s) => Pin::new(s).poll_read(cx, buf),
            WsStream::ServerTls(s) => Pin::new(s).poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for WsStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        match self.get_mut() {
            WsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
            WsStream::ClientTls(s) => Pin::new(s).poll_write(cx, buf),
            WsStream::ServerTls(s) => Pin::new(s).poll_write(cx, buf),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
            WsStream::ClientTls(s) => Pin::new(s).poll_flush(cx),
            WsStream::ServerTls(s) => Pin::new(s).poll_flush(cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.get_mut() {
            WsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
            WsStream::ClientTls(s) => Pin::new(s).poll_shutdown(cx),
            WsStream::ServerTls(s) => Pin::new(s).poll_shutdown(cx),
        }
    }
}

impl Unpin for WsStream {}

/// Tokio WebSocket over TCP or TLS (`ws://` / `wss://`).
pub struct WebSocketTransport {
    ws: Arc<Mutex<Option<WebSocketStream<WsStream>>>>,
}

impl WebSocketTransport {
    /// Connect to `ws://` URL (plain TCP).
    ///
    /// For TLS (`wss://`), use [`Self::connect_tls`] instead.
    pub async fn connect(uri: impl AsRef<str>) -> crate::error::Result<Self> {
        // Parse the URI to extract host:port for TCP connection.
        let uri_str = uri.as_ref();
        let addr = uri_str
            .strip_prefix("ws://")
            .unwrap_or(uri_str)
            .trim_end_matches('/');
        let tcp = TcpStream::connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let request = http::Request::builder()
            .method("GET")
            .header("Host", addr)
            .header("Connection", "Upgrade")
            .header("Upgrade", "websocket")
            .header("Sec-WebSocket-Version", "13")
            .header(
                "Sec-WebSocket-Key",
                tokio_tungstenite::tungstenite::handshake::client::generate_key(),
            )
            .uri(uri_str)
            .body(())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let (ws, _) = client_async(request, WsStream::Plain(tcp))
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            ws: Arc::new(Mutex::new(Some(ws))),
        })
    }

    /// Connect via TLS (`wss://`) using an explicit rustls client config.
    ///
    /// This performs a manual TLS handshake over TCP before the WebSocket upgrade,
    /// giving full control over certificate verification (e.g. trusting a self-signed cert).
    pub async fn connect_tls(
        addr: impl tokio::net::ToSocketAddrs,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
    ) -> crate::error::Result<Self> {
        Self::connect_tls_pinned(addr, server_name, client_config, &[]).await
    }

    /// Same as [`Self::connect_tls`], but additionally enforces certificate pinning.
    pub async fn connect_tls_pinned(
        addr: impl tokio::net::ToSocketAddrs,
        server_name: ServerName<'static>,
        client_config: Arc<rustls::ClientConfig>,
        pins: &[CertPin],
    ) -> crate::error::Result<Self> {
        let tcp = TcpStream::connect(addr)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let host = server_name.to_str().to_string();
        let connector = TlsConnector::from(client_config);
        let tls_stream = connector
            .connect(server_name, tcp)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        verify_peer_cert_pins(tls_stream.get_ref().1.peer_certificates(), pins)?;
        let uri = format!("wss://{host}/");
        let request = http::Request::builder()
            .method("GET")
            .header("Host", host.as_str())
            .header("Connection", "Upgrade")
            .header("Upgrade", "websocket")
            .header("Sec-WebSocket-Version", "13")
            .header(
                "Sec-WebSocket-Key",
                tokio_tungstenite::tungstenite::handshake::client::generate_key(),
            )
            .uri(&uri)
            .body(())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let (ws, _) = client_async(request, WsStream::ClientTls(tls_stream))
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            ws: Arc::new(Mutex::new(Some(ws))),
        })
    }

    /// Accept an inbound WebSocket on an already-accepted TCP connection (plain `ws:` only).
    pub async fn accept(stream: TcpStream) -> crate::error::Result<Self> {
        let ws = accept_async(WsStream::Plain(stream))
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            ws: Arc::new(Mutex::new(Some(ws))),
        })
    }

    /// Accept an inbound **TLS** WebSocket (`wss:`) on an already-accepted TCP connection.
    ///
    /// Performs the TLS handshake via [`TlsAcceptor`] then the WebSocket upgrade.
    pub async fn accept_tls(
        stream: TcpStream,
        server_config: Arc<rustls::ServerConfig>,
    ) -> crate::error::Result<Self> {
        let acceptor = TlsAcceptor::from(server_config);
        let tls_stream = acceptor
            .accept(stream)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let ws = accept_async(WsStream::ServerTls(tls_stream))
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self {
            ws: Arc::new(Mutex::new(Some(ws))),
        })
    }
}

#[async_trait]
impl Transport for WebSocketTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::WebSocket
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut g = self.ws.lock().await;
        let ws = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        ws.send(Message::Binary(data))
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut g = self.ws.lock().await;
        let ws = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        loop {
            let msg = ws.next().await.transpose().map_err(|e| {
                SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
            })?;
            let Some(msg) = msg else {
                return Err(SrxError::Transport(TransportError::ChannelClosed));
            };
            match msg {
                Message::Binary(b) => return Ok(b),
                Message::Ping(p) => {
                    ws.send(Message::Pong(p)).await.map_err(|e| {
                        SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
                    })?;
                }
                Message::Close(_) => {
                    return Err(SrxError::Transport(TransportError::ChannelClosed));
                }
                Message::Pong(_) | Message::Frame(_) | Message::Text(_) => {}
            }
        }
    }

    async fn is_healthy(&self) -> bool {
        self.ws.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        let mut g = self.ws.lock().await;
        if let Some(mut ws) = g.take() {
            let _ = SinkExt::close(&mut ws).await;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};

    #[tokio::test]
    async fn ws_binary_roundtrip() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let uri = format!("ws://{addr}/");

        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let t = WebSocketTransport::accept(stream).await.unwrap();
            let got = t.recv().await.unwrap();
            assert_eq!(got.as_ref(), b"ws-ping");
            t.send(Bytes::from_static(b"ws-pong")).await.unwrap();
        });

        let client = WebSocketTransport::connect(&uri).await.unwrap();
        client.send(Bytes::from_static(b"ws-ping")).await.unwrap();
        let reply = client.recv().await.unwrap();
        assert_eq!(reply.as_ref(), b"ws-pong");
        client.close().await.unwrap();

        server.await.unwrap();
    }

    #[tokio::test]
    async fn ws_tls_binary_roundtrip() {
        let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
        let cert_der = CertificateDer::from(ck.cert);
        let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());

        let server_cfg = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(vec![cert_der.clone()], key.into())
            .unwrap();
        let server_cfg = Arc::new(server_cfg);

        let mut roots = rustls::RootCertStore::empty();
        roots.add(cert_der.clone()).unwrap();
        let client_cfg = rustls::ClientConfig::builder()
            .with_root_certificates(Arc::new(roots))
            .with_no_client_auth();
        let client_cfg = Arc::new(client_cfg);
        let server_name = ServerName::try_from("localhost".to_string()).unwrap();

        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let t = WebSocketTransport::accept_tls(stream, server_cfg)
                .await
                .unwrap();
            let got = t.recv().await.unwrap();
            assert_eq!(got.as_ref(), b"wss-ping");
            t.send(Bytes::from_static(b"wss-pong")).await.unwrap();
        });

        let client = WebSocketTransport::connect_tls(addr, server_name, client_cfg)
            .await
            .unwrap();
        client.send(Bytes::from_static(b"wss-ping")).await.unwrap();
        let reply = client.recv().await.unwrap();
        assert_eq!(reply.as_ref(), b"wss-pong");
        client.close().await.unwrap();

        server.await.unwrap();
    }

    #[tokio::test]
    async fn ws_tls_pinned_rejects_wrong_pin() {
        let ck = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
        let cert_der = CertificateDer::from(ck.cert);
        let key = PrivatePkcs8KeyDer::from(ck.signing_key.serialize_der());

        let server_cfg = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(vec![cert_der.clone()], key.into())
            .unwrap();
        let server_cfg = Arc::new(server_cfg);

        let mut roots = rustls::RootCertStore::empty();
        roots.add(cert_der).unwrap();
        let client_cfg = rustls::ClientConfig::builder()
            .with_root_certificates(Arc::new(roots))
            .with_no_client_auth();
        let client_cfg = Arc::new(client_cfg);
        let server_name = ServerName::try_from("localhost".to_string()).unwrap();
        let wrong_pin = [0x42u8; 32];

        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let _ = WebSocketTransport::accept_tls(stream, server_cfg).await;
        });

        let res =
            WebSocketTransport::connect_tls_pinned(addr, server_name, client_cfg, &[wrong_pin])
                .await;
        assert!(res.is_err(), "expected pin mismatch");

        server.abort();
    }
}