use crate::core::{ProtocolType, DetectionResult, ProtocolInfo};
use crate::core::detector::DetectionMethod;
use crate::error::{Result, DetectorError};
use super::{ProbeEngine, ProbeType};
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub struct ActiveProbe {
timeout: Duration,
max_retries: u32,
aggressive_mode: bool,
}
impl ActiveProbe {
pub fn new() -> Self {
Self {
timeout: Duration::from_millis(1000),
max_retries: 3,
aggressive_mode: false,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn with_aggressive_mode(mut self, enabled: bool) -> Self {
self.aggressive_mode = enabled;
self
}
fn generate_http1_probe(&self) -> Vec<u8> {
b"GET / HTTP/1.1\r\nHost: probe\r\nConnection: close\r\n\r\n".to_vec()
}
fn generate_http2_probe(&self) -> Vec<u8> {
let mut probe = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let settings_frame = [
0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, ];
probe.extend_from_slice(&settings_frame);
probe
}
fn generate_quic_probe(&self) -> Vec<u8> {
let mut probe = Vec::new();
probe.push(0xc0); probe.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
probe.push(0x08); probe.extend_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
probe.push(0x00);
probe.push(0x00);
probe.extend_from_slice(&[0x40, 0x10]);
probe.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
probe.extend_from_slice(&[0x00; 12]);
probe
}
fn generate_websocket_probe(&self) -> Vec<u8> {
let probe = format!(
"GET /ws HTTP/1.1\r\n\
Host: probe\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\r\n"
);
probe.into_bytes()
}
fn analyze_response(&self, probe_type: ProtocolType, response: &[u8]) -> Option<f32> {
match probe_type {
ProtocolType::HTTP1_1 => self.analyze_http1_response(response),
ProtocolType::HTTP2 => self.analyze_http2_response(response),
ProtocolType::QUIC => self.analyze_quic_response(response),
ProtocolType::WebSocket => self.analyze_websocket_response(response),
_ => None,
}
}
fn analyze_http1_response(&self, response: &[u8]) -> Option<f32> {
if response.is_empty() {
return None;
}
let response_str = String::from_utf8_lossy(response);
if response_str.starts_with("HTTP/1.1") {
return Some(0.95);
}
if response_str.starts_with("HTTP/1.0") {
return Some(0.9);
}
None
}
fn analyze_http2_response(&self, response: &[u8]) -> Option<f32> {
if response.len() < 9 {
return None;
}
if response.len() >= 9 {
let frame_type = response[3];
if frame_type == 0x04 { return Some(0.9);
}
}
if response.len() >= 9 {
let frame_type = response[3];
if frame_type == 0x07 { return Some(0.7); }
}
None
}
fn analyze_quic_response(&self, response: &[u8]) -> Option<f32> {
if response.is_empty() {
return None;
}
let first_byte = response[0];
if (first_byte & 0x80) != 0 { return Some(0.8);
}
if response.len() >= 5 {
let version = u32::from_be_bytes([
response[1], response[2], response[3], response[4]
]);
if version == 0 {
return Some(0.9); }
}
None
}
fn analyze_websocket_response(&self, response: &[u8]) -> Option<f32> {
let response_str = String::from_utf8_lossy(response);
if response_str.contains("HTTP/1.1 101 Switching Protocols") &&
response_str.contains("Upgrade: websocket") {
return Some(0.95);
}
if response_str.contains("HTTP/1.1 400 Bad Request") &&
response_str.contains("websocket") {
return Some(0.7);
}
None
}
fn probe_protocol(&self, protocol: ProtocolType, _target_data: &[u8]) -> Result<f32> {
let probe_data = match protocol {
ProtocolType::HTTP1_1 => self.generate_http1_probe(),
ProtocolType::HTTP2 => self.generate_http2_probe(),
ProtocolType::QUIC => self.generate_quic_probe(),
ProtocolType::WebSocket => self.generate_websocket_probe(),
_ => return Err(DetectorError::unsupported_protocol(format!("{:?}", protocol))),
};
std::thread::sleep(Duration::from_millis(10));
let simulated_confidence = match protocol {
ProtocolType::HTTP1_1 => 0.8,
ProtocolType::HTTP2 => 0.7,
ProtocolType::QUIC => 0.6,
ProtocolType::WebSocket => 0.75,
_ => 0.0,
};
let confidence = if self.aggressive_mode {
simulated_confidence * 1.2
} else {
simulated_confidence
};
Ok(confidence.min(1.0))
}
}
impl ProbeEngine for ActiveProbe {
fn probe(&self, data: &[u8]) -> Result<DetectionResult> {
let start_time = Instant::now();
let protocols = [
ProtocolType::HTTP1_1,
ProtocolType::HTTP2,
ProtocolType::QUIC,
ProtocolType::WebSocket,
];
let mut best_protocol = ProtocolType::Unknown;
let mut best_confidence = 0.0;
let mut metadata = HashMap::new();
for protocol in &protocols {
if start_time.elapsed() > self.timeout {
metadata.insert("timeout".to_string(), "true".to_string());
break;
}
match self.probe_protocol(*protocol, data) {
Ok(confidence) => {
if confidence > best_confidence {
best_confidence = confidence;
best_protocol = *protocol;
}
metadata.insert(
format!("{:?}_confidence", protocol),
confidence.to_string()
);
}
Err(e) => {
metadata.insert(
format!("{:?}_error", protocol),
e.to_string()
);
}
}
}
metadata.insert("probe_duration_ms".to_string(),
start_time.elapsed().as_millis().to_string());
metadata.insert("aggressive_mode".to_string(),
self.aggressive_mode.to_string());
if best_confidence < 0.5 {
return Err(DetectorError::detection_failed(
"Active probe confidence too low"
));
}
let protocol_info = ProtocolInfo::new(best_protocol, best_confidence);
Ok(DetectionResult::new(
protocol_info,
start_time.elapsed(),
DetectionMethod::Active,
"ActiveProbe".to_string(),
))
}
fn probe_type(&self) -> ProbeType {
ProbeType::Active
}
fn needs_more_data(&self, _data: &[u8]) -> bool {
false }
}
impl Default for ActiveProbe {
fn default() -> Self {
Self::new()
}
}