use crate::core::protocol::{ProtocolType, ProtocolInfo};
use crate::core::probe::{ProbeRegistry, ProbeConfig, ProbeContext, ProbeAggregator};
use crate::core::magic::MagicDetector;
use crate::error::{DetectorError, Result};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Role {
Server,
Client,
}
impl Role {
pub fn is_server(&self) -> bool {
matches!(self, Role::Server)
}
pub fn is_client(&self) -> bool {
matches!(self, Role::Client)
}
}
#[derive(Debug, Clone)]
pub struct AgentConfig {
pub role: Role,
pub instance_id: String,
pub detection_config: DetectionConfig,
pub probe_config: ProbeConfig,
pub enabled_protocols: Vec<ProtocolType>,
pub enable_upgrade: bool,
pub load_balancer_config: Option<LoadBalancerConfig>,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
role: Role::Server,
instance_id: uuid::Uuid::new_v4().to_string(),
detection_config: DetectionConfig::default(),
probe_config: ProbeConfig::default(),
enabled_protocols: vec![
ProtocolType::HTTP1_1,
ProtocolType::HTTP2,
ProtocolType::TLS,
],
enable_upgrade: true,
load_balancer_config: None,
}
}
}
#[derive(Debug, Clone)]
pub struct LoadBalancerConfig {
pub is_load_balancer: bool,
pub backend_instances: Vec<String>,
pub strategy: LoadBalanceStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadBalanceStrategy {
RoundRobin,
LeastConnections,
WeightedRoundRobin,
ConsistentHash,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DetectionResult {
pub protocol_info: ProtocolInfo,
pub detection_time: Duration,
pub detection_method: DetectionMethod,
pub detector_name: String,
}
impl DetectionResult {
pub fn new(
protocol_info: ProtocolInfo,
detection_time: Duration,
detection_method: DetectionMethod,
detector_name: String,
) -> Self {
Self {
protocol_info,
detection_time,
detection_method,
detector_name,
}
}
pub fn protocol_type(&self) -> ProtocolType {
self.protocol_info.protocol_type
}
pub fn confidence(&self) -> f32 {
self.protocol_info.confidence
}
pub fn is_high_confidence(&self) -> bool {
self.confidence() >= 0.8
}
pub fn is_acceptable(&self, min_confidence: f32) -> bool {
self.confidence() >= min_confidence
}
}
#[derive(Debug)]
pub struct DefaultProtocolDetector {
registry: ProbeRegistry,
probe_config: ProbeConfig,
detection_config: DetectionConfig,
enabled_protocols: Vec<ProtocolType>,
aggregator: ProbeAggregator,
magic_detector: MagicDetector,
}
impl DefaultProtocolDetector {
pub fn new(
registry: ProbeRegistry,
probe_config: ProbeConfig,
detection_config: DetectionConfig,
enabled_protocols: Vec<ProtocolType>,
) -> Result<Self> {
let aggregator = ProbeAggregator::new(probe_config.clone());
if enabled_protocols.is_empty() {
return Err(DetectorError::config_error(
"至少需要启用一个协议,否则无法进行协议检测"
));
}
let magic_detector = MagicDetector::new()
.with_enabled_protocols(enabled_protocols.clone());
Ok(Self {
registry,
probe_config,
detection_config,
enabled_protocols,
aggregator,
magic_detector,
})
}
pub fn probe_config(&self) -> &ProbeConfig {
&self.probe_config
}
pub fn detection_config(&self) -> &DetectionConfig {
&self.detection_config
}
pub fn enabled_protocols(&self) -> &[ProtocolType] {
&self.enabled_protocols
}
}
impl ProtocolDetector for DefaultProtocolDetector {
fn detect(&self, data: &[u8]) -> Result<DetectionResult> {
let start_time = Instant::now();
let mut context = ProbeContext::new();
context.bytes_read = data.len();
if data.len() < self.min_probe_size() {
return Err(DetectorError::InsufficientData(
format!("需要至少 {} 字节,但只有 {} 字节", self.min_probe_size(), data.len())
));
}
if data.len() > self.max_probe_size() {
return Err(DetectorError::DataTooLarge(
format!("数据大小 {} 字节超过最大限制 {} 字节", data.len(), self.max_probe_size())
));
}
if let Some(magic_result) = self.magic_detector.quick_detect(data) {
if magic_result.confidence >= 0.95 {
let detection_time = start_time.elapsed();
return Ok(DetectionResult::new(
magic_result,
detection_time,
DetectionMethod::SimdAccelerated, "MagicBytesDetector".to_string(),
));
}
context.add_candidate(magic_result);
}
let mut all_results = Vec::with_capacity(self.enabled_protocols.len());
let max_detection_time = self.detection_config.timeout;
let mut processed_probes = std::collections::HashSet::new();
for &protocol in &self.enabled_protocols {
if start_time.elapsed() > max_detection_time {
break;
}
let probes = self.registry.get_probes_for_enabled_protocol(protocol, &self.enabled_protocols);
for probe in probes {
let probe_name = probe.name();
if processed_probes.contains(probe_name) {
continue;
}
processed_probes.insert(probe_name);
if probe.needs_more_data(data) {
continue;
}
match probe.probe(data, &mut context) {
Ok(Some(protocol_info)) => {
if self.enabled_protocols.contains(&protocol_info.protocol_type) {
let high_confidence = protocol_info.confidence >= 0.9;
all_results.push(protocol_info);
if high_confidence {
break;
}
}
}
Ok(None) => {
}
Err(_) => {
}
}
if processed_probes.len() % 5 == 0 && start_time.elapsed() > max_detection_time {
break;
}
}
}
if all_results.is_empty() {
let deep_magic_results = self.magic_detector.deep_detect(data);
all_results.extend(deep_magic_results);
}
all_results.extend(context.candidates.clone());
let best_result = self.aggregator.aggregate(all_results)
.ok_or_else(|| DetectorError::NoProtocolDetected("未检测到任何协议".to_string()))?;
let detection_time = start_time.elapsed();
Ok(self.aggregator.create_result(
best_result,
detection_time,
"DefaultProtocolDetector".to_string(),
))
}
fn min_probe_size(&self) -> usize {
self.detection_config.min_probe_size
}
fn max_probe_size(&self) -> usize {
self.detection_config.max_probe_size
}
fn supported_protocols(&self) -> Vec<ProtocolType> {
self.enabled_protocols.clone()
}
fn name(&self) -> &str {
"DefaultProtocolDetector"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DetectionMethod {
Passive,
Active,
Heuristic,
SimdAccelerated,
Hybrid,
}
pub trait ProtocolAgent: Send + Sync + std::fmt::Debug {
fn detect(&self, data: &[u8]) -> Result<DetectionResult>;
fn probe_capabilities(&self, transport: &mut dyn Transport) -> Result<Vec<ProtocolType>> {
match self.role() {
Role::Client => {
self.active_probe(transport)
},
Role::Server => {
Err(DetectorError::unsupported_protocol(
"Server role does not support active probing"
))
},
}
}
fn active_probe(&self, transport: &mut dyn Transport) -> Result<Vec<ProtocolType>> {
let mut supported_protocols = Vec::new();
let probe_order = [
ProtocolType::HTTP3,
ProtocolType::HTTP2,
ProtocolType::HTTP1_1,
ProtocolType::TLS,
ProtocolType::QUIC,
];
for protocol in probe_order {
if self.supports_protocol(protocol) {
match self.send_protocol_probe(transport, protocol) {
Ok(true) => {
supported_protocols.push(protocol);
if matches!(protocol, ProtocolType::HTTP3 | ProtocolType::QUIC) {
break;
}
},
Ok(false) => continue,
Err(_) => continue, }
}
}
if supported_protocols.is_empty() {
supported_protocols.push(ProtocolType::HTTP1_1);
}
Ok(supported_protocols)
}
fn send_protocol_probe(&self, transport: &mut dyn Transport, protocol: ProtocolType) -> Result<bool> {
match protocol {
ProtocolType::HTTP2 => {
let h2_preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
transport.write(h2_preface)?;
let mut response = vec![0u8; 24];
match transport.read(&mut response) {
Ok(n) if n > 0 => {
Ok(response.len() >= 9 && response[3] == 0x04) },
_ => Ok(false),
}
},
ProtocolType::HTTP3 => {
Ok(false) },
ProtocolType::HTTP1_1 => {
let options_request = b"OPTIONS * HTTP/1.1\r\nHost: probe\r\n\r\n";
transport.write(options_request)?;
let mut response = vec![0u8; 256];
match transport.read(&mut response) {
Ok(n) if n > 0 => {
let response_str = String::from_utf8_lossy(&response[..n]);
Ok(response_str.starts_with("HTTP/1.1") || response_str.starts_with("HTTP/1.0"))
},
_ => Ok(false),
}
},
ProtocolType::TLS => {
let client_hello = self.create_tls_client_hello();
transport.write(&client_hello)?;
let mut response = vec![0u8; 1024];
match transport.read(&mut response) {
Ok(n) if n >= 5 => {
Ok(response[0] == 0x16 && response[1] == 0x03) },
_ => Ok(false),
}
},
_ => Ok(false), }
}
fn create_tls_client_hello(&self) -> Vec<u8> {
vec![
0x16, 0x03, 0x01, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x2b, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x02, 0x00, 0x35, 0x01, 0x00, ]
}
fn upgrade(
&self,
transport: Box<dyn Transport>,
role: Role,
) -> Result<Box<dyn Transport>>;
fn negotiate_protocol(&self, max_supported: ProtocolType) -> ProtocolType {
let candidates = [
ProtocolType::HTTP3,
ProtocolType::HTTP2,
ProtocolType::HTTP1_1,
];
for proto in candidates {
if self.supports_protocol(proto) && proto <= max_supported {
return proto;
}
}
ProtocolType::HTTP1_1 }
fn auto_fallback(&self, transport: &mut dyn Transport, preferred: ProtocolType) -> Result<ProtocolType> {
match self.role() {
Role::Client => {
if self.send_protocol_probe(transport, preferred)? {
return Ok(preferred);
}
let fallback_chain = match preferred {
ProtocolType::HTTP3 => vec![ProtocolType::HTTP2, ProtocolType::HTTP1_1],
ProtocolType::HTTP2 => vec![ProtocolType::HTTP1_1],
ProtocolType::QUIC => vec![ProtocolType::TLS, ProtocolType::TCP],
_ => vec![ProtocolType::HTTP1_1],
};
for fallback in fallback_chain {
if self.supports_protocol(fallback) &&
self.send_protocol_probe(transport, fallback)? {
return Ok(fallback);
}
}
Ok(ProtocolType::HTTP1_1)
},
Role::Server => {
Err(DetectorError::unsupported_protocol(
"Server role does not support auto fallback"
))
},
}
}
fn supports_protocol(&self, protocol: ProtocolType) -> bool;
fn role(&self) -> Role;
fn instance_id(&self) -> &str;
fn name(&self) -> &str;
}
pub trait Transport: Send + Sync {
fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
fn write(&mut self, data: &[u8]) -> Result<usize>;
fn peek(&self, size: usize) -> Result<Vec<u8>>;
fn close(&mut self) -> Result<()>;
fn transport_type(&self) -> &str;
}
pub trait ProtocolDetector: Send + Sync + std::fmt::Debug {
fn detect(&self, data: &[u8]) -> Result<DetectionResult>;
fn confidence(&self, data: &[u8]) -> Result<f32> {
self.detect(data).map(|result| result.confidence())
}
fn min_probe_size(&self) -> usize {
64 }
fn max_probe_size(&self) -> usize {
4096 }
fn supported_protocols(&self) -> Vec<ProtocolType>;
fn name(&self) -> &str;
fn can_detect(&self, protocol: ProtocolType) -> bool {
self.supported_protocols().contains(&protocol)
}
fn detect_batch(&self, data_chunks: &[&[u8]]) -> Result<Vec<DetectionResult>> {
data_chunks
.iter()
.map(|chunk| self.detect(chunk))
.collect()
}
}
#[cfg(any(feature = "runtime-tokio", feature = "runtime-async-std"))]
#[async_trait::async_trait]
pub trait AsyncProtocolDetector: Send + Sync {
async fn detect_async(&self, data: &[u8]) -> Result<DetectionResult>;
async fn confidence_async(&self, data: &[u8]) -> Result<f32> {
self.detect_async(data).await.map(|result| result.confidence())
}
fn min_probe_size(&self) -> usize {
64
}
fn max_probe_size(&self) -> usize {
4096
}
fn supported_protocols(&self) -> Vec<ProtocolType>;
fn name(&self) -> &str;
async fn detect_batch_async(&self, data_chunks: &[&[u8]]) -> Result<Vec<DetectionResult>> {
let mut results = Vec::new();
for chunk in data_chunks {
results.push(self.detect_async(chunk).await?);
}
Ok(results)
}
}
#[derive(Debug, Clone)]
pub struct DetectionConfig {
pub min_confidence: f32,
pub timeout: Duration,
pub enable_heuristic: bool,
pub enable_active_probing: bool,
pub max_probe_size: usize,
pub min_probe_size: usize,
pub enable_simd: bool,
}
impl Default for DetectionConfig {
fn default() -> Self {
Self {
min_confidence: 0.7,
timeout: Duration::from_millis(1000),
enable_heuristic: true,
enable_active_probing: false,
max_probe_size: 1024 * 1024, min_probe_size: 16, enable_simd: true,
}
}
}
#[derive(Debug)]
pub struct Agent {
config: AgentConfig,
detector: Arc<dyn ProtocolDetector>,
upgrader: Option<Arc<dyn crate::upgrade::ProtocolUpgrader>>,
state: Arc<std::sync::RwLock<AgentState>>,
load_balancer: Option<Arc<LoadBalancer>>,
}
#[derive(Debug, Clone)]
pub struct AgentState {
pub active_connections: usize,
pub total_requests: u64,
pub successful_upgrades: u64,
pub failed_upgrades: u64,
pub last_activity: Instant,
pub is_healthy: bool,
}
impl Default for AgentState {
fn default() -> Self {
Self {
active_connections: 0,
total_requests: 0,
successful_upgrades: 0,
failed_upgrades: 0,
last_activity: Instant::now(),
is_healthy: true,
}
}
}
#[derive(Debug)]
pub struct LoadBalancer {
config: LoadBalancerConfig,
backends: Arc<std::sync::RwLock<HashMap<String, BackendState>>>,
round_robin_index: Arc<std::sync::atomic::AtomicUsize>,
}
#[derive(Debug, Clone)]
pub struct BackendState {
pub instance_id: String,
pub active_connections: usize,
pub weight: u32,
pub is_healthy: bool,
pub last_health_check: Instant,
}
impl Agent {
pub fn new(
config: AgentConfig,
detector: Arc<dyn ProtocolDetector>,
upgrader: Option<Arc<dyn crate::upgrade::ProtocolUpgrader>>,
) -> Self {
let load_balancer = config.load_balancer_config.as_ref().map(|lb_config| {
Arc::new(LoadBalancer::new(lb_config.clone()))
});
Self {
config,
detector,
upgrader,
state: Arc::new(std::sync::RwLock::new(AgentState::default())),
load_balancer,
}
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
pub fn state(&self) -> Result<AgentState> {
self.state.read()
.map_err(|_| DetectorError::internal_error("Failed to read agent state"))
.map(|state| state.clone())
}
pub fn update_connection_count(&self, delta: i32) -> Result<()> {
if let Ok(mut state) = self.state.try_write() {
if delta > 0 {
state.active_connections += delta as usize;
} else {
state.active_connections = state.active_connections.saturating_sub((-delta) as usize);
}
state.last_activity = Instant::now();
} else {
rat_logger::warn!("Failed to acquire lock for connection count update");
}
Ok(())
}
pub fn select_backend(&self) -> Option<String> {
self.load_balancer.as_ref()?.select_backend()
}
pub fn health_check(&self) -> bool {
if let Ok(state) = self.state.read() {
state.is_healthy && state.last_activity.elapsed() < Duration::from_secs(300)
} else {
false
}
}
}
impl LoadBalancer {
pub fn new(config: LoadBalancerConfig) -> Self {
let backends = config.backend_instances.iter()
.map(|instance_id| {
let state = BackendState {
instance_id: instance_id.clone(),
active_connections: 0,
weight: 1,
is_healthy: true,
last_health_check: Instant::now(),
};
(instance_id.clone(), state)
})
.collect();
Self {
config,
backends: Arc::new(std::sync::RwLock::new(backends)),
round_robin_index: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
pub fn select_backend(&self) -> Option<String> {
let backends = self.backends.read().ok()?;
let healthy_backends: Vec<_> = backends.values()
.filter(|backend| backend.is_healthy)
.collect();
if healthy_backends.is_empty() {
return None;
}
match self.config.strategy {
LoadBalanceStrategy::RoundRobin => {
let index = self.round_robin_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(healthy_backends[index % healthy_backends.len()].instance_id.clone())
},
LoadBalanceStrategy::LeastConnections => {
healthy_backends.iter()
.min_by_key(|backend| backend.active_connections)
.map(|backend| backend.instance_id.clone())
},
LoadBalanceStrategy::WeightedRoundRobin => {
let total_weight: u32 = healthy_backends.iter().map(|b| b.weight).sum();
if total_weight == 0 {
return None;
}
let mut target = (self.round_robin_index.load(std::sync::atomic::Ordering::Relaxed) as u32) % total_weight;
self.round_robin_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
for backend in healthy_backends {
if target < backend.weight {
return Some(backend.instance_id.clone());
}
target -= backend.weight;
}
None
},
LoadBalanceStrategy::ConsistentHash => {
healthy_backends.first().map(|backend| backend.instance_id.clone())
},
}
}
}
impl ProtocolAgent for Agent {
fn detect(&self, data: &[u8]) -> Result<DetectionResult> {
if let Ok(mut state) = self.state.write() {
state.total_requests += 1;
state.last_activity = Instant::now();
}
match self.config.role {
Role::Server => {
self.detector.detect(data)
},
Role::Client => {
self.detector.detect(data)
},
}
}
fn probe_capabilities(&self, transport: &mut dyn Transport) -> Result<Vec<ProtocolType>> {
match self.config.role {
Role::Client => {
if let Ok(mut state) = self.state.write() {
state.total_requests += 1;
state.last_activity = Instant::now();
}
self.active_probe(transport)
},
Role::Server => {
Err(DetectorError::unsupported_protocol(
"Server role does not support active probing"
))
},
}
}
fn active_probe(&self, transport: &mut dyn Transport) -> Result<Vec<ProtocolType>> {
let mut supported_protocols = Vec::new();
let probe_order = [
ProtocolType::HTTP3,
ProtocolType::HTTP2,
ProtocolType::HTTP1_1,
ProtocolType::TLS,
ProtocolType::QUIC,
];
for protocol in probe_order {
if self.config.enabled_protocols.contains(&protocol) {
match self.send_protocol_probe(transport, protocol) {
Ok(true) => {
supported_protocols.push(protocol);
if matches!(protocol, ProtocolType::HTTP3 | ProtocolType::QUIC) {
break;
}
},
Ok(false) => continue,
Err(_) => continue, }
}
}
if supported_protocols.is_empty() {
supported_protocols.push(ProtocolType::HTTP1_1);
}
Ok(supported_protocols)
}
fn send_protocol_probe(&self, transport: &mut dyn Transport, protocol: ProtocolType) -> Result<bool> {
match protocol {
ProtocolType::HTTP2 => {
let h2_preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
transport.write(h2_preface)?;
let mut response = vec![0u8; 24];
match transport.read(&mut response) {
Ok(n) if n > 0 => {
Ok(response.len() >= 9 && response[3] == 0x04) },
_ => Ok(false),
}
},
ProtocolType::HTTP3 => {
Ok(false) },
ProtocolType::HTTP1_1 => {
let options_request = b"OPTIONS * HTTP/1.1\r\nHost: probe\r\n\r\n";
transport.write(options_request)?;
let mut response = vec![0u8; 256];
match transport.read(&mut response) {
Ok(n) if n > 0 => {
let response_str = String::from_utf8_lossy(&response[..n]);
Ok(response_str.starts_with("HTTP/1.1") || response_str.starts_with("HTTP/1.0"))
},
_ => Ok(false),
}
},
ProtocolType::TLS => {
let client_hello = self.create_tls_client_hello();
transport.write(&client_hello)?;
let mut response = vec![0u8; 1024];
match transport.read(&mut response) {
Ok(n) if n >= 5 => {
Ok(response[0] == 0x16 && response[1] == 0x03) },
_ => Ok(false),
}
},
_ => Ok(false), }
}
fn create_tls_client_hello(&self) -> Vec<u8> {
vec![
0x16, 0x03, 0x01, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x2b, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x02, 0x00, 0x35, 0x01, 0x00, ]
}
fn auto_fallback(&self, transport: &mut dyn Transport, preferred: ProtocolType) -> Result<ProtocolType> {
match self.config.role {
Role::Client => {
if self.send_protocol_probe(transport, preferred)? {
return Ok(preferred);
}
let fallback_chain = match preferred {
ProtocolType::HTTP3 => vec![ProtocolType::HTTP2, ProtocolType::HTTP1_1],
ProtocolType::HTTP2 => vec![ProtocolType::HTTP1_1],
ProtocolType::QUIC => vec![ProtocolType::TLS, ProtocolType::TCP],
_ => vec![ProtocolType::HTTP1_1],
};
for fallback in fallback_chain {
if self.config.enabled_protocols.contains(&fallback) &&
self.send_protocol_probe(transport, fallback)? {
return Ok(fallback);
}
}
Ok(ProtocolType::HTTP1_1)
},
Role::Server => {
Err(DetectorError::unsupported_protocol(
"Server role does not support auto fallback"
))
},
}
}
fn upgrade(
&self,
transport: Box<dyn Transport>,
role: Role,
) -> Result<Box<dyn Transport>> {
match &self.upgrader {
Some(upgrader) => {
let result = match role {
Role::Server => {
let current_protocol = ProtocolType::HTTP1_1; let target_protocol = ProtocolType::HTTP2; let data = b""; upgrader.upgrade(current_protocol, target_protocol, data)
},
Role::Client => {
let current_protocol = ProtocolType::HTTP1_1; let target_protocol = ProtocolType::HTTP2; let data = b""; upgrader.upgrade(current_protocol, target_protocol, data)
},
};
if let Ok(mut state) = self.state.write() {
match result {
Ok(_) => state.successful_upgrades += 1,
Err(_) => state.failed_upgrades += 1,
}
state.last_activity = Instant::now();
}
result.map(|_| transport) },
None => Err(DetectorError::unsupported_protocol("Protocol upgrade not supported")),
}
}
fn supports_protocol(&self, protocol: ProtocolType) -> bool {
self.config.enabled_protocols.contains(&protocol)
}
fn role(&self) -> Role {
self.config.role
}
fn instance_id(&self) -> &str {
&self.config.instance_id
}
fn name(&self) -> &str {
match self.config.role {
Role::Server => "PSI Server Agent",
Role::Client => "PSI Client Agent",
}
}
}
impl DetectionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_confidence(mut self, confidence: f32) -> Self {
self.min_confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn enable_heuristic(mut self) -> Self {
self.enable_heuristic = true;
self
}
pub fn disable_heuristic(mut self) -> Self {
self.enable_heuristic = false;
self
}
pub fn enable_active_probing(mut self) -> Self {
self.enable_active_probing = true;
self
}
pub fn disable_active_probing(mut self) -> Self {
self.enable_active_probing = false;
self
}
pub fn with_max_probe_size(mut self, size: usize) -> Self {
self.max_probe_size = size;
self
}
pub fn with_min_probe_size(mut self, size: usize) -> Self {
self.min_probe_size = size;
self
}
pub fn enable_simd(mut self) -> Self {
self.enable_simd = true;
self
}
pub fn disable_simd(mut self) -> Self {
self.enable_simd = false;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DetectionStats {
pub total_detections: u64,
pub successful_detections: u64,
pub failed_detections: u64,
pub avg_detection_time: Duration,
pub protocol_counts: std::collections::HashMap<ProtocolType, u64>,
}
impl DetectionStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_success(&mut self, protocol: ProtocolType, duration: Duration) {
self.total_detections += 1;
self.successful_detections += 1;
self.update_avg_time(duration);
*self.protocol_counts.entry(protocol).or_insert(0) += 1;
}
pub fn record_failure(&mut self, duration: Duration) {
self.total_detections += 1;
self.failed_detections += 1;
self.update_avg_time(duration);
}
pub fn success_rate(&self) -> f64 {
if self.total_detections == 0 {
0.0
} else {
self.successful_detections as f64 / self.total_detections as f64
}
}
pub fn most_common_protocol(&self) -> Option<ProtocolType> {
self.protocol_counts
.iter()
.max_by_key(|(_, count)| *count)
.map(|(protocol, _)| *protocol)
}
fn update_avg_time(&mut self, new_duration: Duration) {
if self.total_detections == 1 {
self.avg_detection_time = new_duration;
} else {
let total_nanos = self.avg_detection_time.as_nanos() * (self.total_detections - 1) as u128
+ new_duration.as_nanos();
self.avg_detection_time = Duration::from_nanos((total_nanos / self.total_detections as u128) as u64);
}
}
}