zlayer-proxy 0.13.0

High-performance reverse proxy with TLS termination and L4/L7 routing
Documentation
//! Minimal, defensive TLS `ClientHello` SNI parser.
//!
//! This is used by the HTTPS ingress to peek at the SNI host name in a
//! `ClientHello` *before* terminating TLS, so an unmanaged SNI can be
//! TCP-spliced straight to its real upstream instead of hanging the client
//! when no matching certificate exists.
//!
//! The parser is deliberately tiny and never panics: every multi-byte read is
//! bounds-checked and any malformed / truncated input yields `None`.

/// Read a big-endian `u16` from `buf[off..off + 2]`, returning `None` if the
/// slice is too short.
#[inline]
fn read_u16(buf: &[u8], off: usize) -> Option<u16> {
    let hi = u16::from(*buf.get(off)?);
    let lo = u16::from(*buf.get(off + 1)?);
    Some((hi << 8) | lo)
}

/// Extract the first `server_name` (SNI) host from a raw TLS `ClientHello`.
///
/// `buf` should contain the bytes that arrived on the wire starting at the TLS
/// record header. The function walks:
///
/// 1. TLS record header (`content_type == 22` handshake, skip 2 version bytes
///    + 2 length bytes),
/// 2. handshake header (`type == 1` `ClientHello`, 3 length bytes),
/// 3. client version (2) + random (32) + `session_id` (1 + len),
/// 4. `cipher_suites` (2 + len) + `compression_methods` (1 + len),
/// 5. extensions (2 length bytes, then `type`/`len`/`body` records), looking for
///    extension type `0` (`server_name`),
/// 6. inside `server_name`: a 2-byte list length, then entries of
///    `type` (1) + `len` (2) + `name`; the first `type == 0` (`host_name`) entry
///    is returned.
///
/// Returns `None` on any malformed, truncated, or SNI-less input. Never panics.
#[must_use]
pub fn parse_sni(buf: &[u8]) -> Option<String> {
    // --- TLS record header (5 bytes) ---
    // [0]      content_type (22 = handshake)
    // [1..3]   legacy record version
    // [3..5]   record length
    if *buf.first()? != 22 {
        return None;
    }
    let record_len = read_u16(buf, 3)? as usize;
    let record_end = 5usize.checked_add(record_len)?;
    // Clamp the handshake view to what's actually present; a ClientHello may be
    // larger than the bytes peeked so far, but the SNI extension is normally
    // near the front, so we parse whatever we have.
    let end = record_end.min(buf.len());
    let hs = buf.get(5..end)?;

    // --- Handshake header (4 bytes) ---
    // [0]      handshake type (1 = ClientHello)
    // [1..4]   handshake body length (3 bytes, big-endian)
    if *hs.first()? != 1 {
        return None;
    }
    let mut p = 4usize; // skip handshake type + 3 length bytes

    // client_version (2)
    p = p.checked_add(2)?;
    // random (32)
    p = p.checked_add(32)?;

    // session_id: 1-byte length + body
    let session_id_len = *hs.get(p)? as usize;
    p = p.checked_add(1)?.checked_add(session_id_len)?;

    // cipher_suites: 2-byte length + body
    let cipher_len = read_u16(hs, p)? as usize;
    p = p.checked_add(2)?.checked_add(cipher_len)?;

    // compression_methods: 1-byte length + body
    let comp_len = *hs.get(p)? as usize;
    p = p.checked_add(1)?.checked_add(comp_len)?;

    // extensions: 2-byte total length, then a sequence of extension records.
    let ext_total = read_u16(hs, p)? as usize;
    p = p.checked_add(2)?;
    let ext_end = p.checked_add(ext_total)?.min(hs.len());

    while p + 4 <= ext_end {
        let ext_type = read_u16(hs, p)?;
        let ext_len = read_u16(hs, p + 2)? as usize;
        let body_start = p + 4;
        let body_end = body_start.checked_add(ext_len)?;
        if body_end > ext_end {
            return None;
        }

        if ext_type == 0 {
            // server_name extension body:
            // [0..2]  server_name_list length
            // then entries: type(1) + len(2) + name
            let snl = hs.get(body_start..body_end)?;
            return parse_server_name_list(snl);
        }

        p = body_end;
    }

    None
}

/// Parse the body of a `server_name` extension and return the first `host_name`.
fn parse_server_name_list(snl: &[u8]) -> Option<String> {
    let list_len = read_u16(snl, 0)? as usize;
    let mut q = 2usize;
    let list_end = q.checked_add(list_len)?.min(snl.len());

    while q + 3 <= list_end {
        let name_type = *snl.get(q)?;
        let name_len = read_u16(snl, q + 1)? as usize;
        let name_start = q + 3;
        let name_end = name_start.checked_add(name_len)?;
        if name_end > list_end {
            return None;
        }

        if name_type == 0 {
            let raw = snl.get(name_start..name_end)?;
            // host_name must be valid UTF-8 (ASCII in practice).
            return std::str::from_utf8(raw).ok().map(str::to_string);
        }

        q = name_end;
    }

    None
}

#[cfg(test)]
#[allow(clippy::cast_possible_truncation)] // test fixtures build fixed-size TLS records
mod tests {
    use super::*;

    /// Build a minimal but well-formed TLS `ClientHello` record carrying the
    /// given SNI host (or none when `sni` is `None`).
    fn build_client_hello(sni: Option<&str>) -> Vec<u8> {
        // --- extensions ---
        let mut extensions = Vec::new();
        if let Some(host) = sni {
            let host = host.as_bytes();
            // server_name entry: type(0) + len(2) + name
            let mut entry = Vec::new();
            entry.push(0u8); // host_name
            entry.extend_from_slice(&(host.len() as u16).to_be_bytes());
            entry.extend_from_slice(host);
            // server_name_list: 2-byte length + entry
            let mut snl = Vec::new();
            snl.extend_from_slice(&(entry.len() as u16).to_be_bytes());
            snl.extend_from_slice(&entry);
            // extension: type(0) + len(2) + body
            extensions.extend_from_slice(&0u16.to_be_bytes());
            extensions.extend_from_slice(&(snl.len() as u16).to_be_bytes());
            extensions.extend_from_slice(&snl);
        }
        // Add an unrelated extension to make sure we skip past it correctly.
        // supported_versions (type 43), trivial body.
        let dummy_body = [0x02u8, 0x03, 0x04];
        extensions.extend_from_slice(&43u16.to_be_bytes());
        extensions.extend_from_slice(&(dummy_body.len() as u16).to_be_bytes());
        extensions.extend_from_slice(&dummy_body);

        // --- handshake body ---
        let mut body = Vec::new();
        body.extend_from_slice(&[0x03, 0x03]); // client_version TLS 1.2
        body.extend_from_slice(&[0u8; 32]); // random
        body.push(0u8); // session_id length 0
                        // cipher_suites: one suite
        body.extend_from_slice(&2u16.to_be_bytes());
        body.extend_from_slice(&[0x13, 0x01]);
        // compression_methods: null
        body.push(1u8);
        body.push(0u8);
        // extensions
        body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
        body.extend_from_slice(&extensions);

        // --- handshake header ---
        let mut hs = Vec::new();
        hs.push(1u8); // ClientHello
        let blen = body.len();
        hs.push(((blen >> 16) & 0xff) as u8);
        hs.push(((blen >> 8) & 0xff) as u8);
        hs.push((blen & 0xff) as u8);
        hs.extend_from_slice(&body);

        // --- record header ---
        let mut rec = Vec::new();
        rec.push(22u8); // handshake
        rec.extend_from_slice(&[0x03, 0x01]); // legacy version
        rec.extend_from_slice(&(hs.len() as u16).to_be_bytes());
        rec.extend_from_slice(&hs);
        rec
    }

    #[test]
    fn parses_sni_example_com() {
        let buf = build_client_hello(Some("example.com"));
        assert_eq!(parse_sni(&buf).as_deref(), Some("example.com"));
    }

    #[test]
    fn parses_sni_subdomain() {
        let buf = build_client_hello(Some("api.service.internal"));
        assert_eq!(parse_sni(&buf).as_deref(), Some("api.service.internal"));
    }

    #[test]
    fn no_sni_extension_returns_none() {
        let buf = build_client_hello(None);
        assert_eq!(parse_sni(&buf), None);
    }

    #[test]
    fn truncated_returns_none() {
        let buf = build_client_hello(Some("example.com"));
        // Cut off mid-handshake — must not panic, must return None.
        for cut in [0usize, 1, 5, 6, 10, 20, buf.len() / 2] {
            let cut = cut.min(buf.len());
            assert_eq!(parse_sni(&buf[..cut]), None, "cut={cut}");
        }
    }

    #[test]
    fn non_handshake_record_returns_none() {
        let mut buf = build_client_hello(Some("example.com"));
        buf[0] = 23; // application_data, not handshake
        assert_eq!(parse_sni(&buf), None);
    }

    #[test]
    fn not_a_client_hello_returns_none() {
        let mut buf = build_client_hello(Some("example.com"));
        buf[5] = 2; // ServerHello handshake type
        assert_eq!(parse_sni(&buf), None);
    }

    #[test]
    fn empty_input_returns_none() {
        assert_eq!(parse_sni(&[]), None);
        assert_eq!(parse_sni(&[22]), None);
        assert_eq!(parse_sni(&[22, 3, 1]), None);
    }

    #[test]
    fn garbage_does_not_panic() {
        // A pile of adversarial inputs; the contract is "never panic".
        for seed in 0u32..2000 {
            let len = (seed % 64) as usize;
            let v: Vec<u8> = (0..len)
                .map(|i| (seed.wrapping_mul(31) ^ i as u32) as u8)
                .collect();
            let _ = parse_sni(&v);
        }
        // Also a record header that claims a huge length.
        let buf = [22u8, 3, 1, 0xff, 0xff, 1, 0, 0, 0];
        let _ = parse_sni(&buf);
    }
}