use std::fmt;
#[derive(Debug)]
pub struct TlsPacket {
pub content_type: TlsContentType,
pub version: TlsVersion,
pub length: u16,
pub payload: Vec<u8>,
}
impl fmt::Display for TlsPacket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"TLS Packet: content_type={}, version={}, length={}, payload={:02X?}",
self.content_type, self.version, self.length, self.payload
)
}
}
#[derive(Debug, PartialEq)]
pub enum TlsContentType {
ChangeCipherSpec = 20,
Alert = 21,
Handshake = 22,
ApplicationData = 23,
Heartbeat = 24,
}
impl fmt::Display for TlsContentType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
TlsContentType::ChangeCipherSpec => "ChangeCipherSpec",
TlsContentType::Alert => "Alert",
TlsContentType::Handshake => "Handshake",
TlsContentType::ApplicationData => "ApplicationData",
TlsContentType::Heartbeat => "Heartbeat",
};
write!(f, "{}", s)
}
}
#[derive(Debug, PartialEq)]
pub struct TlsVersion {
pub major: u8,
pub minor: u8,
}
const VALID_TLS_VERSIONS: [TlsVersion; 4] = [
TlsVersion { major: 3, minor: 1 }, TlsVersion { major: 3, minor: 2 }, TlsVersion { major: 3, minor: 3 }, TlsVersion { major: 3, minor: 4 }, ];
impl fmt::Display for TlsVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let version = match (self.major, self.minor) {
(3, 1) => "TLS 1.0",
(3, 2) => "TLS 1.1",
(3, 3) => "TLS 1.2",
(3, 4) => "TLS 1.3",
_ => return write!(f, "{}.{}", self.major, self.minor),
};
write!(f, "{}", version)
}
}
fn check_minimum_length(payload: &[u8]) -> Result<(), bool> {
if payload.len() < 5 {
return Err(false);
}
Ok(())
}
fn check_content_type(payload: &[u8]) -> Result<TlsContentType, bool> {
match payload[0] {
20 => Ok(TlsContentType::ChangeCipherSpec),
21 => Ok(TlsContentType::Alert),
22 => Ok(TlsContentType::Handshake),
23 => Ok(TlsContentType::ApplicationData),
24 => Ok(TlsContentType::Heartbeat),
_ => Err(false),
}
}
fn check_tls_version(payload: &[u8]) -> Result<TlsVersion, bool> {
let version = TlsVersion {
major: payload[1],
minor: payload[2],
};
if VALID_TLS_VERSIONS.contains(&version) {
Ok(version)
} else {
Err(false)
}
}
fn extract_length(payload: &[u8]) -> u16 {
u16::from_be_bytes([payload[3], payload[4]])
}
fn validate_payload_length(payload: &[u8], length: u16) -> Result<(), bool> {
if payload.len() < (5 + length as usize) {
return Err(false);
}
Ok(())
}
fn extract_payload(payload: &[u8], length: u16) -> Vec<u8> {
payload[5..(5 + length as usize)].to_vec()
}
pub fn parse_tls_packet(payload: &[u8]) -> Result<TlsPacket, bool> {
check_minimum_length(payload)?;
let content_type = check_content_type(payload)?;
let version = check_tls_version(payload)?;
let length = extract_length(payload);
validate_payload_length(payload, length)?;
let actual_payload = extract_payload(payload, length);
Ok(TlsPacket {
content_type,
version,
length,
payload: actual_payload,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_tls_packet() {
let tls_payload = vec![22, 3, 3, 0, 5, 1, 2, 3, 4, 5]; match parse_tls_packet(&tls_payload) {
Ok(packet) => {
assert_eq!(packet.content_type, TlsContentType::Handshake);
assert_eq!(packet.version, TlsVersion { major: 3, minor: 3 });
assert_eq!(packet.length, 5);
assert_eq!(packet.payload, vec![1, 2, 3, 4, 5]);
}
Err(_) => panic!("Expected TLS packet"),
}
let invalid_content_type = vec![99, 3, 3, 0, 5, 1, 2, 3, 4, 5];
match parse_tls_packet(&invalid_content_type) {
Ok(_) => panic!("Expected non-TLS packet due to invalid content type"),
Err(is_tls) => assert!(!is_tls),
}
let invalid_tls_version = vec![22, 3, 9, 0, 5, 1, 2, 3, 4, 5]; match parse_tls_packet(&invalid_tls_version) {
Ok(_) => panic!("Expected non-TLS packet due to invalid TLS version"),
Err(is_tls) => assert!(!is_tls),
}
let invalid_length = vec![22, 3, 3, 0, 6, 1, 2, 3, 4, 5]; match parse_tls_packet(&invalid_length) {
Ok(_) => panic!("Expected non-TLS packet due to inconsistent length"),
Err(is_tls) => assert!(!is_tls),
}
let short_payload = vec![22, 3, 3, 0]; match parse_tls_packet(&short_payload) {
Ok(_) => panic!("Expected non-TLS packet due to short payload"),
Err(is_tls) => assert!(!is_tls),
}
}
#[test]
fn test_check_minimum_length() {
assert!(check_minimum_length(&vec![1, 2, 3, 4, 5]).is_ok());
assert!(check_minimum_length(&vec![1, 2, 3, 4]).is_err());
}
#[test]
fn test_check_content_type() {
assert_eq!(
check_content_type(&vec![22, 3, 3, 0, 5]).unwrap(),
TlsContentType::Handshake
);
assert!(check_content_type(&vec![99, 3, 3, 0, 5]).is_err());
}
#[test]
fn test_check_tls_version() {
assert_eq!(
check_tls_version(&vec![22, 3, 3, 0, 5]).unwrap(),
TlsVersion { major: 3, minor: 3 }
);
assert!(check_tls_version(&vec![22, 3, 9, 0, 5]).is_err());
}
#[test]
fn test_extract_length() {
assert_eq!(extract_length(&vec![22, 3, 3, 0, 5]), 5);
}
#[test]
fn test_validate_payload_length() {
assert!(validate_payload_length(&vec![22, 3, 3, 0, 5, 1, 2, 3, 4, 5], 5).is_ok());
assert!(validate_payload_length(&vec![22, 3, 3, 0, 6, 1, 2, 3, 4, 5], 6).is_err());
}
#[test]
fn test_extract_payload() {
assert_eq!(
extract_payload(&vec![22, 3, 3, 0, 5, 1, 2, 3, 4, 5], 5),
vec![1, 2, 3, 4, 5]
);
}
}