use crate::core::protocol::{ProtocolType, UpgradePath, UpgradeMethod};
use crate::error::{DetectorError, Result};
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub mod http;
pub mod websocket;
pub use http::HttpUpgrader;
pub use websocket::WebSocketUpgrader;
pub trait ProtocolUpgrader: Send + Sync + std::fmt::Debug {
fn can_upgrade(&self, from: ProtocolType, to: ProtocolType) -> bool;
fn upgrade(&self, from: ProtocolType, to: ProtocolType, data: &[u8]) -> Result<UpgradeResult>;
fn supported_upgrades(&self) -> Vec<UpgradePath>;
fn name(&self) -> &'static str;
fn estimate_upgrade_time(&self, from: ProtocolType, to: ProtocolType) -> Duration {
Duration::from_millis(100) }
fn check_prerequisites(&self, from: ProtocolType, to: ProtocolType, data: &[u8]) -> Result<()> {
if !self.can_upgrade(from, to) {
return Err(DetectorError::upgrade_failed(
format!("{:?}", from),
format!("{:?}", to),
"Upgrade not supported"
));
}
if data.is_empty() {
return Err(DetectorError::upgrade_failed(
format!("{:?}", from),
format!("{:?}", to),
"Cannot upgrade with empty data"
));
}
Ok(())
}
}
#[cfg(any(feature = "runtime-tokio", feature = "runtime-async-std"))]
pub trait AsyncProtocolUpgrader {
async fn upgrade_async(&self, from: ProtocolType, to: ProtocolType, data: &[u8]) -> Result<UpgradeResult>;
async fn check_prerequisites_async(&self, from: ProtocolType, to: ProtocolType, data: &[u8]) -> Result<()>;
}
#[derive(Debug, Clone, PartialEq)]
pub struct UpgradeResult {
pub target_protocol: ProtocolType,
pub success: bool,
pub upgraded_data: Vec<u8>,
pub method: UpgradeMethod,
pub duration: Duration,
pub metadata: HashMap<String, String>,
pub error_message: Option<String>,
}
impl UpgradeResult {
pub fn success(
target_protocol: ProtocolType,
upgraded_data: Vec<u8>,
method: UpgradeMethod,
duration: Duration,
) -> Self {
Self {
target_protocol,
success: true,
upgraded_data,
method,
duration,
metadata: HashMap::new(),
error_message: None,
}
}
pub fn failure(
target_protocol: ProtocolType,
method: UpgradeMethod,
duration: Duration,
error: String,
) -> Self {
Self {
target_protocol,
success: false,
upgraded_data: Vec::new(),
method,
duration,
metadata: HashMap::new(),
error_message: Some(error),
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn is_success(&self) -> bool {
self.success
}
pub fn error(&self) -> Option<&str> {
self.error_message.as_deref()
}
}
#[derive(Debug, Clone)]
pub struct UpgradeConfig {
pub timeout: Duration,
pub auto_upgrade: bool,
pub max_retries: u32,
pub retry_interval: Duration,
pub enable_cache: bool,
pub cache_ttl: Duration,
pub custom_options: HashMap<String, String>,
}
impl Default for UpgradeConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
auto_upgrade: true,
max_retries: 3,
retry_interval: Duration::from_millis(500),
enable_cache: true,
cache_ttl: Duration::from_secs(300),
custom_options: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct UpgradeStats {
pub total_upgrades: u64,
pub successful_upgrades: u64,
pub failed_upgrades: u64,
pub average_upgrade_time: Duration,
pub protocol_upgrades: HashMap<(ProtocolType, ProtocolType), u64>,
pub method_usage: HashMap<UpgradeMethod, u64>,
}
impl UpgradeStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_upgrade(&mut self, result: &UpgradeResult, from: ProtocolType) {
self.total_upgrades += 1;
if result.success {
self.successful_upgrades += 1;
} else {
self.failed_upgrades += 1;
}
let total_time = self.average_upgrade_time.as_nanos() as u64 * (self.total_upgrades - 1)
+ result.duration.as_nanos() as u64;
self.average_upgrade_time = Duration::from_nanos(total_time / self.total_upgrades);
let upgrade_pair = (from, result.target_protocol);
*self.protocol_upgrades.entry(upgrade_pair).or_insert(0) += 1;
*self.method_usage.entry(result.method.clone()).or_insert(0) += 1;
}
pub fn success_rate(&self) -> f64 {
if self.total_upgrades == 0 {
0.0
} else {
self.successful_upgrades as f64 / self.total_upgrades as f64
}
}
pub fn most_common_upgrade(&self) -> Option<(ProtocolType, ProtocolType)> {
self.protocol_upgrades
.iter()
.max_by_key(|(_, &count)| count)
.map(|(&upgrade_pair, _)| upgrade_pair)
}
pub fn reset(&mut self) {
*self = Self::new();
}
}
pub struct UpgradeManager {
upgraders: Vec<Box<dyn ProtocolUpgrader + Send + Sync>>,
config: UpgradeConfig,
stats: UpgradeStats,
cache: HashMap<(ProtocolType, ProtocolType), UpgradeResult>,
cache_timestamps: HashMap<(ProtocolType, ProtocolType), Instant>,
}
impl UpgradeManager {
pub fn new() -> Self {
Self {
upgraders: Vec::new(),
config: UpgradeConfig::default(),
stats: UpgradeStats::new(),
cache: HashMap::new(),
cache_timestamps: HashMap::new(),
}
}
pub fn with_config(config: UpgradeConfig) -> Self {
Self {
upgraders: Vec::new(),
config,
stats: UpgradeStats::new(),
cache: HashMap::new(),
cache_timestamps: HashMap::new(),
}
}
pub fn add_upgrader(&mut self, upgrader: Box<dyn ProtocolUpgrader + Send + Sync>) {
self.upgraders.push(upgrader);
}
pub fn upgrade(&mut self, from: ProtocolType, to: ProtocolType, data: &[u8]) -> Result<UpgradeResult> {
if self.config.enable_cache {
if let Some(cached_result) = self.get_cached_result(from, to) {
return Ok(cached_result);
}
}
let upgrader = self.upgraders
.iter()
.find(|u| u.can_upgrade(from, to))
.ok_or_else(|| DetectorError::upgrade_failed(
format!("{:?}", from),
format!("{:?}", to),
"No upgrader found"
))?;
let start = Instant::now();
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
match upgrader.upgrade(from, to, data) {
Ok(result) => {
let final_result = UpgradeResult {
duration: start.elapsed(),
..result
};
self.stats.record_upgrade(&final_result, from);
if self.config.enable_cache && final_result.success {
self.cache_result(from, to, final_result.clone());
}
return Ok(final_result);
}
Err(e) => {
last_error = Some(e);
if attempt < self.config.max_retries {
std::thread::sleep(self.config.retry_interval);
}
}
}
}
let error_result = UpgradeResult::failure(
to,
UpgradeMethod::Direct,
start.elapsed(),
last_error.map(|e| e.to_string()).unwrap_or_else(|| "Unknown error".to_string()),
);
self.stats.record_upgrade(&error_result, from);
Ok(error_result)
}
pub fn supported_upgrades(&self) -> Vec<UpgradePath> {
self.upgraders
.iter()
.flat_map(|u| u.supported_upgrades())
.collect()
}
pub fn can_upgrade(&self, from: ProtocolType, to: ProtocolType) -> bool {
self.upgraders.iter().any(|u| u.can_upgrade(from, to))
}
pub fn stats(&self) -> &UpgradeStats {
&self.stats
}
pub fn cleanup_cache(&mut self) {
let now = Instant::now();
let ttl = self.config.cache_ttl;
let expired_keys: Vec<_> = self.cache_timestamps
.iter()
.filter(|(_, ×tamp)| now.duration_since(timestamp) > ttl)
.map(|(&key, _)| key)
.collect();
for key in expired_keys {
self.cache.remove(&key);
self.cache_timestamps.remove(&key);
}
}
fn get_cached_result(&mut self, from: ProtocolType, to: ProtocolType) -> Option<UpgradeResult> {
let key = (from, to);
if let Some(×tamp) = self.cache_timestamps.get(&key) {
if Instant::now().duration_since(timestamp) <= self.config.cache_ttl {
return self.cache.get(&key).cloned();
} else {
self.cache.remove(&key);
self.cache_timestamps.remove(&key);
}
}
None
}
fn cache_result(&mut self, from: ProtocolType, to: ProtocolType, result: UpgradeResult) {
let key = (from, to);
self.cache.insert(key, result);
self.cache_timestamps.insert(key, Instant::now());
}
}
impl Default for UpgradeManager {
fn default() -> Self {
let mut manager = Self::new();
manager.add_upgrader(Box::new(HttpUpgrader::new()));
manager.add_upgrader(Box::new(WebSocketUpgrader::new()));
manager
}
}