use crate::core::{ProtocolType, DetectionResult, ProtocolInfo};
use crate::core::detector::DetectionMethod;
use crate::core::probe::{ProtocolProbe, ProbeContext};
use crate::error::{Result, DetectorError};
use super::{ProbeEngine, ProbeType};
pub struct PassiveProbe {
min_data_size: usize,
confidence_threshold: f32,
}
impl PassiveProbe {
pub fn new() -> Self {
Self {
min_data_size: 16,
confidence_threshold: 0.7,
}
}
pub fn with_min_data_size(mut self, size: usize) -> Self {
self.min_data_size = size;
self
}
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
self.confidence_threshold = threshold;
self
}
fn detect_http1(&self, data: &[u8]) -> Option<f32> {
if data.len() < 8 {
return None;
}
let methods = [b"GET ", b"POST", b"PUT ", b"HEAD", b"DELE"];
for method in &methods {
if data.starts_with(*method) {
return Some(0.9);
}
}
if data.starts_with(b"HTTP/1.") {
return Some(0.95);
}
None
}
fn detect_http2(&self, data: &[u8]) -> Option<f32> {
if data.len() < 24 {
return None;
}
const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
if data.starts_with(HTTP2_PREFACE) {
return Some(1.0);
}
if data.len() >= 9 {
let frame_type = data[3];
if frame_type == 0x4 || frame_type == 0x1 {
return Some(0.8);
}
}
None
}
fn detect_quic(&self, data: &[u8]) -> Option<f32> {
if data.len() < 16 {
return None;
}
let first_byte = data[0];
if (first_byte & 0x80) != 0 { let version = u32::from_be_bytes([
data[1], data[2], data[3], data[4]
]);
match version {
0x00000001 => return Some(0.95), 0xff00001d => return Some(0.9), 0 => return Some(0.7), _ => {}
}
}
None
}
fn detect_http3(&self, data: &[u8]) -> Option<f32> {
if let Some(quic_confidence) = self.detect_quic(data) {
if quic_confidence > 0.7 {
let mut http3_confidence = 0.0;
if self.fast_search(data, b"h3") || self.fast_search(data, b"h3-") {
http3_confidence += 0.5;
}
if data.len() >= 20 {
let check_positions = [16, 20, 24, 28, 32, 36, 40, 44, 48, 52];
for &pos in &check_positions {
if pos < data.len() {
let frame_type = data[pos];
if matches!(frame_type, 0x0 | 0x1 | 0x4 | 0x5 | 0x7 | 0xd | 0xe) {
http3_confidence += 0.4;
break;
}
}
}
}
if self.fast_search(data, &[0x01, 0x40]) || self.fast_search(data, &[0x06, 0x40]) { http3_confidence += 0.3;
}
if http3_confidence >= 0.4 {
return Some((quic_confidence + http3_confidence).min(0.95));
}
return Some(quic_confidence * 0.6);
}
}
None
}
fn detect_grpc(&self, data: &[u8]) -> Option<f32> {
if data.len() < 16 {
return None;
}
let mut confidence: f32 = 0.0;
const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
if data.len() >= 24 && data.starts_with(HTTP2_PREFACE) {
confidence += 0.4;
}
if self.fast_search(data, b"application/grpc") {
confidence += 0.5;
}
if data.len() >= 9 {
let check_positions = [
0, 24, 33, ];
for &pos in &check_positions {
if pos + 9 <= data.len() {
let frame_type = data[pos + 3];
if matches!(frame_type, 0x00..=0x08) {
confidence += 0.3;
break;
}
}
}
}
if confidence >= 0.8 {
confidence = confidence.max(0.9);
}
if confidence > 0.5 {
Some(confidence)
} else {
None
}
}
#[inline]
fn fast_search(&self, haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() || haystack.len() < needle.len() {
return false;
}
if needle.len() <= 4 {
return haystack.windows(needle.len()).any(|window| window == needle);
}
let last_byte = needle[needle.len() - 1];
let mut i = needle.len() - 1;
let max_iterations = haystack.len() * 2;
let mut iteration_count = 0;
while i < haystack.len() && iteration_count < max_iterations {
iteration_count += 1;
if haystack[i] == last_byte {
let start = i + 1 - needle.len();
if start <= i && haystack[start..=i] == *needle {
return true;
}
}
i += 1;
}
false
}
fn detect_websocket(&self, data: &[u8]) -> Option<f32> {
if data.len() < 20 {
return None;
}
let is_http_like = self.fast_search(data, b"HTTP/") ||
self.fast_search(data, b"GET ") ||
self.fast_search(data, b"POST ");
if is_http_like {
let has_upgrade_websocket = self.fast_search(data, b"Upgrade: websocket") ||
self.fast_search(data, b"upgrade: websocket");
if has_upgrade_websocket {
if self.fast_search(data, b"HTTP/1.1 101") {
return Some(0.98);
}
return Some(0.75);
}
return None;
}
if data.len() >= 2 {
let first_byte = data[0];
let second_byte = data[1];
let opcode = first_byte & 0x0F;
let masked = (second_byte & 0x80) != 0;
let payload_len = second_byte & 0x7F;
if matches!(opcode, 0x0..=0x2 | 0x8..=0xA) && payload_len <= 125 {
let expected_header_len = if masked { 6 } else { 2 };
if data.len() >= expected_header_len {
return Some(0.6); }
}
}
None
}
fn detect_tls(&self, data: &[u8]) -> Option<f32> {
if data.len() < 5 {
return None;
}
let content_type = data[0];
let version_major = data[1];
let version_minor = data[2];
let valid_content_type = matches!(content_type, 0x14 | 0x15 | 0x16 | 0x17);
let valid_version = match (version_major, version_minor) {
(0x03, 0x00) => true, (0x03, 0x01) => true, (0x03, 0x02) => true, (0x03, 0x03) => true, (0x03, 0x04) => true, _ => false,
};
if valid_content_type && valid_version {
if content_type == 0x16 {
Some(0.95)
} else {
Some(0.8)
}
} else {
None
}
}
fn detect_ssh(&self, data: &[u8]) -> Option<f32> {
if data.len() < 4 {
return None;
}
if data.starts_with(b"SSH-2.0") {
return Some(0.98);
}
if data.starts_with(b"SSH-1.") {
return Some(0.95);
}
if data.starts_with(b"SSH-") {
return Some(0.9);
}
if data.len() >= 6 {
let packet_length = u32::from_be_bytes([
data[0], data[1], data[2], data[3]
]);
if packet_length > 0 && packet_length < 65536 &&
packet_length as usize <= data.len() - 4 {
let padding_length = data[4];
if padding_length < 255 {
return Some(0.6);
}
}
}
None
}
}
impl ProbeEngine for PassiveProbe {
fn probe(&self, data: &[u8]) -> Result<DetectionResult> {
if data.len() < self.min_data_size {
return Err(DetectorError::NeedMoreData(self.min_data_size));
}
let mut best_protocol = ProtocolType::Unknown;
let mut best_confidence = 0.0;
let mut detections = [(ProtocolType::Unknown, 0.0); 8];
let mut detection_count = 0;
if let Some(confidence) = self.detect_http3(data) {
detections[detection_count] = (ProtocolType::HTTP3, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_quic(data) {
detections[detection_count] = (ProtocolType::QUIC, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_http2(data) {
detections[detection_count] = (ProtocolType::HTTP2, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_grpc(data) {
detections[detection_count] = (ProtocolType::GRPC, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_http1(data) {
detections[detection_count] = (ProtocolType::HTTP1_1, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_tls(data) {
detections[detection_count] = (ProtocolType::TLS, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_ssh(data) {
detections[detection_count] = (ProtocolType::SSH, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_websocket(data) {
detections[detection_count] = (ProtocolType::WebSocket, confidence);
detection_count += 1;
}
for i in 0..detection_count {
let (protocol, confidence) = detections[i];
if confidence > best_confidence {
best_confidence = confidence;
best_protocol = protocol;
}
}
if best_confidence < self.confidence_threshold {
return Err(DetectorError::detection_failed(
format!("Confidence {} below threshold {}",
best_confidence, self.confidence_threshold)
));
}
let protocol_info = ProtocolInfo::new(best_protocol, best_confidence);
Ok(DetectionResult::new(
protocol_info,
std::time::Duration::from_millis(0), DetectionMethod::Passive,
"PassiveProbe".to_string(),
))
}
fn probe_type(&self) -> ProbeType {
ProbeType::Passive
}
fn needs_more_data(&self, data: &[u8]) -> bool {
data.len() < self.min_data_size
}
}
impl Default for PassiveProbe {
fn default() -> Self {
Self::new()
}
}
impl ProtocolProbe for PassiveProbe {
fn name(&self) -> &'static str {
"PassiveProbe"
}
fn supported_protocols(&self) -> Vec<ProtocolType> {
vec![
ProtocolType::HTTP1_1,
ProtocolType::HTTP2,
ProtocolType::HTTP3,
ProtocolType::QUIC,
ProtocolType::GRPC,
ProtocolType::WebSocket,
ProtocolType::TLS,
ProtocolType::SSH,
ProtocolType::UDP,
]
}
fn probe(&self, data: &[u8], context: &mut ProbeContext) -> Result<Option<ProtocolInfo>> {
if data.len() < self.min_data_size {
return Ok(None);
}
let mut best_protocol = ProtocolType::Unknown;
let mut best_confidence = 0.0;
let mut detections = [(ProtocolType::Unknown, 0.0); 8];
let mut detection_count = 0;
if let Some(confidence) = self.detect_http3(data) {
detections[detection_count] = (ProtocolType::HTTP3, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_quic(data) {
detections[detection_count] = (ProtocolType::QUIC, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_http2(data) {
detections[detection_count] = (ProtocolType::HTTP2, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_grpc(data) {
detections[detection_count] = (ProtocolType::GRPC, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_http1(data) {
detections[detection_count] = (ProtocolType::HTTP1_1, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_tls(data) {
detections[detection_count] = (ProtocolType::TLS, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_ssh(data) {
detections[detection_count] = (ProtocolType::SSH, confidence);
detection_count += 1;
}
if let Some(confidence) = self.detect_websocket(data) {
detections[detection_count] = (ProtocolType::WebSocket, confidence);
detection_count += 1;
}
for i in 0..detection_count {
let (protocol, confidence) = detections[i];
if confidence > best_confidence {
best_confidence = confidence;
best_protocol = protocol;
}
}
if best_confidence >= self.confidence_threshold {
let protocol_info = ProtocolInfo::new(best_protocol, best_confidence);
context.add_candidate(protocol_info.clone());
Ok(Some(protocol_info))
} else {
Ok(None)
}
}
fn priority(&self) -> u8 {
80 }
fn needs_more_data(&self, data: &[u8]) -> bool {
data.len() < self.min_data_size
}
}