runewarp 0.1.0

Runewarp is an ingress tunneling tool for exposing local services without moving TLS termination to the edge. Clients connect out over QUIC, so you can publish services without putting your backend directly on the Internet or leaking your public IP.
Documentation
use std::fmt;
use std::io::Cursor;

use rustls::server::Acceptor;
use tokio::io::{AsyncRead, AsyncReadExt};

use crate::hostname::validate_public_hostname;

pub const CLIENT_HELLO_BUFFER_LIMIT: usize = 16 * 1024;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedClientHello {
    buffered_bytes: Vec<u8>,
    server_name: String,
    alpn_protocols: Vec<Vec<u8>>,
}

impl ParsedClientHello {
    pub fn buffered_bytes(&self) -> &[u8] {
        &self.buffered_bytes
    }

    pub fn into_buffered_bytes(self) -> Vec<u8> {
        self.buffered_bytes
    }

    pub fn server_name(&self) -> &str {
        &self.server_name
    }

    pub fn offers_alpn_protocol(&self, protocol: &[u8]) -> bool {
        self.alpn_protocols
            .iter()
            .any(|offered| offered.as_slice() == protocol)
    }

    pub fn into_parts(self) -> (String, Vec<u8>) {
        (self.server_name, self.buffered_bytes)
    }
}

#[derive(Debug)]
pub enum ClientHelloError {
    Io(std::io::Error),
    InvalidTls,
    InvalidSni,
    MissingSni,
    TooLong { limit: usize },
    UnexpectedEof,
}

impl fmt::Display for ClientHelloError {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Io(_) => formatter.write_str("client hello IO error"),
            Self::InvalidTls => formatter.write_str("invalid TLS client hello"),
            Self::InvalidSni => formatter.write_str("invalid SNI in client hello"),
            Self::MissingSni => formatter.write_str("missing SNI in client hello"),
            Self::TooLong { limit } => {
                write!(formatter, "client hello exceeded the {limit}-byte limit")
            }
            Self::UnexpectedEof => {
                formatter.write_str("client connection closed before the client hello completed")
            }
        }
    }
}

impl std::error::Error for ClientHelloError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            Self::Io(error) => Some(error),
            Self::InvalidTls
            | Self::InvalidSni
            | Self::MissingSni
            | Self::TooLong { .. }
            | Self::UnexpectedEof => None,
        }
    }
}

pub async fn read_client_hello<R>(reader: &mut R) -> Result<ParsedClientHello, ClientHelloError>
where
    R: AsyncRead + Unpin,
{
    let mut acceptor = Acceptor::default();
    let mut buffered_bytes = Vec::new();
    let mut read_buffer = [0_u8; 4096];

    loop {
        let read = reader
            .read(&mut read_buffer)
            .await
            .map_err(ClientHelloError::Io)?;
        if read == 0 {
            return Err(ClientHelloError::UnexpectedEof);
        }

        buffered_bytes.extend_from_slice(&read_buffer[..read]);
        if buffered_bytes.len() > CLIENT_HELLO_BUFFER_LIMIT {
            return Err(ClientHelloError::TooLong {
                limit: CLIENT_HELLO_BUFFER_LIMIT,
            });
        }

        let mut cursor = Cursor::new(&read_buffer[..read]);
        acceptor
            .read_tls(&mut cursor)
            .map_err(ClientHelloError::Io)?;

        match acceptor.accept() {
            Ok(Some(accepted)) => {
                let server_name = accepted
                    .client_hello()
                    .server_name()
                    .ok_or(ClientHelloError::MissingSni)?
                    .to_owned();
                let server_name = validate_public_hostname(&server_name)
                    .map_err(|_| ClientHelloError::InvalidSni)?;
                let alpn_protocols = accepted
                    .client_hello()
                    .alpn()
                    .into_iter()
                    .flatten()
                    .map(|protocol| protocol.to_vec())
                    .collect();

                return Ok(ParsedClientHello {
                    buffered_bytes,
                    server_name,
                    alpn_protocols,
                });
            }
            Ok(None) => {}
            Err((_error, _alert)) => return Err(ClientHelloError::InvalidTls),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::VecDeque;
    use std::io;
    use std::net::{IpAddr, Ipv4Addr};
    use std::pin::Pin;
    use std::sync::Arc;
    use std::task::{Context, Poll};

    use proptest::collection::vec as prop_vec;
    use proptest::prelude::*;
    use rcgen::generate_simple_self_signed;
    use rustls::ClientConnection;
    use rustls::RootCertStore;
    use rustls::pki_types::{CertificateDer, ServerName};
    use tokio::io::{AsyncRead, ReadBuf};

    use crate::hostname::validate_public_hostname;

    use super::{
        CLIENT_HELLO_BUFFER_LIMIT, ClientHelloError, ParsedClientHello, read_client_hello,
    };

    #[tokio::test]
    async fn parses_sni_from_a_valid_client_hello() {
        let server_name = "app.example.test";
        let client_hello = build_client_hello(ServerName::try_from(server_name).unwrap());
        let parsed = parse_from_chunks(vec![client_hello.clone()]).await.unwrap();

        assert_eq!(parsed.server_name(), server_name);
        assert_eq!(parsed.buffered_bytes(), client_hello.as_slice());
    }

    #[tokio::test]
    async fn preserves_extra_bytes_read_past_the_client_hello() {
        let client_hello = build_client_hello(ServerName::try_from("app.example.test").unwrap());
        let mut buffered = client_hello.clone();
        buffered.extend_from_slice(&[0x14, 0x03, 0x03, 0x00, 0x01, 0x01]);

        let parsed = parse_from_chunks(vec![buffered.clone()]).await.unwrap();

        assert_eq!(parsed.server_name(), "app.example.test");
        assert_eq!(parsed.buffered_bytes(), buffered.as_slice());
    }

    #[tokio::test]
    async fn normalizes_the_parsed_server_name_for_routing() {
        let client_hello = build_client_hello(ServerName::try_from("App.Example.Test").unwrap());

        let parsed = parse_from_chunks(vec![client_hello]).await.unwrap();

        assert_eq!(parsed.server_name(), "app.example.test");
    }

    #[tokio::test]
    async fn parses_a_client_hello_split_across_multiple_tls_records() {
        let client_hello = build_client_hello(ServerName::try_from("app.example.test").unwrap());
        let split_client_hello = split_tls_record(&client_hello, 19);
        let chunks = split_client_hello
            .iter()
            .copied()
            .map(|byte| vec![byte])
            .collect();

        let parsed = parse_from_chunks(chunks).await.unwrap();

        assert_eq!(parsed.server_name(), "app.example.test");
        assert_eq!(parsed.buffered_bytes(), split_client_hello.as_slice());
    }

    #[tokio::test]
    async fn rejects_a_client_hello_without_sni() {
        let server_name = ServerName::IpAddress(IpAddr::V4(Ipv4Addr::LOCALHOST).into());
        let client_hello = build_client_hello(server_name);
        let error = parse_from_chunks(vec![client_hello]).await.unwrap_err();

        assert!(matches!(error, ClientHelloError::MissingSni));
    }

    #[tokio::test]
    async fn enforces_the_client_hello_size_limit() {
        let mut oversized = vec![0x16, 0x03, 0x03, 0x40, 0x01];
        oversized.extend(std::iter::repeat_n(0_u8, CLIENT_HELLO_BUFFER_LIMIT));

        let error = parse_from_chunks(vec![oversized]).await.unwrap_err();

        assert!(matches!(
            error,
            ClientHelloError::TooLong {
                limit: CLIENT_HELLO_BUFFER_LIMIT
            }
        ));
    }

    async fn parse_from_chunks(
        chunks: Vec<Vec<u8>>,
    ) -> Result<ParsedClientHello, ClientHelloError> {
        let mut reader = ChunkedReader::new(chunks);
        read_client_hello(&mut reader).await
    }

    fn parse_from_chunks_blocking(
        chunks: Vec<Vec<u8>>,
    ) -> Result<ParsedClientHello, ClientHelloError> {
        tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .expect("property tests should be able to build a runtime")
            .block_on(parse_from_chunks(chunks))
    }

    fn build_client_hello(server_name: ServerName<'static>) -> Vec<u8> {
        let trusted_cert = generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap();
        let cert_der = CertificateDer::from(trusted_cert.cert);
        let mut roots = RootCertStore::empty();
        roots.add(cert_der).unwrap();

        let config = rustls::ClientConfig::builder()
            .with_root_certificates(roots)
            .with_no_client_auth();
        let mut connection = ClientConnection::new(Arc::new(config), server_name).unwrap();
        let mut bytes = Vec::new();
        connection.write_tls(&mut bytes).unwrap();
        bytes
    }

    fn split_tls_record(bytes: &[u8], first_record_body_len: usize) -> Vec<u8> {
        assert!(bytes.len() > 5);

        let header = &bytes[..5];
        let body = &bytes[5..];
        assert!(first_record_body_len > 0);
        assert!(first_record_body_len < body.len());

        let second_record_body_len = body.len() - first_record_body_len;
        let mut fragmented = Vec::with_capacity(bytes.len() + 5);

        fragmented.extend_from_slice(&[
            header[0],
            header[1],
            header[2],
            ((first_record_body_len >> 8) & 0xff) as u8,
            (first_record_body_len & 0xff) as u8,
        ]);
        fragmented.extend_from_slice(&body[..first_record_body_len]);
        fragmented.extend_from_slice(&[
            header[0],
            header[1],
            header[2],
            ((second_record_body_len >> 8) & 0xff) as u8,
            (second_record_body_len & 0xff) as u8,
        ]);
        fragmented.extend_from_slice(&body[first_record_body_len..]);

        fragmented
    }

    fn chunk_bytes(bytes: Vec<u8>, chunk_sizes: Vec<usize>) -> Vec<Vec<u8>> {
        let mut chunks = Vec::new();
        let mut offset = 0;
        for chunk_size in chunk_sizes {
            if offset >= bytes.len() {
                break;
            }
            let end = (offset + chunk_size).min(bytes.len());
            chunks.push(bytes[offset..end].to_vec());
            offset = end;
        }
        if offset < bytes.len() || chunks.is_empty() {
            chunks.push(bytes[offset..].to_vec());
        }
        chunks
    }

    struct ChunkedReader {
        chunks: VecDeque<Vec<u8>>,
    }

    impl ChunkedReader {
        fn new(chunks: Vec<Vec<u8>>) -> Self {
            Self {
                chunks: chunks.into(),
            }
        }
    }

    impl AsyncRead for ChunkedReader {
        fn poll_read(
            mut self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            buffer: &mut ReadBuf<'_>,
        ) -> Poll<io::Result<()>> {
            let Some(front) = self.chunks.front_mut() else {
                return Poll::Ready(Ok(()));
            };

            let written = front.len().min(buffer.remaining());
            buffer.put_slice(&front[..written]);

            if written == front.len() {
                self.chunks.pop_front();
            } else {
                front.drain(..written);
            }

            Poll::Ready(Ok(()))
        }
    }

    proptest! {
        #![proptest_config(ProptestConfig::with_cases(16))]

        #[test]
        fn preserves_valid_client_hellos_across_chunk_boundaries(
            chunk_sizes in prop_vec(1usize..32, 0..32),
            trailing_bytes in prop_vec(any::<u8>(), 0..32),
        ) {
            let client_hello = build_client_hello(ServerName::try_from("App.Example.Test").unwrap());
            let mut buffered = client_hello.clone();
            buffered.extend_from_slice(&trailing_bytes);

            let parsed = parse_from_chunks_blocking(chunk_bytes(buffered.clone(), chunk_sizes)).unwrap();

            prop_assert_eq!(parsed.server_name(), "app.example.test");
            prop_assert!(buffered.starts_with(parsed.buffered_bytes()));
            prop_assert!(parsed.buffered_bytes().len() >= client_hello.len());
        }

        #[test]
        fn arbitrary_client_hello_bytes_only_succeed_with_validated_server_names(
            bytes in prop_vec(any::<u8>(), 0..(CLIENT_HELLO_BUFFER_LIMIT + 8)),
        ) {
            let result = parse_from_chunks_blocking(vec![bytes.clone()]);

            if let Ok(parsed) = result {
                prop_assert!(validate_public_hostname(parsed.server_name()).is_ok());
                prop_assert_eq!(parsed.buffered_bytes(), bytes.as_slice());
            }
        }
    }
}