stochastic-routing-extended 1.0.2

SRX (Stochastic Routing eXtended) — a next-generation VPN protocol with stochastic routing, DPI evasion, post-quantum cryptography, and multi-transport channel splitting
Documentation
//! HTTP `CONNECT` tunnel — after `HTTP/1.1 200`, the socket carries opaque SRX bytes (same semantics as raw TCP `send`/`recv`).

use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;

use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};

const RECV_BUF: usize = 65_536;

/// Byte tunnel after a successful HTTP proxy `CONNECT`.
pub struct HttpTunnelTransport {
    stream: Arc<Mutex<Option<TcpStream>>>,
}

impl HttpTunnelTransport {
    /// Wrap an established TCP connection (e.g. after you handled `CONNECT` yourself).
    #[must_use]
    pub fn from_tcp(stream: TcpStream) -> Self {
        Self {
            stream: Arc::new(Mutex::new(Some(stream))),
        }
    }

    /// Connect to `proxy`, issue `CONNECT target_authority HTTP/1.1`, expect `200`.
    ///
    /// `target_authority` is typically `host:port` as sent in the request line.
    pub async fn connect_via_proxy(
        proxy: SocketAddr,
        target_authority: &str,
    ) -> crate::error::Result<Self> {
        let mut stream = TcpStream::connect(proxy)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let req = format!(
            "CONNECT {target_authority} HTTP/1.1\r\nHost: {target_authority}\r\nProxy-Connection: keep-alive\r\n\r\n"
        );
        stream
            .write_all(req.as_bytes())
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;

        let mut buf = Vec::with_capacity(2048);
        let mut tmp = [0u8; 256];
        loop {
            let n = stream.read(&mut tmp).await.map_err(|e| {
                SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
            })?;
            if n == 0 {
                return Err(SrxError::Transport(TransportError::ChannelClosed));
            }
            buf.extend_from_slice(&tmp[..n]);
            if buf.windows(4).any(|w| w == b"\r\n\r\n") {
                break;
            }
            if buf.len() > 16 * 1024 {
                return Err(SrxError::Transport(TransportError::ConnectionFailed(
                    "CONNECT response headers too large".into(),
                )));
            }
        }
        let head_end = buf
            .windows(4)
            .position(|w| w == b"\r\n\r\n")
            .expect("checked")
            + 4;
        let head = std::str::from_utf8(&buf[..head_end])
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        let status = head.lines().next().unwrap_or("");
        if !status.starts_with("HTTP/1.") || !status.contains("200") {
            return Err(SrxError::Transport(TransportError::ConnectionFailed(
                format!("CONNECT failed: {status}"),
            )));
        }
        Ok(Self {
            stream: Arc::new(Mutex::new(Some(stream))),
        })
    }

    /// Send one length-prefixed payload (`u32` BE length + bytes) over the HTTP tunnel.
    pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        write_length_prefixed(stream, payload).await
    }

    /// Receive one length-prefixed payload (full message) from the HTTP tunnel.
    pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
        let mut guard = self.stream.lock().await;
        let stream = guard
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let v = read_length_prefixed(stream).await?;
        Ok(Bytes::from(v))
    }
}

#[async_trait]
impl Transport for HttpTunnelTransport {
    fn kind(&self) -> TransportKind {
        TransportKind::Http2
    }

    async fn send(&self, data: Bytes) -> crate::error::Result<()> {
        let mut g = self.stream.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        s.write_all(&data)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        s.flush()
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        Ok(())
    }

    async fn recv(&self) -> crate::error::Result<Bytes> {
        let mut g = self.stream.lock().await;
        let s = g
            .as_mut()
            .ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
        let mut buf = vec![0u8; RECV_BUF];
        let n = s
            .read(&mut buf)
            .await
            .map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
        if n == 0 {
            return Err(SrxError::Transport(TransportError::ChannelClosed));
        }
        buf.truncate(n);
        Ok(Bytes::from(buf))
    }

    async fn is_healthy(&self) -> bool {
        self.stream.lock().await.is_some()
    }

    async fn close(&self) -> crate::error::Result<()> {
        let mut g = self.stream.lock().await;
        if let Some(mut s) = g.take() {
            let _ = s.shutdown().await;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn connect_proxy_roundtrip() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let proxy_addr = listener.local_addr().unwrap();
        let target = "127.0.0.1:9";

        let serve = tokio::spawn(async move {
            let (mut stream, _) = listener.accept().await.unwrap();
            let mut buf = vec![0u8; 2048];
            let mut total = 0usize;
            loop {
                let n = stream.read(&mut buf[total..]).await.unwrap();
                if n == 0 {
                    panic!("eof");
                }
                total += n;
                if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
                    break;
                }
            }
            stream
                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
                .await
                .unwrap();
            let mut out = [0u8; 64];
            let n = stream.read(&mut out).await.unwrap();
            assert_eq!(&out[..n], b"ping");
            stream.write_all(b"pong").await.unwrap();
        });

        let t = HttpTunnelTransport::connect_via_proxy(proxy_addr, target)
            .await
            .unwrap();
        t.send(Bytes::from_static(b"ping")).await.unwrap();
        let r = t.recv().await.unwrap();
        assert_eq!(r.as_ref(), b"pong");
        t.close().await.unwrap();

        serve.await.unwrap();
    }

    #[tokio::test]
    async fn framed_roundtrip_over_http_tunnel() {
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let proxy_addr = listener.local_addr().unwrap();
        let target = "127.0.0.1:9";
        let payload = b"framed-payload-srx";

        let serve = tokio::spawn(async move {
            let (mut stream, _) = listener.accept().await.unwrap();
            // Read CONNECT request
            let mut buf = vec![0u8; 2048];
            let mut total = 0usize;
            loop {
                let n = stream.read(&mut buf[total..]).await.unwrap();
                if n == 0 {
                    panic!("eof");
                }
                total += n;
                if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
                    break;
                }
            }
            stream
                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
                .await
                .unwrap();
            // Use framed transport on the server side
            let t = HttpTunnelTransport::from_tcp(stream);
            let got = t.recv_framed().await.unwrap();
            assert_eq!(got.as_ref(), payload);
            t.send_framed(b"ack").await.unwrap();
        });

        let t = HttpTunnelTransport::connect_via_proxy(proxy_addr, target)
            .await
            .unwrap();
        t.send_framed(payload).await.unwrap();
        let reply = t.recv_framed().await.unwrap();
        assert_eq!(reply.as_ref(), b"ack");
        t.close().await.unwrap();

        serve.await.unwrap();
    }
}