use crate::core::protocol::{ProtocolType, ProtocolInfo};
use crate::core::detector::{DetectionResult, DetectionMethod};
use crate::error::{DetectorError, Result};
use std::time::{Duration, Instant};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum ProbeStrategy {
Passive,
Active,
Hybrid,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct ProbeConfig {
pub strategy: ProbeStrategy,
pub max_probe_time: Duration,
pub min_confidence: f32,
pub enable_simd: bool,
pub enable_heuristic: bool,
pub buffer_size: usize,
}
impl Default for ProbeConfig {
fn default() -> Self {
Self {
strategy: ProbeStrategy::Passive, max_probe_time: Duration::from_millis(100),
min_confidence: 0.8,
enable_simd: true,
enable_heuristic: true,
buffer_size: 4096,
}
}
}
#[derive(Debug)]
pub struct ProbeContext {
pub start_time: Instant,
pub bytes_read: usize,
pub attempt_count: u32,
pub current_confidence: f32,
pub candidates: Vec<ProtocolInfo>,
}
impl ProbeContext {
pub fn new() -> Self {
Self {
start_time: Instant::now(),
bytes_read: 0,
attempt_count: 0,
current_confidence: 0.0,
candidates: Vec::new(),
}
}
pub fn add_candidate(&mut self, protocol: ProtocolInfo) {
self.candidates.push(protocol);
if let Some(max_confidence) = self.candidates.iter().map(|p| p.confidence).fold(None, |acc: Option<f32>, x| {
Some(acc.map_or(x, |a: f32| a.max(x)))
}) {
self.current_confidence = max_confidence;
}
}
pub fn best_candidate(&self) -> Option<&ProtocolInfo> {
self.candidates.iter().max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
}
pub fn is_timeout(&self, max_time: Duration) -> bool {
self.start_time.elapsed() > max_time
}
}
pub trait ProtocolProbe: Send + Sync {
fn name(&self) -> &'static str;
fn supported_protocols(&self) -> Vec<ProtocolType>;
fn probe(&self, data: &[u8], context: &mut ProbeContext) -> Result<Option<ProtocolInfo>>;
fn priority(&self) -> u8 {
50
}
fn needs_more_data(&self, data: &[u8]) -> bool {
data.len() < 64 }
}
#[derive(Default)]
pub struct ProbeRegistry {
probes: HashMap<ProtocolType, Vec<Box<dyn ProtocolProbe>>>,
global_probes: Vec<Box<dyn ProtocolProbe>>,
}
impl std::fmt::Debug for ProbeRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProbeRegistry")
.field("probes_count", &self.probes.len())
.field("global_probes_count", &self.global_probes.len())
.finish()
}
}
impl ProbeRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_probe(&mut self, protocol: ProtocolType, probe: Box<dyn ProtocolProbe>) {
self.probes.entry(protocol).or_insert_with(Vec::new).push(probe);
}
pub fn register_global_probe(&mut self, probe: Box<dyn ProtocolProbe>) {
self.global_probes.push(probe);
}
pub fn get_probes(&self, protocol: ProtocolType) -> Vec<&dyn ProtocolProbe> {
let mut probes = Vec::new();
if let Some(protocol_probes) = self.probes.get(&protocol) {
probes.extend(protocol_probes.iter().map(|p| p.as_ref()));
}
for probe in &self.global_probes {
if probe.supported_protocols().contains(&protocol) {
probes.push(probe.as_ref());
}
}
probes.sort_by(|a, b| b.priority().cmp(&a.priority()));
probes
}
pub fn get_probes_for_enabled_protocol(&self, protocol: ProtocolType, enabled_protocols: &[ProtocolType]) -> Vec<&dyn ProtocolProbe> {
let mut probes = Vec::new();
if let Some(protocol_probes) = self.probes.get(&protocol) {
probes.extend(protocol_probes.iter().map(|p| p.as_ref()));
}
for probe in &self.global_probes {
let supported = probe.supported_protocols();
if supported.contains(&protocol) {
let has_enabled_protocol = supported.iter().any(|p| enabled_protocols.contains(p));
if has_enabled_protocol {
probes.push(probe.as_ref());
}
}
}
probes.sort_by(|a, b| b.priority().cmp(&a.priority()));
probes
}
#[deprecated(note = "使用 get_probes_for_enabled_protocol 以获得更好的性能")]
pub fn get_all_probes(&self) -> Vec<&dyn ProtocolProbe> {
let mut probes = Vec::new();
for protocol_probes in self.probes.values() {
probes.extend(protocol_probes.iter().map(|p| p.as_ref()));
}
probes.extend(self.global_probes.iter().map(|p| p.as_ref()));
probes.sort_by(|a, b| b.priority().cmp(&a.priority()));
probes
}
}
#[derive(Debug)]
pub struct ProbeAggregator {
config: ProbeConfig,
}
impl ProbeAggregator {
pub fn new(config: ProbeConfig) -> Self {
Self { config }
}
pub fn aggregate(&self, results: Vec<ProtocolInfo>) -> Option<ProtocolInfo> {
if results.is_empty() {
return None;
}
let valid_results: Vec<ProtocolInfo> = results.into_iter()
.filter(|info| info.protocol_type != ProtocolType::Unknown)
.collect();
if valid_results.is_empty() {
return None;
}
let mut sorted_results = valid_results;
sorted_results.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let best = &sorted_results[0];
if best.confidence >= self.config.min_confidence {
Some(best.clone())
} else {
None
}
}
pub fn create_result(
&self,
protocol_info: ProtocolInfo,
duration: Duration,
detector_name: String,
) -> DetectionResult {
let method = match self.config.strategy {
ProbeStrategy::Passive => DetectionMethod::Passive,
ProbeStrategy::Active => DetectionMethod::Active,
ProbeStrategy::Hybrid | ProbeStrategy::Adaptive => DetectionMethod::Hybrid,
};
DetectionResult::new(protocol_info, duration, method, detector_name)
}
}