use std::fmt;
pub const MAX_RECORD_LEN: usize = 16384;
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum SniParseError {
TooShort,
NotHandshake,
RecordTooLarge,
NotClientHello,
Truncated,
InvalidLength,
}
impl fmt::Display for SniParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SniParseError::TooShort => write!(f, "buffer too short for TLS record header"),
SniParseError::NotHandshake => write!(f, "first byte is not handshake (type=22)"),
SniParseError::RecordTooLarge => {
write!(f, "TLS record length exceeds {MAX_RECORD_LEN} bytes")
}
SniParseError::NotClientHello => {
write!(f, "handshake type is not ClientHello (type=1)")
}
SniParseError::Truncated => write!(f, "ClientHello truncated mid-field"),
SniParseError::InvalidLength => {
write!(f, "ClientHello length field overflows container")
}
}
}
}
impl std::error::Error for SniParseError {}
fn take<'a>(buf: &'a [u8], idx: &mut usize, n: usize) -> Result<&'a [u8], SniParseError> {
let end = idx.checked_add(n).ok_or(SniParseError::InvalidLength)?;
if end > buf.len() {
return Err(SniParseError::Truncated);
}
let s = &buf[*idx..end];
*idx = end;
Ok(s)
}
fn read_u8(buf: &[u8], idx: &mut usize) -> Result<u8, SniParseError> {
Ok(take(buf, idx, 1)?[0])
}
fn read_u16(buf: &[u8], idx: &mut usize) -> Result<u16, SniParseError> {
let s = take(buf, idx, 2)?;
Ok(u16::from_be_bytes([s[0], s[1]]))
}
fn read_u24(buf: &[u8], idx: &mut usize) -> Result<u32, SniParseError> {
let s = take(buf, idx, 3)?;
Ok(u32::from_be_bytes([0, s[0], s[1], s[2]]))
}
pub fn extract_sni(client_hello: &[u8]) -> Result<Option<String>, SniParseError> {
let mut idx = 0usize;
if client_hello.len() < 5 {
return Err(SniParseError::TooShort);
}
let record_type = read_u8(client_hello, &mut idx)?;
if record_type != 22 {
return Err(SniParseError::NotHandshake);
}
let _ = read_u16(client_hello, &mut idx)?;
let record_len = read_u16(client_hello, &mut idx)? as usize;
if record_len > MAX_RECORD_LEN {
return Err(SniParseError::RecordTooLarge);
}
let record_end = idx
.checked_add(record_len)
.ok_or(SniParseError::InvalidLength)?;
if record_end > client_hello.len() {
return Err(SniParseError::Truncated);
}
let record_end_clamp = record_end.min(client_hello.len());
let hs_type = read_u8(client_hello, &mut idx)?;
if hs_type != 1 {
return Err(SniParseError::NotClientHello);
}
let hs_len = read_u24(client_hello, &mut idx)? as usize;
let hs_end = idx
.checked_add(hs_len)
.ok_or(SniParseError::InvalidLength)?;
if hs_end > record_end_clamp {
return Err(SniParseError::InvalidLength);
}
if hs_end - idx < 2 + 32 {
return Err(SniParseError::Truncated);
}
let _ = read_u16(client_hello, &mut idx)?;
let _ = take(client_hello, &mut idx, 32)?;
let sid_len = read_u8(client_hello, &mut idx)? as usize;
if sid_len > 32 {
return Err(SniParseError::InvalidLength);
}
let _ = take(client_hello, &mut idx, sid_len)?;
let cs_len = read_u16(client_hello, &mut idx)? as usize;
if !cs_len.is_multiple_of(2) {
return Err(SniParseError::InvalidLength);
}
if idx + cs_len > hs_end {
return Err(SniParseError::InvalidLength);
}
let _ = take(client_hello, &mut idx, cs_len)?;
let comp_len = read_u8(client_hello, &mut idx)? as usize;
if idx + comp_len > hs_end {
return Err(SniParseError::InvalidLength);
}
let _ = take(client_hello, &mut idx, comp_len)?;
if idx == hs_end {
return Ok(None);
}
let ext_total = read_u16(client_hello, &mut idx)? as usize;
if idx + ext_total > hs_end {
return Err(SniParseError::InvalidLength);
}
let ext_end = idx + ext_total;
while idx + 4 <= ext_end {
let ext_type = read_u16(client_hello, &mut idx)?;
let ext_len = read_u16(client_hello, &mut idx)? as usize;
if idx + ext_len > ext_end {
return Err(SniParseError::InvalidLength);
}
if ext_type == 0 {
return parse_server_name_extension(&client_hello[idx..idx + ext_len]);
}
idx += ext_len;
}
Ok(None)
}
fn parse_server_name_extension(body: &[u8]) -> Result<Option<String>, SniParseError> {
let mut idx = 0usize;
let list_len = read_u16(body, &mut idx)? as usize;
if idx + list_len > body.len() {
return Err(SniParseError::InvalidLength);
}
let list_end = idx + list_len;
while idx + 3 <= list_end {
let name_type = read_u8(body, &mut idx)?;
let name_len = read_u16(body, &mut idx)? as usize;
if idx + name_len > list_end {
return Err(SniParseError::InvalidLength);
}
if name_type == 0 {
let raw = &body[idx..idx + name_len];
if raw.is_empty() {
return Ok(None);
}
let mut s = String::from_utf8_lossy(raw).to_string();
s.make_ascii_lowercase();
if s.ends_with('.') {
s.pop();
}
return Ok(Some(s));
}
idx += name_len;
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
fn build_client_hello(snis: &[&str]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]); body.extend_from_slice(&[0u8; 32]); body.push(0); body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); body.extend_from_slice(&[0x01, 0x00]);
let mut ext_section = Vec::new();
if !snis.is_empty() {
let mut sn_body = Vec::new();
let mut inner = Vec::new();
for s in snis {
inner.push(0u8); inner.extend_from_slice(&(s.len() as u16).to_be_bytes());
inner.extend_from_slice(s.as_bytes());
}
sn_body.extend_from_slice(&(inner.len() as u16).to_be_bytes()); sn_body.extend_from_slice(&inner);
ext_section.extend_from_slice(&[0x00, 0x00]); ext_section.extend_from_slice(&(sn_body.len() as u16).to_be_bytes());
ext_section.extend_from_slice(&sn_body);
}
body.extend_from_slice(&(ext_section.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_section);
let mut hs = Vec::new();
hs.push(1); let body_len_bytes = (body.len() as u32).to_be_bytes();
hs.extend_from_slice(&body_len_bytes[1..]); hs.extend_from_slice(&body);
let mut rec = Vec::new();
rec.push(22); rec.extend_from_slice(&[0x03, 0x01]); rec.extend_from_slice(&(hs.len() as u16).to_be_bytes());
rec.extend_from_slice(&hs);
rec
}
#[test]
fn extracts_well_formed_sni() {
let bytes = build_client_hello(&["api.example.com"]);
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("api.example.com"));
}
#[test]
fn no_sni_returns_ok_none() {
let bytes = build_client_hello(&[]);
assert_eq!(extract_sni(&bytes), Ok(None));
}
#[test]
fn malformed_too_short_record_header() {
let bytes = vec![22, 0x03];
assert_eq!(extract_sni(&bytes), Err(SniParseError::TooShort));
}
#[test]
fn non_handshake_record_type() {
let bytes = vec![23, 0x03, 0x03, 0x00, 0x10, 0xff, 0xff];
assert_eq!(extract_sni(&bytes), Err(SniParseError::NotHandshake));
}
#[test]
fn oversized_record_length_rejected() {
let bytes = vec![22, 0x03, 0x03, 0xff, 0xff, 0x01];
assert_eq!(extract_sni(&bytes), Err(SniParseError::RecordTooLarge));
}
#[test]
fn truncated_random_is_truncation_error() {
let mut bytes = build_client_hello(&["api.example.com"]);
bytes.truncate(21);
assert!(matches!(
extract_sni(&bytes),
Err(SniParseError::Truncated) | Err(SniParseError::InvalidLength)
));
}
#[test]
fn multiple_sni_returns_first() {
let bytes = build_client_hello(&["first.example.com", "second.example.com"]);
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("first.example.com"));
}
#[test]
fn ipv4_literal_sni_parses_but_is_caller_concern() {
let bytes = build_client_hello(&["192.0.2.1"]);
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("192.0.2.1"));
}
#[test]
fn empty_sni_string_returns_ok_none() {
let bytes = build_client_hello(&[""]);
assert_eq!(extract_sni(&bytes), Ok(None));
}
#[test]
fn trailing_dot_is_stripped() {
let bytes = build_client_hello(&["api.example.com."]);
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("api.example.com"));
}
#[test]
fn uppercase_sni_is_lowercased() {
let bytes = build_client_hello(&["API.Example.COM"]);
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("api.example.com"));
}
#[test]
fn tls13_record_version_accepted() {
let mut bytes = build_client_hello(&["modern.example.com"]);
bytes[9] = 0x03;
bytes[10] = 0x04;
let sni = extract_sni(&bytes).unwrap();
assert_eq!(sni.as_deref(), Some("modern.example.com"));
}
#[test]
fn handshake_type_must_be_one() {
let mut bytes = build_client_hello(&["api.example.com"]);
bytes[5] = 2; assert_eq!(extract_sni(&bytes), Err(SniParseError::NotClientHello));
}
#[test]
fn odd_cipher_suites_length_rejected() {
let mut bytes = build_client_hello(&["api.example.com"]);
bytes[44] = 0x00;
bytes[45] = 0x03;
assert_eq!(extract_sni(&bytes), Err(SniParseError::InvalidLength));
}
}