#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use crate::core::protocol::ProtocolType;
use crate::error::{DetectorError, Result};
use crate::simd::{SimdDetectionResult, SimdDetector, SimdInstructionSet};
use std::time::Instant;
pub struct AArch64SimdDetector {
instruction_set: SimdInstructionSet,
has_neon: bool,
}
impl AArch64SimdDetector {
pub fn new() -> Self {
let has_neon = detect_neon_support();
let instruction_set = if has_neon {
SimdInstructionSet::NEON
} else {
SimdInstructionSet::None
};
Self {
instruction_set,
has_neon,
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_pattern_match(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
if !self.has_neon || needle.is_empty() || haystack.len() < needle.len() {
return None;
}
if needle.len() == 1 {
return self.neon_find_byte(haystack, needle[0]);
}
let first_byte = needle[0];
let needle_vec = vdupq_n_u8(first_byte);
let mut pos = 0;
while pos + 16 <= haystack.len() {
let chunk = vld1q_u8(haystack.as_ptr().add(pos));
let cmp = vceqq_u8(chunk, needle_vec);
let mask = vget_lane_u64(vreinterpret_u64_u8(vorr_u8(
vget_low_u8(cmp),
vget_high_u8(cmp)
)), 0);
if mask != 0 {
for i in 0..16 {
let check_pos = pos + i;
if check_pos + needle.len() <= haystack.len() && haystack[check_pos] == first_byte {
if haystack[check_pos..check_pos + needle.len()] == *needle {
return Some(check_pos);
}
}
}
}
pos += 16;
}
while pos + needle.len() <= haystack.len() {
if haystack[pos..pos + needle.len()] == *needle {
return Some(pos);
}
pos += 1;
}
None
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_find_byte(&self, data: &[u8], byte: u8) -> Option<usize> {
if !self.has_neon || data.is_empty() {
return None;
}
let needle = vdupq_n_u8(byte);
let mut pos = 0;
while pos + 16 <= data.len() {
let chunk = vld1q_u8(data.as_ptr().add(pos));
let cmp = vceqq_u8(chunk, needle);
let mask_low = vget_low_u8(cmp);
let mask_high = vget_high_u8(cmp);
let low_mask = vget_lane_u64(vreinterpret_u64_u8(mask_low), 0);
if low_mask != 0 {
for i in 0..8 {
if data[pos + i] == byte {
return Some(pos + i);
}
}
}
let high_mask = vget_lane_u64(vreinterpret_u64_u8(mask_high), 0);
if high_mask != 0 {
for i in 8..16 {
if pos + i < data.len() && data[pos + i] == byte {
return Some(pos + i);
}
}
}
pos += 16;
}
for i in pos..data.len() {
if data[i] == byte {
return Some(i);
}
}
None
}
fn fast_pattern_match(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
#[cfg(target_arch = "aarch64")]
unsafe {
if self.has_neon {
return self.neon_pattern_match(haystack, needle);
}
}
self.fallback_pattern_match(haystack, needle)
}
fn fallback_pattern_match(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
for i in 0..=haystack.len() - needle.len() {
if haystack[i..i + needle.len()] == *needle {
return Some(i);
}
}
None
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_count_bytes(&self, data: &[u8], byte: u8) -> usize {
if !self.has_neon || data.is_empty() {
return data.iter().filter(|&&b| b == byte).count();
}
let needle = vdupq_n_u8(byte);
let mut count = 0;
let mut pos = 0;
while pos + 16 <= data.len() {
let chunk = vld1q_u8(data.as_ptr().add(pos));
let cmp = vceqq_u8(chunk, needle);
let low = vget_low_u8(cmp);
let high = vget_high_u8(cmp);
let low_sum = vaddv_u8(vshr_n_u8(low, 7));
let high_sum = vaddv_u8(vshr_n_u8(high, 7));
count += (low_sum + high_sum) as usize;
pos += 16;
}
for i in pos..data.len() {
if data[i] == byte {
count += 1;
}
}
count
}
}
impl SimdDetector for AArch64SimdDetector {
fn detect_http2(&self, data: &[u8]) -> Result<SimdDetectionResult> {
let start = Instant::now();
let http2_preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
if let Some(_) = self.fast_pattern_match(data, http2_preface) {
return Ok(SimdDetectionResult {
protocol: ProtocolType::HTTP2,
confidence: 1.0,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
if data.len() >= 9 {
let frame_type = data[3];
match frame_type {
0x0 => { return Ok(SimdDetectionResult {
protocol: ProtocolType::HTTP2,
confidence: 0.8,
match_positions: vec![3],
instruction_set: self.instruction_set,
});
}
0x1 => { return Ok(SimdDetectionResult {
protocol: ProtocolType::HTTP2,
confidence: 0.9,
match_positions: vec![3],
instruction_set: self.instruction_set,
});
}
0x4 => { return Ok(SimdDetectionResult {
protocol: ProtocolType::HTTP2,
confidence: 0.95,
match_positions: vec![3],
instruction_set: self.instruction_set,
});
}
_ => {}
}
}
Err(DetectorError::detection_failed("No HTTP/2 patterns found"))
}
fn detect_quic(&self, data: &[u8]) -> Result<SimdDetectionResult> {
if data.is_empty() {
return Err(DetectorError::detection_failed("Empty data"));
}
let first_byte = data[0];
if (first_byte & 0x80) != 0 {
if data.len() >= 5 {
let version = u32::from_be_bytes([
data[1], data[2], data[3], data[4]
]);
if version == 0x00000001 {
return Ok(SimdDetectionResult {
protocol: ProtocolType::QUIC,
confidence: 0.95,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
if version == 0x00000000 {
return Ok(SimdDetectionResult {
protocol: ProtocolType::QUIC,
confidence: 0.9,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
}
} else {
return Ok(SimdDetectionResult {
protocol: ProtocolType::QUIC,
confidence: 0.7,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
Err(DetectorError::detection_failed("No QUIC patterns found"))
}
fn detect_grpc(&self, data: &[u8]) -> Result<SimdDetectionResult> {
let grpc_content_type = b"application/grpc";
if let Some(pos) = self.fast_pattern_match(data, grpc_content_type) {
return Ok(SimdDetectionResult {
protocol: ProtocolType::GRPC,
confidence: 0.9,
match_positions: vec![pos],
instruction_set: self.instruction_set,
});
}
let grpc_web = b"application/grpc-web";
if let Some(pos) = self.fast_pattern_match(data, grpc_web) {
return Ok(SimdDetectionResult {
protocol: ProtocolType::GRPC,
confidence: 0.85,
match_positions: vec![pos],
instruction_set: self.instruction_set,
});
}
if data.len() >= 5 {
let compression_flag = data[0];
if compression_flag <= 1 {
let message_length = u32::from_be_bytes([
data[1], data[2], data[3], data[4]
]) as usize;
if message_length > 0 && data.len() >= 5 + message_length {
return Ok(SimdDetectionResult {
protocol: ProtocolType::GRPC,
confidence: 0.7,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
}
}
Err(DetectorError::detection_failed("No gRPC patterns found"))
}
fn detect_websocket(&self, data: &[u8]) -> Result<SimdDetectionResult> {
let mut positions = Vec::new();
let mut confidence = 0.0;
let upgrade_header = b"Upgrade: websocket";
if let Some(pos) = self.fast_pattern_match(data, upgrade_header) {
positions.push(pos);
confidence += 0.4;
}
let connection_header = b"Connection: Upgrade";
if let Some(pos) = self.fast_pattern_match(data, connection_header) {
positions.push(pos);
confidence += 0.3;
}
let websocket_key = b"Sec-WebSocket-Key:";
if let Some(pos) = self.fast_pattern_match(data, websocket_key) {
positions.push(pos);
confidence += 0.3;
}
if data.len() >= 2 {
let first_byte = data[0];
let opcode = first_byte & 0x0F;
if matches!(opcode, 0x0 | 0x1 | 0x2 | 0x8 | 0x9 | 0xA) {
confidence += 0.2;
if positions.is_empty() {
positions.push(0);
}
}
}
if confidence > 0.5 {
Ok(SimdDetectionResult {
protocol: ProtocolType::WebSocket,
confidence,
match_positions: positions,
instruction_set: self.instruction_set,
})
} else {
Err(DetectorError::detection_failed("No WebSocket patterns found"))
}
}
fn detect_tls(&self, data: &[u8]) -> Result<SimdDetectionResult> {
if data.len() < 5 {
return Err(DetectorError::detection_failed("Data too short for TLS"));
}
let content_type = data[0];
let version_major = data[1];
let version_minor = data[2];
let valid_content_type = matches!(content_type, 0x14 | 0x15 | 0x16 | 0x17);
let valid_version = matches!((version_major, version_minor),
(0x03, 0x00) | (0x03, 0x01) | (0x03, 0x02) | (0x03, 0x03) | (0x03, 0x04));
if valid_content_type && valid_version {
let length = u16::from_be_bytes([data[3], data[4]]) as usize;
if length > 0 && length <= 16384 && data.len() >= 5 + length {
let confidence = match content_type {
0x16 => 0.95, 0x17 => 0.9, 0x15 => 0.85, 0x14 => 0.8, _ => 0.7,
};
return Ok(SimdDetectionResult {
protocol: ProtocolType::TLS,
confidence,
match_positions: vec![0],
instruction_set: self.instruction_set,
});
}
}
Err(DetectorError::detection_failed("No TLS patterns found"))
}
fn detect_multiple(&self, data: &[u8], protocols: &[ProtocolType]) -> Result<Vec<SimdDetectionResult>> {
let mut results = Vec::new();
for &protocol in protocols {
let result = match protocol {
ProtocolType::HTTP2 => self.detect_http2(data),
ProtocolType::QUIC => self.detect_quic(data),
ProtocolType::GRPC => self.detect_grpc(data),
ProtocolType::WebSocket => self.detect_websocket(data),
ProtocolType::TLS => self.detect_tls(data),
_ => continue,
};
if let Ok(detection) = result {
results.push(detection);
}
}
results.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
fn instruction_set(&self) -> SimdInstructionSet {
self.instruction_set
}
fn supports_protocol(&self, protocol: ProtocolType) -> bool {
matches!(
protocol,
ProtocolType::HTTP2
| ProtocolType::QUIC
| ProtocolType::GRPC
| ProtocolType::WebSocket
| ProtocolType::TLS
| ProtocolType::UDP
)
}
}
fn detect_neon_support() -> bool {
#[cfg(target_arch = "aarch64")]
{
true
}
#[cfg(not(target_arch = "aarch64"))]
{
false
}
}
pub type NeonDetector = AArch64SimdDetector;