Skip to main content

orca_proxy/
sni.rs

1//! Peek at TLS ClientHello to extract SNI hostname without consuming bytes.
2//!
3//! Used by the HTTPS listener to decide between local TLS termination
4//! and TCP pass-through to a fallback backend.
5
6use tokio::net::TcpStream;
7
8/// Peek at the TLS ClientHello buffered on `stream` and return the SNI hostname.
9///
10/// Returns `None` if the data isn't a valid ClientHello, no SNI extension is
11/// present, or the buffer is too short. Uses `TcpStream::peek` so the bytes
12/// remain available for the subsequent TLS handshake.
13pub async fn peek_sni(stream: &mut TcpStream) -> Option<String> {
14    let mut buf = [0u8; 2048];
15    let n = stream.peek(&mut buf).await.ok()?;
16    parse_sni(&buf[..n])
17}
18
19/// Parse a TLS ClientHello byte slice and extract the SNI hostname.
20pub fn parse_sni(buf: &[u8]) -> Option<String> {
21    // TLS record: type(1) + version(2) + length(2)
22    if buf.len() < 5 || buf[0] != 0x16 {
23        return None;
24    }
25
26    // Handshake header: type(1) + length(3) at offset 5
27    if buf.len() < 9 || buf[5] != 0x01 {
28        return None;
29    }
30
31    // Skip: handshake header (4) + version (2) + random (32) = 38 bytes after offset 5
32    let mut pos = 5 + 38;
33    if buf.len() < pos + 1 {
34        return None;
35    }
36
37    // Session ID
38    let sid_len = buf[pos] as usize;
39    pos += 1 + sid_len;
40    if buf.len() < pos + 2 {
41        return None;
42    }
43
44    // Cipher suites
45    let cs_len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
46    pos += 2 + cs_len;
47    if buf.len() < pos + 1 {
48        return None;
49    }
50
51    // Compression methods
52    let cm_len = buf[pos] as usize;
53    pos += 1 + cm_len;
54    if buf.len() < pos + 2 {
55        return None;
56    }
57
58    // Extensions block
59    let ext_total = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
60    pos += 2;
61    let ext_end = (pos + ext_total).min(buf.len());
62
63    while pos + 4 <= ext_end {
64        let ext_type = u16::from_be_bytes([buf[pos], buf[pos + 1]]);
65        let ext_len = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]) as usize;
66        pos += 4;
67        if ext_type == 0x0000 {
68            // server_name extension: list_len(2) + name_type(1) + name_len(2) + name
69            if pos + 5 > ext_end {
70                return None;
71            }
72            let name_len = u16::from_be_bytes([buf[pos + 3], buf[pos + 4]]) as usize;
73            let name_start = pos + 5;
74            if name_start + name_len > ext_end {
75                return None;
76            }
77            return std::str::from_utf8(&buf[name_start..name_start + name_len])
78                .ok()
79                .map(String::from);
80        }
81        pos += ext_len;
82    }
83    None
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    /// Build a minimal valid ClientHello with the given SNI.
91    fn build_clienthello(sni: &str) -> Vec<u8> {
92        let sni_bytes = sni.as_bytes();
93        let name_len = sni_bytes.len();
94        // server_name extension data: list_len(2) + name_type(1) + name_len(2) + name
95        let sn_data_len = 2 + 1 + 2 + name_len;
96        let extensions_len = 4 + sn_data_len; // ext_type(2) + ext_len(2) + data
97        let body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1 + 2 + extensions_len;
98        let record_len = 4 + body_len;
99        let mut buf = vec![0x16, 0x03, 0x01];
100        buf.extend_from_slice(&(record_len as u16).to_be_bytes());
101        buf.push(0x01);
102        buf.extend_from_slice(&[
103            (body_len >> 16) as u8,
104            (body_len >> 8) as u8,
105            body_len as u8,
106        ]);
107        buf.extend_from_slice(&[0x03, 0x03]); // version
108        buf.extend_from_slice(&[0u8; 32]); // random
109        buf.push(0); // session_id_len
110        buf.extend_from_slice(&[0x00, 0x02, 0x00, 0x35]); // 1 cipher suite
111        buf.push(1); // compression methods len
112        buf.push(0); // null compression
113        buf.extend_from_slice(&(extensions_len as u16).to_be_bytes());
114        buf.extend_from_slice(&[0x00, 0x00]); // ext_type = server_name
115        buf.extend_from_slice(&(sn_data_len as u16).to_be_bytes());
116        buf.extend_from_slice(&((1 + 2 + name_len) as u16).to_be_bytes()); // list_len
117        buf.push(0); // name_type = host_name
118        buf.extend_from_slice(&(name_len as u16).to_be_bytes());
119        buf.extend_from_slice(sni_bytes);
120        buf
121    }
122
123    #[test]
124    fn parse_minimal_clienthello() {
125        let buf = build_clienthello("example.com");
126        assert_eq!(parse_sni(&buf), Some("example.com".to_string()));
127    }
128
129    #[test]
130    fn parse_clienthello_with_subdomain() {
131        let buf = build_clienthello("api.test.example.com");
132        assert_eq!(parse_sni(&buf), Some("api.test.example.com".to_string()));
133    }
134
135    #[test]
136    fn returns_none_for_non_handshake() {
137        let buf = [0x17, 0x03, 0x03, 0x00, 0x05];
138        assert_eq!(parse_sni(&buf), None);
139    }
140
141    #[test]
142    fn returns_none_for_truncated() {
143        assert_eq!(parse_sni(&[0x16, 0x03]), None);
144        assert_eq!(parse_sni(&[]), None);
145    }
146
147    #[test]
148    fn returns_none_for_not_clienthello() {
149        // Handshake record but not ClientHello (type=0x02 = ServerHello)
150        let buf = [0x16, 0x03, 0x01, 0x00, 0x10, 0x02];
151        assert_eq!(parse_sni(&buf), None);
152    }
153}