use crate::core::protocol::{ProtocolType, ProtocolInfo};
use crate::error::{DetectorError, Result};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TlsRecordType {
ChangeCipherSpec = 0x14,
Alert = 0x15,
Handshake = 0x16,
ApplicationData = 0x17,
}
impl TlsRecordType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x14 => Some(TlsRecordType::ChangeCipherSpec),
0x15 => Some(TlsRecordType::Alert),
0x16 => Some(TlsRecordType::Handshake),
0x17 => Some(TlsRecordType::ApplicationData),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TlsHandshakeType {
HelloRequest = 0x00,
ClientHello = 0x01,
ServerHello = 0x02,
Certificate = 0x0b,
ServerKeyExchange = 0x0c,
CertificateRequest = 0x0d,
ServerHelloDone = 0x0e,
CertificateVerify = 0x0f,
ClientKeyExchange = 0x10,
Finished = 0x14,
}
impl TlsHandshakeType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x00 => Some(TlsHandshakeType::HelloRequest),
0x01 => Some(TlsHandshakeType::ClientHello),
0x02 => Some(TlsHandshakeType::ServerHello),
0x0b => Some(TlsHandshakeType::Certificate),
0x0c => Some(TlsHandshakeType::ServerKeyExchange),
0x0d => Some(TlsHandshakeType::CertificateRequest),
0x0e => Some(TlsHandshakeType::ServerHelloDone),
0x0f => Some(TlsHandshakeType::CertificateVerify),
0x10 => Some(TlsHandshakeType::ClientKeyExchange),
0x14 => Some(TlsHandshakeType::Finished),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum TlsExtensionType {
ServerName = 0x0000,
MaxFragmentLength = 0x0001,
ClientCertificateUrl = 0x0002,
TrustedCaKeys = 0x0003,
TruncatedHmac = 0x0004,
StatusRequest = 0x0005,
ApplicationLayerProtocolNegotiation = 0x0010,
SignatureAlgorithms = 0x000d,
UseSrtp = 0x000e,
Heartbeat = 0x000f,
Padding = 0x0015,
}
impl TlsExtensionType {
pub fn from_u16(value: u16) -> Option<Self> {
match value {
0x0000 => Some(TlsExtensionType::ServerName),
0x0001 => Some(TlsExtensionType::MaxFragmentLength),
0x0002 => Some(TlsExtensionType::ClientCertificateUrl),
0x0003 => Some(TlsExtensionType::TrustedCaKeys),
0x0004 => Some(TlsExtensionType::TruncatedHmac),
0x0005 => Some(TlsExtensionType::StatusRequest),
0x000d => Some(TlsExtensionType::SignatureAlgorithms),
0x000e => Some(TlsExtensionType::UseSrtp),
0x000f => Some(TlsExtensionType::Heartbeat),
0x0010 => Some(TlsExtensionType::ApplicationLayerProtocolNegotiation),
0x0015 => Some(TlsExtensionType::Padding),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct AlpnDetectionResult {
pub protocols: Vec<String>,
pub primary_protocol: Option<ProtocolType>,
pub confidence: f32,
}
#[derive(Debug)]
pub struct TlsAlpnDetector {
min_data_size: usize,
enabled_protocols: Vec<ProtocolType>,
}
impl TlsAlpnDetector {
pub fn new() -> Self {
Self {
min_data_size: 64, enabled_protocols: vec![
ProtocolType::HTTP2,
ProtocolType::HTTP1_1,
ProtocolType::HTTP3,
],
}
}
pub fn with_enabled_protocols(mut self, protocols: Vec<ProtocolType>) -> Self {
self.enabled_protocols = protocols;
self
}
pub fn detect_alpn(&self, data: &[u8]) -> Option<AlpnDetectionResult> {
if data.len() < self.min_data_size {
return None;
}
if data.len() < 5 {
return None;
}
let record_type = TlsRecordType::from_u8(data[0])?;
if record_type != TlsRecordType::Handshake {
return None;
}
let record_length = u16::from_be_bytes([data[3], data[4]]) as usize;
let available_data = if data.len() < 5 + record_length {
&data[5..]
} else {
&data[5..5 + record_length]
};
if available_data.is_empty() {
return None;
}
let handshake_type = TlsHandshakeType::from_u8(available_data[0])?;
if handshake_type != TlsHandshakeType::ClientHello {
return None;
}
self.parse_client_hello_alpn(available_data)
}
fn parse_client_hello_alpn(&self, handshake_data: &[u8]) -> Option<AlpnDetectionResult> {
if handshake_data.len() < 12 {
return None;
}
let mut pos = 1 + 3 + 2 + 32;
if handshake_data.len() < pos {
return None;
}
let session_id_len = if handshake_data.len() > pos {
handshake_data[pos] as usize
} else {
return None;
};
pos += 1 + session_id_len;
if handshake_data.len() < pos {
return None;
}
if handshake_data.len() < pos + 2 {
return None;
}
let cipher_suites_len = u16::from_be_bytes([handshake_data[pos], handshake_data[pos + 1]]) as usize;
pos += 2 + cipher_suites_len;
if handshake_data.len() < pos {
return None;
}
if handshake_data.len() < pos + 1 {
return None;
}
let compression_methods_len = handshake_data[pos] as usize;
pos += 1 + compression_methods_len;
if handshake_data.len() < pos {
return None;
}
if handshake_data.len() < pos + 2 {
return None;
}
let extensions_length = u16::from_be_bytes([handshake_data[pos], handshake_data[pos + 1]]) as usize;
pos += 2;
let available_extensions_length = if handshake_data.len() < pos + extensions_length {
handshake_data.len() - pos
} else {
extensions_length
};
if available_extensions_length == 0 {
return None;
}
let extensions_data = &handshake_data[pos..pos + available_extensions_length];
self.parse_alpn_extensions(extensions_data)
}
fn parse_alpn_extensions(&self, extensions_data: &[u8]) -> Option<AlpnDetectionResult> {
let mut pos = 0;
let mut alpn_protocols = Vec::new();
while pos + 4 <= extensions_data.len() {
let extension_type = u16::from_be_bytes([extensions_data[pos], extensions_data[pos + 1]]);
let extension_length = u16::from_be_bytes([extensions_data[pos + 2], extensions_data[pos + 3]]) as usize;
pos += 4;
if pos + extension_length > extensions_data.len() {
break;
}
if extension_type == 0x0010 {
let alpn_data = &extensions_data[pos..pos + extension_length];
if let Some(protocols) = self.parse_alpn_list(alpn_data) {
alpn_protocols = protocols;
break;
}
}
pos += extension_length;
}
if alpn_protocols.is_empty() {
return None;
}
let primary_protocol = self.determine_primary_protocol(&alpn_protocols);
let confidence = self.calculate_confidence(&alpn_protocols, primary_protocol);
Some(AlpnDetectionResult {
protocols: alpn_protocols,
primary_protocol,
confidence,
})
}
fn parse_alpn_list(&self, alpn_data: &[u8]) -> Option<Vec<String>> {
if alpn_data.is_empty() {
return None;
}
let mut pos = 0;
let mut protocols = Vec::new();
if pos + 2 > alpn_data.len() {
return None;
}
let alpn_list_length = u16::from_be_bytes([alpn_data[pos], alpn_data[pos + 1]]) as usize;
pos += 2;
if pos + alpn_list_length > alpn_data.len() {
return None;
}
while pos + 1 <= alpn_data.len() {
let protocol_name_len = alpn_data[pos] as usize;
pos += 1;
if pos + protocol_name_len > alpn_data.len() {
break;
}
let protocol_name = String::from_utf8_lossy(&alpn_data[pos..pos + protocol_name_len]).to_string();
protocols.push(protocol_name);
pos += protocol_name_len;
}
if protocols.is_empty() {
None
} else {
Some(protocols)
}
}
fn determine_primary_protocol(&self, protocols: &[String]) -> Option<ProtocolType> {
for protocol in protocols {
match protocol.as_str() {
"h2" => return Some(ProtocolType::HTTP2),
"h2-16" | "h2-14" => return Some(ProtocolType::HTTP2),
"http/1.1" | "http/1.0" => return Some(ProtocolType::HTTP1_1),
"h3" | "h3-29" | "h3-28" => return Some(ProtocolType::HTTP3),
_ => {}
}
}
None
}
fn calculate_confidence(&self, protocols: &[String], primary_protocol: Option<ProtocolType>) -> f32 {
if primary_protocol.is_none() {
return 0.5;
}
let mut confidence: f32 = 0.85;
if protocols.iter().any(|p| p.starts_with("h2")) {
confidence += 0.1;
}
if protocols.iter().any(|p| p.starts_with("h3")) {
confidence += 0.1;
}
confidence.min(0.95) }
pub fn create_protocol_info(&self, result: AlpnDetectionResult) -> Option<ProtocolInfo> {
if let Some(primary_protocol) = result.primary_protocol {
let mut info = ProtocolInfo::new(primary_protocol, result.confidence);
info.add_metadata("alpn_protocols", &result.protocols.join(","));
info.add_metadata("detection_method", "tls_alpn");
Some(info)
} else {
let mut info = ProtocolInfo::new(ProtocolType::TLS, 0.7);
info.add_metadata("alpn_protocols", &result.protocols.join(","));
info.add_metadata("detection_method", "tls_alpn");
Some(info)
}
}
}
impl Default for TlsAlpnDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alpn_detection() {
let detector = TlsAlpnDetector::new();
let tls_h2_data = vec![
0x16, 0x03, 0x01, 0x00, 0x80,
0x01, 0x00, 0x00, 0x7c,
0x03, 0x03,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x00, 0x00, 0x02, 0x13, 0x01, 0x01, 0x00, 0x00, 0x30, 0x00, 0x10, 0x00, 0x07, 0x00, 0x05, 0x02, 0x68, 0x32, 0x00, 0x0d, 0x00, 0x04, 0x00, 0x02, 0x04, 0x03,
0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x17,
0x00, 0x0b, 0x00, 0x02, 0x01, 0x00,
];
let result = detector.detect_alpn(&tls_h2_data);
assert!(result.is_some());
let detection = result.unwrap();
assert!(detection.protocols.contains(&"h2".to_string()));
assert_eq!(detection.primary_protocol, Some(ProtocolType::HTTP2));
assert!(detection.confidence > 0.9);
}
#[test]
fn test_no_alpn_extension() {
let detector = TlsAlpnDetector::new();
let tls_no_alpn_data = vec![
0x16, 0x03, 0x01, 0x00, 0x40,
0x01, 0x00, 0x00, 0x3c,
0x03, 0x03,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x00, 0x00, 0x02, 0x13, 0x01, 0x01, 0x00, 0x00, 0x06, 0x00, 0x0d, 0x00, 0x02, 0x04, 0x03,
];
let result = detector.detect_alpn(&tls_no_alpn_data);
assert!(result.is_none());
}
}