stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! High-level SRX API: secure framed TCP session with automatic handshake.
//!
//! This API hides handshake/session bootstrapping and provides simple
//! send/receive primitives for application payloads and in-band signals.

use std::net::SocketAddr;
use std::sync::Arc;

use tokio::sync::Mutex;
use tracing::{debug, info};

use crate::config::SrxConfig;
use crate::error::{Result, SessionError, SrxError};
use crate::pipeline::Payload;
use crate::session::{Handshake, Session};
use crate::signaling::inband::Signal;
use crate::transport::TcpTransport;
use crate::{SrxNode, TransportManager};

/// Secure framed TCP session built on SRX.
pub struct SecureTcpSession {
    transport: Arc<TcpTransport>,
    node: Mutex<SrxNode>,
}

impl SecureTcpSession {
    /// Connect as initiator and establish secure SRX session over TCP.
    pub async fn connect(addr: SocketAddr, config: SrxConfig) -> Result<Self> {
        info!("SRX: connecting to {addr}");
        let stream = tokio::net::TcpStream::connect(addr).await?;
        debug!("SRX: TCP connection established to {addr}");
        let transport = Arc::new(TcpTransport::from_stream(stream));
        Self::connect_over_transport(transport, config).await
    }

    /// Accept as responder and establish secure SRX session over accepted stream.
    pub async fn accept(stream: tokio::net::TcpStream, config: SrxConfig) -> Result<Self> {
        debug!("SRX: accepting incoming session");
        let transport = Arc::new(TcpTransport::from_stream(stream));
        Self::accept_over_transport(transport, config).await
    }

    async fn connect_over_transport(
        transport: Arc<TcpTransport>,
        config: SrxConfig,
    ) -> Result<Self> {
        debug!("SRX handshake: initiating as client");
        let mut hs = Handshake::new_initiator();
        let client_hello = hs.client_hello()?;
        debug!("SRX handshake: sending ClientHello ({} bytes)", client_hello.len());
        transport.send_framed(&client_hello).await?;
        debug!("SRX handshake: waiting for ServerHello");
        let server_hello = transport.recv_framed().await?;
        debug!("SRX handshake: received ServerHello ({} bytes)", server_hello.len());
        let client_finished = hs.finalize(&server_hello)?;
        debug!("SRX handshake: sending ClientFinished ({} bytes)", client_finished.len());
        transport.send_framed(&client_finished).await?;
        let master = hs.master_secret().ok_or_else(|| {
            SrxError::Session(SessionError::HandshakeFailed(
                "master secret not established".to_string(),
            ))
        })?;

        let session = session_from_client_hello(false, master, &client_hello)?;
        let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
        let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
        info!("SRX handshake: session established (initiator)");
        Ok(Self {
            transport,
            node: Mutex::new(node),
        })
    }

    async fn accept_over_transport(
        transport: Arc<TcpTransport>,
        config: SrxConfig,
    ) -> Result<Self> {
        debug!("SRX handshake: initiating as server (responder)");
        let mut hs = Handshake::new_responder();
        debug!("SRX handshake: waiting for ClientHello");
        let client_hello = transport.recv_framed().await?;
        debug!("SRX handshake: received ClientHello ({} bytes)", client_hello.len());
        let server_hello = hs.server_hello(&client_hello)?;
        debug!("SRX handshake: sending ServerHello ({} bytes)", server_hello.len());
        transport.send_framed(&server_hello).await?;
        debug!("SRX handshake: waiting for ClientFinished");
        let client_finished = transport.recv_framed().await?;
        debug!("SRX handshake: received ClientFinished ({} bytes)", client_finished.len());
        hs.server_finish(&client_finished)?;
        let master = hs.master_secret().ok_or_else(|| {
            SrxError::Session(SessionError::HandshakeFailed(
                "master secret not established".to_string(),
            ))
        })?;

        let session = session_from_client_hello(true, master, &client_hello)?;
        let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
        let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
        info!("SRX handshake: session established (responder)");
        Ok(Self {
            transport,
            node: Mutex::new(node),
        })
    }

    /// Send application data over secure channel.
    pub async fn send_data(&self, payload: &[u8]) -> Result<()> {
        let envelope = {
            let mut node = self.node.lock().await;
            node.prepare_outgoing(payload)?
        };
        self.transport.send_framed(&envelope).await?;
        Ok(())
    }

    /// Send in-band control signal over secure channel.
    pub async fn send_signal(&self, signal: &Signal) -> Result<()> {
        let envelope = {
            let mut node = self.node.lock().await;
            node.send_signal(signal)?
        };
        self.transport.send_framed(&envelope).await?;
        Ok(())
    }

    /// Receive next encrypted frame and decode it into data or signal.
    pub async fn recv(&self) -> Result<Payload> {
        let envelope = self.transport.recv_framed().await?;
        let node = self.node.lock().await;
        node.process_incoming_dispatched(&envelope)
    }
}

fn session_from_client_hello(
    is_server: bool,
    master_secret: [u8; 32],
    client_hello: &[u8],
) -> Result<Session> {
    let (timestamp, nonce) = parse_client_hello(client_hello)?;
    let session_id = if is_server { 1 } else { 2 };
    Session::from_master_secret(session_id, &master_secret, timestamp, &nonce)
}

fn parse_client_hello(client_hello: &[u8]) -> Result<(u64, [u8; 16])> {
    if client_hello.len() < 10 {
        return Err(SrxError::Session(SessionError::HandshakeFailed(
            "ClientHello too short".to_string(),
        )));
    }
    if &client_hello[0..4] != b"SRXH" {
        return Err(SrxError::Session(SessionError::HandshakeFailed(
            "Invalid magic in ClientHello".to_string(),
        )));
    }
    if client_hello[5] != 1 {
        return Err(SrxError::Session(SessionError::HandshakeFailed(
            "Invalid message type in ClientHello".to_string(),
        )));
    }
    let payload_len = u32::from_be_bytes(client_hello[6..10].try_into().expect("len")) as usize;
    if client_hello.len() != 10 + payload_len || payload_len != 24 {
        return Err(SrxError::Session(SessionError::HandshakeFailed(
            "Invalid ClientHello payload length".to_string(),
        )));
    }

    let payload = &client_hello[10..];
    let timestamp = u64::from_be_bytes(payload[0..8].try_into().expect("timestamp"));
    let mut nonce = [0u8; 16];
    nonce.copy_from_slice(&payload[8..24]);
    Ok((timestamp, nonce))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn secure_tcp_session_roundtrip() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        let server = tokio::spawn(async move {
            let (stream, _) = listener.accept().await.unwrap();
            let mut cfg = SrxConfig::default();
            cfg.replay.persist_enabled = false;
            let sess = SecureTcpSession::accept(stream, cfg).await.unwrap();
            match sess.recv().await.unwrap() {
                Payload::Data(d) => {
                    assert_eq!(d, b"ping");
                    sess.send_data(b"pong").await.unwrap();
                }
                Payload::Signal(_) => panic!("expected data"),
            }
        });

        let mut cfg = SrxConfig::default();
        cfg.replay.persist_enabled = false;
        let sess = SecureTcpSession::connect(addr, cfg).await.unwrap();
        sess.send_data(b"ping").await.unwrap();
        match sess.recv().await.unwrap() {
            Payload::Data(d) => assert_eq!(d, b"pong"),
            Payload::Signal(_) => panic!("expected data"),
        }

        server.await.unwrap();
    }
}