1use tokio::net::TcpStream;
7
8pub async fn peek_sni(stream: &mut TcpStream) -> Option<String> {
14 let mut buf = [0u8; 2048];
15 let n = stream.peek(&mut buf).await.ok()?;
16 parse_sni(&buf[..n])
17}
18
19pub fn parse_sni(buf: &[u8]) -> Option<String> {
21 if buf.len() < 5 || buf[0] != 0x16 {
23 return None;
24 }
25
26 if buf.len() < 9 || buf[5] != 0x01 {
28 return None;
29 }
30
31 let mut pos = 5 + 38;
33 if buf.len() < pos + 1 {
34 return None;
35 }
36
37 let sid_len = buf[pos] as usize;
39 pos += 1 + sid_len;
40 if buf.len() < pos + 2 {
41 return None;
42 }
43
44 let cs_len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
46 pos += 2 + cs_len;
47 if buf.len() < pos + 1 {
48 return None;
49 }
50
51 let cm_len = buf[pos] as usize;
53 pos += 1 + cm_len;
54 if buf.len() < pos + 2 {
55 return None;
56 }
57
58 let ext_total = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
60 pos += 2;
61 let ext_end = (pos + ext_total).min(buf.len());
62
63 while pos + 4 <= ext_end {
64 let ext_type = u16::from_be_bytes([buf[pos], buf[pos + 1]]);
65 let ext_len = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]) as usize;
66 pos += 4;
67 if ext_type == 0x0000 {
68 if pos + 5 > ext_end {
70 return None;
71 }
72 let name_len = u16::from_be_bytes([buf[pos + 3], buf[pos + 4]]) as usize;
73 let name_start = pos + 5;
74 if name_start + name_len > ext_end {
75 return None;
76 }
77 return std::str::from_utf8(&buf[name_start..name_start + name_len])
78 .ok()
79 .map(String::from);
80 }
81 pos += ext_len;
82 }
83 None
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 fn build_clienthello(sni: &str) -> Vec<u8> {
92 let sni_bytes = sni.as_bytes();
93 let name_len = sni_bytes.len();
94 let sn_data_len = 2 + 1 + 2 + name_len;
96 let extensions_len = 4 + sn_data_len; let body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1 + 2 + extensions_len;
98 let record_len = 4 + body_len;
99 let mut buf = vec![0x16, 0x03, 0x01];
100 buf.extend_from_slice(&(record_len as u16).to_be_bytes());
101 buf.push(0x01);
102 buf.extend_from_slice(&[
103 (body_len >> 16) as u8,
104 (body_len >> 8) as u8,
105 body_len as u8,
106 ]);
107 buf.extend_from_slice(&[0x03, 0x03]); buf.extend_from_slice(&[0u8; 32]); buf.push(0); buf.extend_from_slice(&[0x00, 0x02, 0x00, 0x35]); buf.push(1); buf.push(0); buf.extend_from_slice(&(extensions_len as u16).to_be_bytes());
114 buf.extend_from_slice(&[0x00, 0x00]); buf.extend_from_slice(&(sn_data_len as u16).to_be_bytes());
116 buf.extend_from_slice(&((1 + 2 + name_len) as u16).to_be_bytes()); buf.push(0); buf.extend_from_slice(&(name_len as u16).to_be_bytes());
119 buf.extend_from_slice(sni_bytes);
120 buf
121 }
122
123 #[test]
124 fn parse_minimal_clienthello() {
125 let buf = build_clienthello("example.com");
126 assert_eq!(parse_sni(&buf), Some("example.com".to_string()));
127 }
128
129 #[test]
130 fn parse_clienthello_with_subdomain() {
131 let buf = build_clienthello("api.test.example.com");
132 assert_eq!(parse_sni(&buf), Some("api.test.example.com".to_string()));
133 }
134
135 #[test]
136 fn returns_none_for_non_handshake() {
137 let buf = [0x17, 0x03, 0x03, 0x00, 0x05];
138 assert_eq!(parse_sni(&buf), None);
139 }
140
141 #[test]
142 fn returns_none_for_truncated() {
143 assert_eq!(parse_sni(&[0x16, 0x03]), None);
144 assert_eq!(parse_sni(&[]), None);
145 }
146
147 #[test]
148 fn returns_none_for_not_clienthello() {
149 let buf = [0x16, 0x03, 0x01, 0x00, 0x10, 0x02];
151 assert_eq!(parse_sni(&buf), None);
152 }
153}