use crate::core::protocol::{ProtocolType, ProtocolInfo};
use crate::core::tls_alpn::TlsAlpnDetector;
use crate::error::{DetectorError, Result};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct MagicSignature {
pub protocol: ProtocolType,
pub magic_bytes: Vec<u8>,
pub offset: usize,
pub match_length: usize,
pub confidence: f32,
pub description: String,
pub case_sensitive: bool,
}
impl MagicSignature {
pub fn new(
protocol: ProtocolType,
magic_bytes: Vec<u8>,
offset: usize,
confidence: f32,
description: String,
) -> Self {
let match_length = magic_bytes.len();
Self {
protocol,
magic_bytes,
offset,
match_length,
confidence,
description,
case_sensitive: true,
}
}
pub fn case_insensitive(mut self) -> Self {
self.case_sensitive = false;
self
}
pub fn with_match_length(mut self, length: usize) -> Self {
self.match_length = length.min(self.magic_bytes.len());
self
}
pub fn matches(&self, data: &[u8]) -> bool {
if data.len() < self.offset + self.match_length {
return false;
}
let data_slice = &data[self.offset..self.offset + self.match_length];
let magic_slice = &self.magic_bytes[..self.match_length];
if self.case_sensitive {
data_slice == magic_slice
} else {
data_slice.iter().zip(magic_slice.iter())
.all(|(a, b)| a.to_ascii_lowercase() == b.to_ascii_lowercase())
}
}
}
#[derive(Debug)]
pub struct MagicDetector {
byte_indexed_signatures: HashMap<u8, Vec<MagicSignature>>,
all_signatures: Vec<MagicSignature>,
enabled_protocols: Option<Vec<ProtocolType>>,
tls_alpn_detector: TlsAlpnDetector,
}
impl MagicDetector {
pub fn new() -> Self {
let mut detector = Self {
byte_indexed_signatures: HashMap::new(),
all_signatures: Vec::new(),
enabled_protocols: None,
tls_alpn_detector: TlsAlpnDetector::new(),
};
detector.load_common_signatures();
detector
}
fn load_common_signatures(&mut self) {
let signatures = vec![
MagicSignature::new(
ProtocolType::HTTP1_1,
b"GET ".to_vec(),
0,
0.95,
"HTTP GET request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"POST ".to_vec(),
0,
0.95,
"HTTP POST request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"PUT ".to_vec(),
0,
0.95,
"HTTP PUT request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"HEAD ".to_vec(),
0,
0.95,
"HTTP HEAD request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"OPTIONS ".to_vec(),
0,
0.95,
"HTTP OPTIONS request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"DELETE ".to_vec(),
0,
0.95,
"HTTP DELETE request".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP1_1,
b"HTTP/1.".to_vec(),
0,
0.98,
"HTTP response".to_string(),
),
MagicSignature::new(
ProtocolType::HTTP2,
b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(),
0,
1.0,
"HTTP/2 connection preface".to_string(),
),
MagicSignature::new(
ProtocolType::TLS,
vec![0x16, 0x03], 0,
0.9,
"TLS handshake".to_string(),
).with_match_length(2),
MagicSignature::new(
ProtocolType::QUIC,
vec![0x80], 0,
0.7,
"QUIC long header".to_string(),
).with_match_length(1),
MagicSignature::new(
ProtocolType::SSH,
b"SSH-".to_vec(),
0,
0.99,
"SSH protocol".to_string(),
),
MagicSignature::new(
ProtocolType::FTP,
b"220 ".to_vec(),
0,
0.85,
"FTP welcome message".to_string(),
),
MagicSignature::new(
ProtocolType::SMTP,
b"220 ".to_vec(),
0,
0.8,
"SMTP welcome".to_string(),
),
MagicSignature::new(
ProtocolType::WebSocket,
b"GET ".to_vec(),
0,
0.3, "Potential WebSocket upgrade".to_string(),
),
MagicSignature::new(
ProtocolType::GRPC,
b"application/grpc".to_vec(),
0,
0.95,
"gRPC content type".to_string(),
),
MagicSignature::new(
ProtocolType::DNS,
vec![0x00, 0x00, 0x01, 0x00], 2,
0.8,
"DNS query".to_string(),
).with_match_length(4),
MagicSignature::new(
ProtocolType::MQTT,
vec![0x10], 0,
0.7,
"MQTT CONNECT".to_string(),
).with_match_length(1),
MagicSignature::new(
ProtocolType::Redis,
b"*".to_vec(),
0,
0.6,
"Redis bulk array".to_string(),
),
MagicSignature::new(
ProtocolType::Redis,
b"+OK\r\n".to_vec(),
0,
0.9,
"Redis OK response".to_string(),
),
MagicSignature::new(
ProtocolType::MySQL,
vec![0x0a], 4,
0.8,
"MySQL handshake".to_string(),
).with_match_length(1),
];
for signature in signatures {
self.add_signature(signature);
}
}
pub fn add_signature(&mut self, signature: MagicSignature) {
if !signature.magic_bytes.is_empty() {
let first_byte = if signature.case_sensitive {
signature.magic_bytes[0]
} else {
signature.magic_bytes[0].to_ascii_lowercase()
};
self.byte_indexed_signatures
.entry(first_byte)
.or_insert_with(Vec::new)
.push(signature.clone());
if !signature.case_sensitive {
let upper_byte = signature.magic_bytes[0].to_ascii_uppercase();
if upper_byte != first_byte {
self.byte_indexed_signatures
.entry(upper_byte)
.or_insert_with(Vec::new)
.push(signature.clone());
}
}
}
self.all_signatures.push(signature);
}
pub fn with_enabled_protocols(mut self, protocols: Vec<ProtocolType>) -> Self {
self.enabled_protocols = Some(protocols);
self
}
pub fn quick_detect(&self, data: &[u8]) -> Option<ProtocolInfo> {
if data.is_empty() {
return None;
}
let first_byte = data[0];
if let Some(signatures) = self.byte_indexed_signatures.get(&first_byte) {
for signature in signatures {
if let Some(ref enabled) = self.enabled_protocols {
if !enabled.contains(&signature.protocol) {
continue;
}
}
if signature.matches(data) {
if signature.protocol == ProtocolType::TLS {
if let Some(alpn_result) = self.tls_alpn_detector.detect_alpn(data) {
if let Some(alpn_info) = self.tls_alpn_detector.create_protocol_info(alpn_result) {
if let Some(ref enabled) = self.enabled_protocols {
if enabled.contains(&alpn_info.protocol_type) {
return Some(alpn_info);
}
} else {
return Some(alpn_info);
}
}
}
}
let mut info = ProtocolInfo::new(signature.protocol, signature.confidence);
info.add_metadata("detection_method", "magic_bytes");
info.add_metadata("signature_desc", &signature.description);
return Some(info);
}
}
}
self.heuristic_by_first_byte(data, first_byte)
}
fn heuristic_by_first_byte(&self, data: &[u8], first_byte: u8) -> Option<ProtocolInfo> {
let confidence = match first_byte {
0x14 | 0x15 | 0x16 | 0x17 => {
if data.len() >= 3 && data[1] == 0x03 {
Some((ProtocolType::TLS, 0.85, "TLS record"))
} else {
None
}
},
0x80..=0xFF => {
if data.len() >= 5 {
Some((ProtocolType::QUIC, 0.6, "QUIC long header"))
} else {
None
}
},
b'G' | b'P' | b'H' | b'O' | b'D' => {
if data.len() >= 4 {
Some((ProtocolType::HTTP1_1, 0.4, "HTTP method"))
} else {
None
}
},
0x00 => {
if data.len() >= 12 {
Some((ProtocolType::DNS, 0.3, "DNS length prefix"))
} else {
None
}
},
_ => None,
};
confidence.and_then(|(protocol, conf, desc)| {
if let Some(ref enabled) = self.enabled_protocols {
if !enabled.contains(&protocol) {
return None;
}
}
let mut info = ProtocolInfo::new(protocol, conf);
info.add_metadata("detection_method", "heuristic");
info.add_metadata("heuristic_desc", desc);
Some(info)
})
}
pub fn deep_detect(&self, data: &[u8]) -> Vec<ProtocolInfo> {
let mut results = Vec::new();
for signature in &self.all_signatures {
if let Some(ref enabled) = self.enabled_protocols {
if !enabled.contains(&signature.protocol) {
continue;
}
}
if signature.matches(data) {
let mut info = ProtocolInfo::new(signature.protocol, signature.confidence);
info.add_metadata("detection_method", "magic_bytes");
info.add_metadata("signature_desc", &signature.description);
info.add_metadata("match_offset", &signature.offset.to_string());
results.push(info);
}
}
results.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn supported_protocols(&self) -> Vec<ProtocolType> {
self.all_signatures.iter()
.map(|sig| sig.protocol)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
}
pub fn get_signatures_for_protocol(&self, protocol: ProtocolType) -> Vec<&MagicSignature> {
self.all_signatures.iter()
.filter(|sig| sig.protocol == protocol)
.collect()
}
}
impl Default for MagicDetector {
fn default() -> Self {
Self::new()
}
}
pub struct CustomSignatureBuilder {
protocol: ProtocolType,
magic_bytes: Vec<u8>,
offset: usize,
confidence: f32,
description: String,
case_sensitive: bool,
match_length: Option<usize>,
}
impl CustomSignatureBuilder {
pub fn new(protocol: ProtocolType, description: &str) -> Self {
Self {
protocol,
magic_bytes: Vec::new(),
offset: 0,
confidence: 0.8,
description: description.to_string(),
case_sensitive: true,
match_length: None,
}
}
pub fn with_magic_string(mut self, magic: &str) -> Self {
self.magic_bytes = magic.as_bytes().to_vec();
self
}
pub fn with_magic_bytes(mut self, magic: Vec<u8>) -> Self {
self.magic_bytes = magic;
self
}
pub fn with_offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn case_insensitive(mut self) -> Self {
self.case_sensitive = false;
self
}
pub fn with_match_length(mut self, length: usize) -> Self {
self.match_length = Some(length);
self
}
pub fn build(self) -> MagicSignature {
let mut signature = MagicSignature::new(
self.protocol,
self.magic_bytes,
self.offset,
self.confidence,
self.description,
);
signature.case_sensitive = self.case_sensitive;
if let Some(length) = self.match_length {
signature.match_length = length.min(signature.magic_bytes.len());
}
signature
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_magic_detection() {
let detector = MagicDetector::new();
let http_get = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
let result = detector.quick_detect(http_get).unwrap();
assert_eq!(result.protocol_type, ProtocolType::HTTP1_1);
assert!(result.confidence > 0.9);
let http2_preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
let result = detector.quick_detect(http2_preface).unwrap();
assert_eq!(result.protocol_type, ProtocolType::HTTP2);
assert_eq!(result.confidence, 1.0);
}
#[test]
fn test_tls_magic_detection() {
let detector = MagicDetector::new();
let tls_handshake = &[0x16, 0x03, 0x01, 0x00, 0x2f];
let result = detector.quick_detect(tls_handshake).unwrap();
assert_eq!(result.protocol_type, ProtocolType::TLS);
assert!(result.confidence > 0.8);
}
#[test]
fn test_custom_signature() {
let mut detector = MagicDetector::new();
let custom_sig = CustomSignatureBuilder::new(ProtocolType::Custom, "My Protocol")
.with_magic_string("MYPROT")
.with_confidence(0.95)
.build();
detector.add_signature(custom_sig);
let test_data = b"MYPROT version 1.0";
let result = detector.quick_detect(test_data).unwrap();
assert_eq!(result.protocol_type, ProtocolType::Custom);
assert_eq!(result.confidence, 0.95);
}
}