stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! QUIC transport with a default bidirectional stream plus optional additional
//! logical bidirectional streams on the same QUIC connection.

use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use quinn::{ClientConfig, Connection, Endpoint, RecvStream, SendStream, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use tokio::sync::Mutex;

use super::pinning::{CertPin, verify_peer_cert_pins};
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};

/// One logical QUIC bidirectional stream carrying length-prefixed SRX payloads.
pub struct QuicStreamChannel {
    send: Arc<Mutex<Option<SendStream>>>,
    recv: Arc<Mutex<Option<RecvStream>>>,
}

impl QuicStreamChannel {
    fn from_bi(send: SendStream, recv: RecvStream) -> Self {
        Self {
            send: Arc::new(Mutex::new(Some(send))),
            recv: Arc::new(Mutex::new(Some(recv))),
        }
    }

    /// Send one length-prefixed payload on this stream.
    pub async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut g = self.send.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        write_length_prefixed(s, &data).await
    }

    /// Receive one length-prefixed payload on this stream.
    pub async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut g = self.recv.lock().await;
        let r = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let v = read_length_prefixed(r).await?;
        Ok(Bytes::from(v))
    }

    /// Check if both directions are still present.
    pub async fn is_healthy(&self) -> bool {
        self.send.lock().await.is_some() && self.recv.lock().await.is_some()
    }

    /// Close stream directions held by this channel.
    pub async fn close(&self) -> crate::error::Result<()> {
        let mut s = self.send.lock().await;
        if let Some(mut stream) = s.take() {
            let _ = stream.finish();
        }
        self.recv.lock().await.take();
        Ok(())
    }
}

/// QUIC transport using one default stream for [`Transport`] trait I/O,
/// while allowing additional logical streams on demand.
pub struct QuicTransport {
    conn: Arc<Connection>,
    default_stream: QuicStreamChannel,
}

impl QuicTransport {
    fn from_bi(conn: Connection, send: SendStream, recv: RecvStream) -> Self {
        Self {
            conn: Arc::new(conn),
            default_stream: QuicStreamChannel::from_bi(send, recv),
        }
    }

    /// Build a server [`ServerConfig`] and trust anchor (DER) for clients.
    pub fn server_config() -> crate::error::Result<(ServerConfig, CertificateDer<'static>)> {
        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let cert_der = CertificateDer::from(cert.cert);
        let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
        let server_config = ServerConfig::with_single_cert(vec![cert_der.clone()], key.into())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok((server_config, cert_der))
    }

    /// Client [`ClientConfig`] trusting a single server certificate (e.g. from [`Self::server_config`]).
    pub fn client_config_trust_server(
        cert_der: &CertificateDer<'_>,
    ) -> crate::error::Result<ClientConfig> {
        let mut roots = rustls::RootCertStore::empty();
        roots
            .add(cert_der.clone())
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        ClientConfig::with_root_certificates(Arc::new(roots))
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))
    }

    /// Client [`ClientConfig`] with Mozilla webpki roots (public CAs).
    pub fn client_config_webpki() -> crate::error::Result<ClientConfig> {
        let mut roots = rustls::RootCertStore::empty();
        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
        ClientConfig::with_root_certificates(Arc::new(roots))
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))
    }

    /// Connect from `bind` to `server`, open one default bidirectional stream.
    pub async fn connect(
        bind: SocketAddr,
        server: SocketAddr,
        server_name: &str,
        server_cert_der: &CertificateDer<'_>,
    ) -> crate::error::Result<Self> {
        Self::connect_pinned(bind, server, server_name, server_cert_der, &[]).await
    }

    /// Same as [`Self::connect`], but additionally enforces certificate pinning.
    pub async fn connect_pinned(
        bind: SocketAddr,
        server: SocketAddr,
        server_name: &str,
        server_cert_der: &CertificateDer<'_>,
        pins: &[CertPin],
    ) -> crate::error::Result<Self> {
        let mut endpoint = Endpoint::client(bind)
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let cfg = Self::client_config_trust_server(server_cert_der)?;
        endpoint.set_default_client_config(cfg);
        let conn = endpoint
            .connect(server, server_name)
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Self::verify_connection_pins(&conn, pins)?;
        let (send, recv) = conn
            .open_bi()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self::from_bi(conn, send, recv))
    }

    /// Server side: wait for the peer to open the default bidirectional stream.
    pub async fn accept_bi(conn: Connection) -> crate::error::Result<Self> {
        let (send, recv) = conn
            .accept_bi()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(Self::from_bi(conn, send, recv))
    }

    /// Open an additional logical bidirectional stream on this QUIC connection.
    pub async fn open_stream(&self) -> crate::error::Result<QuicStreamChannel> {
        let (send, recv) =
            self.conn.open_bi().await.map_err(|e| {
                SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
            })?;
        Ok(QuicStreamChannel::from_bi(send, recv))
    }

    /// Accept an inbound logical bidirectional stream opened by the peer.
    pub async fn accept_stream(&self) -> crate::error::Result<QuicStreamChannel> {
        let (send, recv) =
            self.conn.accept_bi().await.map_err(|e| {
                SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
            })?;
        Ok(QuicStreamChannel::from_bi(send, recv))
    }

    fn verify_connection_pins(conn: &Connection, pins: &[CertPin]) -> crate::error::Result<()> {
        if pins.is_empty() {
            return Ok(());
        }

        let identity = conn.peer_identity().ok_or_else(|| {
            SrxError::Transport(TransportError::ConnectionFailed(
                "missing QUIC peer identity for pin verification".into(),
            ))
        })?;
        let certs = identity
            .downcast::<Vec<CertificateDer<'static>>>()
            .map_err(|_| {
                SrxError::Transport(TransportError::ConnectionFailed(
                    "unexpected QUIC peer identity type for pin verification".into(),
                ))
            })?;
        verify_peer_cert_pins(Some(certs.as_slice()), pins)
    }
}

#[async_trait]
impl Transport for QuicTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::Quic
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        self.default_stream.send(data).await
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        self.default_stream.recv().await
    }

    async fn is_healthy(&self) -> bool {
        self.default_stream.is_healthy().await
    }

    async fn close(&self) -> crate::error::Result<()> {
        self.default_stream.close().await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn client_config_webpki_builds() {
        assert!(QuicTransport::client_config_webpki().is_ok());
    }

    #[tokio::test]
    async fn quic_length_prefixed_roundtrip() {
        let (server_config, cert_der) = QuicTransport::server_config().unwrap();
        let listen = SocketAddr::from(([127, 0, 0, 1], 0));
        let endpoint = Endpoint::server(server_config, listen).unwrap();
        let server_addr = endpoint.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let incoming = endpoint.accept().await.expect("accept");
            let conn = incoming.await.expect("conn");
            let t = QuicTransport::accept_bi(conn).await.expect("bi");
            let got = t.recv().await.expect("recv");
            assert_eq!(got.as_ref(), b"quic-ping");
            t.send(Bytes::from_static(b"quic-pong"))
                .await
                .expect("send");
            let _ = t.recv().await;
        });

        let client = QuicTransport::connect(
            SocketAddr::from(([127, 0, 0, 1], 0)),
            server_addr,
            "localhost",
            &cert_der,
        )
        .await
        .expect("client connect");

        client
            .send(Bytes::from_static(b"quic-ping"))
            .await
            .expect("send");
        let reply = client.recv().await.expect("recv");
        assert_eq!(reply.as_ref(), b"quic-pong");
        client.close().await.ok();

        server.await.unwrap();
    }

    #[tokio::test]
    async fn quic_pinned_rejects_wrong_pin() {
        let (server_config, cert_der) = QuicTransport::server_config().unwrap();
        let listen = SocketAddr::from(([127, 0, 0, 1], 0));
        let endpoint = Endpoint::server(server_config, listen).unwrap();
        let server_addr = endpoint.local_addr().unwrap();
        let wrong_pin = [0x33u8; 32];

        let server = tokio::spawn(async move {
            let incoming = endpoint.accept().await.expect("accept");
            let _ = incoming.await;
        });

        let res = QuicTransport::connect_pinned(
            SocketAddr::from(([127, 0, 0, 1], 0)),
            server_addr,
            "localhost",
            &cert_der,
            &[wrong_pin],
        )
        .await;
        assert!(res.is_err(), "expected pin mismatch");

        server.abort();
    }

    #[tokio::test]
    async fn quic_multi_stream_roundtrip() {
        let (server_config, cert_der) = QuicTransport::server_config().unwrap();
        let endpoint =
            Endpoint::server(server_config, SocketAddr::from(([127, 0, 0, 1], 0))).unwrap();
        let server_addr = endpoint.local_addr().unwrap();

        let server = tokio::spawn(async move {
            let incoming = endpoint.accept().await.expect("accept");
            let conn = incoming.await.expect("conn");
            let server_t = QuicTransport::accept_bi(conn)
                .await
                .expect("default stream");

            let got = server_t.recv().await.expect("default recv");
            assert_eq!(got.as_ref(), b"default-ping");
            server_t
                .send(Bytes::from_static(b"default-pong"))
                .await
                .expect("default send");

            let s1 = server_t.accept_stream().await.expect("accept stream1");
            let got1 = s1.recv().await.expect("recv stream1");
            assert_eq!(got1.as_ref(), b"stream-1");
            s1.send(Bytes::from_static(b"ack-1"))
                .await
                .expect("send stream1");

            let s2 = server_t.accept_stream().await.expect("accept stream2");
            let got2 = s2.recv().await.expect("recv stream2");
            assert_eq!(got2.as_ref(), b"stream-2");
            s2.send(Bytes::from_static(b"ack-2"))
                .await
                .expect("send stream2");

            // Keep connection alive until client confirms it received all acks.
            let done = server_t.recv().await.expect("done recv");
            assert_eq!(done.as_ref(), b"done");
        });

        let client_t = QuicTransport::connect(
            SocketAddr::from(([127, 0, 0, 1], 0)),
            server_addr,
            "localhost",
            &cert_der,
        )
        .await
        .expect("client connect");

        client_t
            .send(Bytes::from_static(b"default-ping"))
            .await
            .expect("default send");
        let default_reply = client_t.recv().await.expect("default recv");
        assert_eq!(default_reply.as_ref(), b"default-pong");

        let c1 = client_t.open_stream().await.expect("open stream1");
        c1.send(Bytes::from_static(b"stream-1"))
            .await
            .expect("send stream1");
        let ack1 = c1.recv().await.expect("recv stream1");
        assert_eq!(ack1.as_ref(), b"ack-1");

        let c2 = client_t.open_stream().await.expect("open stream2");
        c2.send(Bytes::from_static(b"stream-2"))
            .await
            .expect("send stream2");
        let ack2 = c2.recv().await.expect("recv stream2");
        assert_eq!(ack2.as_ref(), b"ack-2");

        client_t
            .send(Bytes::from_static(b"done"))
            .await
            .expect("done send");

        server.await.unwrap();
    }
}