use alloc::string::String;
use alloc::vec::Vec;
use crate::tls::codec::{ClientHello, ExtensionType, extension as ext, hs_type, read_record};
use crate::tls::{ContentType, Error};
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ClientHelloInfo {
pub server_name: Option<String>,
pub alpn_protocols: Vec<Vec<u8>>,
}
pub fn peek_client_hello(buf: &[u8]) -> Result<Option<ClientHelloInfo>, Error> {
let mut handshake: Vec<u8> = Vec::new();
let mut offset = 0usize;
loop {
if let Some(info) = parse_client_hello(&handshake)? {
return Ok(Some(info));
}
match read_record(&buf[offset..])? {
None => return Ok(None),
Some(rec) => {
if rec.content_type != ContentType::Handshake {
return Err(Error::UnexpectedMessage);
}
handshake.extend_from_slice(rec.fragment);
offset += rec.len;
}
}
}
}
fn parse_client_hello(handshake: &[u8]) -> Result<Option<ClientHelloInfo>, Error> {
if handshake.len() < 4 {
return Ok(None);
}
if handshake[0] != hs_type::CLIENT_HELLO {
return Err(Error::UnexpectedMessage);
}
let body_len =
((handshake[1] as usize) << 16) | ((handshake[2] as usize) << 8) | (handshake[3] as usize);
if handshake.len() < 4 + body_len {
return Ok(None);
}
let ch = ClientHello::decode(&handshake[4..4 + body_len])?;
let mut info = ClientHelloInfo::default();
for (ty, body) in &ch.extensions {
if *ty == ExtensionType::SERVER_NAME {
info.server_name = ext::parse_server_name(body)?;
} else if *ty == ExtensionType::ALPN {
info.alpn_protocols = ext::parse_alpn(body)?;
}
}
Ok(Some(info))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls::codec::CipherSuite;
fn records(msg: &[u8], chunk: usize) -> Vec<u8> {
let mut out = Vec::new();
for frag in msg.chunks(chunk.max(1)) {
out.push(ContentType::Handshake.as_u8());
out.extend_from_slice(&0x0301u16.to_be_bytes());
out.extend_from_slice(&(frag.len() as u16).to_be_bytes());
out.extend_from_slice(frag);
}
out
}
fn sample_client_hello() -> Vec<u8> {
ClientHello {
legacy_version: 0x0303,
random: [0x42u8; 32],
session_id: Vec::new(),
cipher_suites: alloc::vec![CipherSuite(0x1301)],
extensions: alloc::vec![
ext::server_name("example.com"),
ext::alpn_protocols(&[b"h2".as_slice(), b"acme-tls/1".as_slice()]),
],
}
.encode()
}
#[test]
fn peeks_sni_and_alpn_in_one_record() {
let buf = records(&sample_client_hello(), 4096);
let before = buf.clone();
let info = peek_client_hello(&buf).unwrap().unwrap();
assert_eq!(info.server_name.as_deref(), Some("example.com"));
assert_eq!(
info.alpn_protocols,
alloc::vec![b"h2".to_vec(), b"acme-tls/1".to_vec()]
);
assert_eq!(buf, before, "peek must not consume the buffer");
}
#[test]
fn reassembles_client_hello_split_across_records() {
let buf = records(&sample_client_hello(), 7);
let info = peek_client_hello(&buf).unwrap().unwrap();
assert_eq!(info.server_name.as_deref(), Some("example.com"));
assert_eq!(info.alpn_protocols.len(), 2);
}
#[test]
fn incomplete_buffer_needs_more_bytes() {
let full = records(&sample_client_hello(), 4096);
for n in 0..full.len() {
assert_eq!(
peek_client_hello(&full[..n]).unwrap(),
None,
"prefix of length {n} should ask for more bytes"
);
}
assert!(peek_client_hello(&full).unwrap().is_some());
}
#[test]
fn client_hello_without_sni_or_alpn() {
let msg = ClientHello {
legacy_version: 0x0303,
random: [0u8; 32],
session_id: Vec::new(),
cipher_suites: alloc::vec![CipherSuite(0x1301)],
extensions: Vec::new(),
}
.encode();
let info = peek_client_hello(&records(&msg, 4096)).unwrap().unwrap();
assert_eq!(info.server_name, None);
assert!(info.alpn_protocols.is_empty());
}
#[test]
fn non_handshake_first_record_is_rejected() {
let buf = alloc::vec![23u8, 0x03, 0x03, 0x00, 0x02, 0xAB, 0xCD];
assert!(peek_client_hello(&buf).is_err());
}
}