use super::extension::TlsExtension;
use super::extension::SniParserError;
use super::extension::TlsExtensionType;
#[derive(Debug)]
pub enum TlsClientHelloError {
NotTLSHandshake,
NotClientHello,
MessageIncomplete(#[allow(dead_code)]usize),
}
#[derive(Debug)]
pub struct TlsClientHello {
protocol_version: (u8, u8),
random: Vec<u8>,
session_id: Vec<u8>,
cipher_suites: Vec<u16>,
compression_methods: Vec<u8>,
extensions: Vec<TlsExtension>,
}
#[allow(dead_code)]
impl TlsClientHello {
pub fn get_protocol_version_str(&self) -> String {
match self.protocol_version {
(3, 3) => "TLS 1.2".to_string(),
(3, 4) => "TLS 1.3".to_string(),
_ => format!(
"Unknown TLS Version {:?}.{:?}",
self.protocol_version.0, self.protocol_version.1
),
}
}
pub fn get_session_id_hex(&self) -> String {
self.session_id
.iter()
.map(|byte| format!("{:02X}", byte))
.collect::<Vec<String>>()
.join("")
}
pub fn get_random_hex(&self) -> String {
self.random
.iter()
.map(|byte| format!("{:02X}", byte))
.collect::<Vec<String>>()
.join("")
}
pub fn get_cipher_suites_str(&self) -> Vec<String> {
self.cipher_suites
.iter()
.map(|&cs| Self::cipher_suite_to_str(cs))
.collect()
}
pub fn get_compression_methods_str(&self) -> Vec<String> {
self.compression_methods
.iter()
.map(|&cm| Self::compression_method_to_str(cm))
.collect()
}
pub fn cipher_suite_to_str(suite: u16) -> String {
match suite {
0x1301 => "TLS_AES_128_GCM_SHA256",
0x1302 => "TLS_AES_256_GCM_SHA384",
0x1303 => "TLS_CHACHA20_POLY1305_SHA256",
0x1304 => "TLS_AES_128_CCM_SHA256",
0x1305 => "TLS_AES_128_CCM_8_SHA256",
0xC02B => "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
0xC02C => "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
0xC023 => "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
0xC024 => "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384",
0xC02F => "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xC030 => "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xC027 => "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
0xC028 => "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
0x009E => "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
0x009F => "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
0x0067 => "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
0x006B => "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
0x009C => "TLS_RSA_WITH_AES_128_GCM_SHA256",
0x009D => "TLS_RSA_WITH_AES_256_GCM_SHA384",
0x003C => "TLS_RSA_WITH_AES_128_CBC_SHA256",
0x003D => "TLS_RSA_WITH_AES_256_CBC_SHA256",
0x00AE => "TLS_PSK_WITH_AES_128_GCM_SHA256",
0x00AF => "TLS_PSK_WITH_AES_256_GCM_SHA384",
0x008C => "TLS_PSK_WITH_AES_128_CBC_SHA256",
0x008D => "TLS_PSK_WITH_AES_256_CBC_SHA384",
0x0005 => "TLS_RSA_WITH_RC4_128_SHA",
0x000A => "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
0x002F => "TLS_RSA_WITH_AES_128_CBC_SHA",
0x0035 => "TLS_RSA_WITH_AES_256_CBC_SHA",
_ => "UNKNOWN",
}
.to_string()
}
pub fn compression_method_to_str(method: u8) -> String {
match method {
0 => "null (no compression)".to_string(),
_ => format!("Unknown Compression Method 0x{:02X}", method),
}
}
pub fn read_sni_hostname(&self) -> Result<String, SniParserError> {
for extension in &self.extensions {
if let TlsExtensionType::ServerName = extension.typ {
if extension.data.len() < 5 {
return Err(SniParserError::InvalidExtensionFormat);
}
let name_type = extension.data[2];
if name_type != 0x00 {
continue;
}
let name_length =
u16::from_be_bytes([extension.data[3], extension.data[4]]) as usize;
if name_length > extension.data.len() - 5 {
return Err(SniParserError::InvalidExtensionFormat);
}
let hostname = std::str::from_utf8(&extension.data[5..5 + name_length])
.map_err(SniParserError::Utf8Error)?;
return Ok(hostname.to_string());
}
}
Err(SniParserError::NoSniFound)
}
}
impl TryFrom<&[u8]> for TlsClientHello {
type Error = TlsClientHelloError;
fn try_from(data: &[u8]) -> Result<Self, Self::Error> {
let mut current = 0;
if data.len() < 6 || data[current] != 0x16 {
return Err(TlsClientHelloError::NotTLSHandshake);
}
current += 1;
if data.len() < current + 4 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
current += 4;
if data.len() <= current || data[current] != 0x01 {
return Err(TlsClientHelloError::NotClientHello);
}
current += 1;
if data.len() < current + 3 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
current += 3;
if data.len() < current + 2 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let protocol_version = (data[current], data[current + 1]);
current += 2;
if data.len() < current + 32 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let random = data[current..current + 32].to_vec();
current += 32;
if data.len() <= current {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let session_id_length = data[current] as usize;
if data.len() < current + 1 + session_id_length {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let session_id = data[current + 1..current + 1 + session_id_length].to_vec();
current += 1 + session_id_length;
if data.len() < current + 2 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let cipher_suites_length = u16::from_be_bytes([data[current], data[current + 1]]) as usize;
if data.len() < current + 2 + cipher_suites_length {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let cipher_suites = (0..cipher_suites_length / 2)
.map(|i| u16::from_be_bytes([data[current + 2 + i * 2], data[current + 3 + i * 2]]))
.collect();
current += 2 + cipher_suites_length;
if data.len() <= current {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let compression_methods_length = data[current] as usize;
if data.len() < current + 1 + compression_methods_length {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let compression_methods =
data[current + 1..current + 1 + compression_methods_length].to_vec();
current += 1 + compression_methods_length;
if data.len() < current + 2 {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let extensions_length = u16::from_be_bytes([data[current], data[current + 1]]) as usize;
current += 2;
let mut extensions = Vec::new();
let extensions_end = current + extensions_length;
if data.len() < extensions_end {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
while current + 4 <= extensions_end {
let typ = u16::from_be_bytes([data[current], data[current + 1]]);
let extension_length =
u16::from_be_bytes([data[current + 2], data[current + 3]]) as usize;
current += 4;
if current + extension_length > extensions_end {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
if current + extension_length > data.len() {
return Err(TlsClientHelloError::MessageIncomplete(current));
}
let data = data[current..current + extension_length].to_vec();
extensions.push(TlsExtension::new(typ.into(), data));
current += extension_length; }
Ok(TlsClientHello {
protocol_version,
random,
session_id,
cipher_suites,
compression_methods,
extensions,
})
}
}