use std::convert::TryFrom;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TlsError {
TooShort,
InvalidContentType(u8),
InvalidVersion { major: u8, minor: u8 },
InconsistentLength { declared: u16, available: usize },
}
#[derive(Debug)]
pub struct TlsPacket<'a> {
pub content_type: TlsContentType,
pub version: TlsVersion,
pub length: u16,
pub payload: &'a [u8],
}
impl fmt::Display for TlsPacket<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"TLS Packet: content_type={}, version={}, length={}, payload={:02X?}",
self.content_type, self.version, self.length, self.payload
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TlsContentType {
ChangeCipherSpec = 20,
Alert = 21,
Handshake = 22,
ApplicationData = 23,
Heartbeat = 24,
}
impl fmt::Display for TlsContentType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
TlsContentType::ChangeCipherSpec => "ChangeCipherSpec",
TlsContentType::Alert => "Alert",
TlsContentType::Handshake => "Handshake",
TlsContentType::ApplicationData => "ApplicationData",
TlsContentType::Heartbeat => "Heartbeat",
};
write!(f, "{s}")
}
}
impl TryFrom<u8> for TlsContentType {
type Error = TlsError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
20 => Ok(TlsContentType::ChangeCipherSpec),
21 => Ok(TlsContentType::Alert),
22 => Ok(TlsContentType::Handshake),
23 => Ok(TlsContentType::ApplicationData),
24 => Ok(TlsContentType::Heartbeat),
_ => Err(TlsError::InvalidContentType(value)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TlsVersion {
pub major: u8,
pub minor: u8,
}
impl TlsVersion {
pub fn new(major: u8, minor: u8) -> Result<Self, TlsError> {
match (major, minor) {
(3, 1) | (3, 2) | (3, 3) | (3, 4) => Ok(Self { major, minor }),
_ => Err(TlsError::InvalidVersion { major, minor }),
}
}
}
impl fmt::Display for TlsVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let version_str = match (self.major, self.minor) {
(3, 1) => "TLS 1.0",
(3, 2) => "TLS 1.1",
(3, 3) => "TLS 1.2",
(3, 4) => "TLS 1.3",
_ => return write!(f, "{}.{}", self.major, self.minor),
};
write!(f, "{version_str}")
}
}
impl<'a> TryFrom<&'a [u8]> for TlsPacket<'a> {
type Error = TlsError;
fn try_from(buf: &'a [u8]) -> Result<Self, Self::Error> {
if buf.len() < 5 {
return Err(TlsError::TooShort);
}
let content_type = TlsContentType::try_from(buf[0])?;
let version = TlsVersion::new(buf[1], buf[2])?;
let length = u16::from_be_bytes([buf[3], buf[4]]);
let header_len = 5usize;
let available = buf.len().saturating_sub(header_len);
if available < length as usize {
return Err(TlsError::InconsistentLength {
declared: length,
available,
});
}
let start = header_len;
let end = start + length as usize;
let payload = &buf[start..end];
Ok(TlsPacket {
content_type,
version,
length,
payload,
})
}
}
pub fn parse_tls_records<'a>(buf: &'a [u8]) -> Vec<TlsPacket<'a>> {
let mut records = Vec::new();
let mut offset = 0usize;
while buf.len().saturating_sub(offset) >= 5 {
let slice = &buf[offset..];
match TlsPacket::try_from(slice) {
Ok(packet) => {
let record_total_len = 5 + packet.length as usize;
if buf.len().saturating_sub(offset) < record_total_len {
break;
}
records.push(packet);
offset += record_total_len;
}
Err(TlsError::TooShort) => {
break;
}
Err(TlsError::InconsistentLength { .. }) => {
break;
}
Err(_) => {
break;
}
}
}
records
}
pub fn looks_like_tls(buf: &[u8]) -> bool {
TlsPacket::try_from(buf).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryFrom;
#[test]
fn test_parse_valid_tls_packet() {
let tls_payload = vec![22, 3, 3, 0, 5, 1, 2, 3, 4, 5];
let packet = TlsPacket::try_from(tls_payload.as_slice()).expect("Expected TLS packet");
assert_eq!(packet.content_type, TlsContentType::Handshake);
assert_eq!(packet.version, TlsVersion { major: 3, minor: 3 });
assert_eq!(packet.length, 5);
assert_eq!(packet.payload, &[1, 2, 3, 4, 5]);
}
#[test]
fn test_invalid_content_type() {
let invalid = vec![99, 3, 3, 0, 5, 1, 2, 3, 4, 5];
let err = TlsPacket::try_from(invalid.as_slice()).unwrap_err();
assert!(matches!(err, TlsError::InvalidContentType(99)));
}
#[test]
fn test_invalid_tls_version() {
let invalid = vec![22, 3, 9, 0, 5, 1, 2, 3, 4, 5];
let err = TlsPacket::try_from(invalid.as_slice()).unwrap_err();
assert!(matches!(
err,
TlsError::InvalidVersion { major: 3, minor: 9 }
));
}
#[test]
fn test_inconsistent_length() {
let invalid = vec![22, 3, 3, 0, 6, 1, 2, 3, 4, 5];
let err = TlsPacket::try_from(invalid.as_slice()).unwrap_err();
assert!(matches!(
err,
TlsError::InconsistentLength {
declared: 6,
available: 5
}
));
}
#[test]
fn test_too_short() {
let short = vec![22, 3, 3, 0];
let err = TlsPacket::try_from(short.as_slice()).unwrap_err();
assert!(matches!(err, TlsError::TooShort));
}
#[test]
fn test_parse_multiple_tls_records_in_one_buffer() {
let buf = vec![
20, 3, 3, 0, 1, 0x00, 23, 3, 3, 0, 3, 0x01, 0x02, 0x03, ];
let records = parse_tls_records(&buf);
assert_eq!(records.len(), 2);
assert_eq!(records[0].content_type, TlsContentType::ChangeCipherSpec);
assert_eq!(records[0].version, TlsVersion { major: 3, minor: 3 });
assert_eq!(records[0].length, 1);
assert_eq!(records[0].payload, &[0x00]);
assert_eq!(records[1].content_type, TlsContentType::ApplicationData);
assert_eq!(records[1].version, TlsVersion { major: 3, minor: 3 });
assert_eq!(records[1].length, 3);
assert_eq!(records[1].payload, &[0x01, 0x02, 0x03]);
}
#[test]
fn test_parse_tls_records_truncated_last_record() {
let buf = vec![
23, 3, 3, 0, 2, 0xAA, 0xBB, 23, 3, 3, 0, 4, 0xCC, ];
let records = parse_tls_records(&buf);
assert_eq!(records.len(), 1);
assert_eq!(records[0].content_type, TlsContentType::ApplicationData);
assert_eq!(records[0].length, 2);
assert_eq!(records[0].payload, &[0xAA, 0xBB]);
}
#[test]
fn test_parse_tls_records_non_tls_content() {
let buf = vec![1, 3, 3, 0, 5, 0, 0, 0, 0, 0];
let records = parse_tls_records(&buf);
assert!(records.is_empty());
}
#[test]
fn test_parse_tls_records_header_too_short_at_end() {
let buf = vec![
22, 3, 3, 0, 1, 0x01, 0x23, 0x00, 0x00, 0x00, ];
let records = parse_tls_records(&buf);
assert_eq!(records.len(), 1);
assert_eq!(records[0].content_type, TlsContentType::Handshake);
assert_eq!(records[0].payload, &[0x01]);
}
#[test]
fn test_looks_like_tls_when_true() {
let tls_buf = vec![22, 3, 3, 0, 2, 0xAA, 0xBB];
assert!(looks_like_tls(&tls_buf));
}
#[test]
fn test_looks_like_tls_when_false_invalid_content_type() {
let non_tls = vec![0, 3, 3, 0, 2, 0xAA, 0xBB];
assert!(!looks_like_tls(&non_tls));
}
#[test]
fn test_looks_like_tls_when_false_too_short() {
let too_short = vec![22, 3, 3, 0]; assert!(!looks_like_tls(&too_short));
}
#[test]
fn test_tls_content_type_from_u8_all_valid_values() {
for (value, expected) in [
(20u8, TlsContentType::ChangeCipherSpec),
(21, TlsContentType::Alert),
(22, TlsContentType::Handshake),
(23, TlsContentType::ApplicationData),
(24, TlsContentType::Heartbeat),
] {
let ct = TlsContentType::try_from(value).unwrap();
assert_eq!(ct, expected);
}
}
#[test]
fn test_tls_content_type_from_u8_invalid_value() {
let err = TlsContentType::try_from(0xFF).unwrap_err();
assert!(matches!(err, TlsError::InvalidContentType(0xFF)));
}
#[test]
fn test_tls_version_new_valid_versions() {
for (maj, min) in [(3, 1), (3, 2), (3, 3), (3, 4)] {
let v = TlsVersion::new(maj, min).expect("valid version");
assert_eq!(v.major, maj);
assert_eq!(v.minor, min);
}
}
#[test]
fn test_tls_version_new_invalid_version() {
let err = TlsVersion::new(3, 0).unwrap_err();
assert!(matches!(
err,
TlsError::InvalidVersion { major: 3, minor: 0 }
));
}
}