use std::collections::VecDeque;
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Prediction {
pub rtt: Duration,
pub confidence: f64,
pub degradation_likely: bool,
pub improvement_likely: bool,
pub recommendation: PredictionRecommendation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PredictionRecommendation {
Continue,
ConsiderSwitch,
ReduceLoad,
IncreaseLoad,
Failover,
}
#[derive(Debug, Clone)]
struct Sample {
rtt_ms: f64,
loss: f64,
}
pub struct QualityPredictor {
samples: VecDeque<Sample>,
max_samples: usize,
alpha: f64,
beta: f64,
level: f64,
trend: f64,
variance: f64,
}
impl QualityPredictor {
pub fn new(max_samples: usize) -> Self {
Self {
samples: VecDeque::with_capacity(max_samples),
max_samples,
alpha: 0.3, beta: 0.1, level: 0.0,
trend: 0.0,
variance: 0.0,
}
}
pub fn add_sample(&mut self, rtt: Duration, loss: f64) {
let rtt_ms = rtt.as_secs_f64() * 1000.0;
let sample = Sample { rtt_ms, loss };
if self.samples.is_empty() {
self.level = rtt_ms;
self.trend = 0.0;
} else {
let prev_level = self.level;
self.level = self.alpha * rtt_ms + (1.0 - self.alpha) * (prev_level + self.trend);
self.trend = self.beta * (self.level - prev_level) + (1.0 - self.beta) * self.trend;
let error = rtt_ms - prev_level;
self.variance = self.alpha * error * error + (1.0 - self.alpha) * self.variance;
}
if self.samples.len() >= self.max_samples {
self.samples.pop_front();
}
self.samples.push_back(sample);
}
pub fn predict(&self, horizon: Duration) -> Prediction {
if self.samples.len() < 3 {
return Prediction {
rtt: Duration::ZERO,
confidence: 0.0,
degradation_likely: false,
improvement_likely: false,
recommendation: PredictionRecommendation::Continue,
};
}
let steps = horizon.as_secs_f64();
let predicted_rtt_ms = self.level + self.trend * steps;
let std_dev = self.variance.sqrt();
let sample_confidence = (self.samples.len() as f64 / self.max_samples as f64).min(1.0);
let variance_confidence = 1.0 / (1.0 + std_dev / 100.0);
let confidence = sample_confidence * variance_confidence;
let degradation_likely = self.trend > 5.0 && predicted_rtt_ms > self.level * 1.2;
let improvement_likely = self.trend < -5.0 && predicted_rtt_ms < self.level * 0.8;
let recommendation = self.recommend(predicted_rtt_ms, degradation_likely);
Prediction {
rtt: Duration::from_secs_f64(predicted_rtt_ms.max(0.0) / 1000.0),
confidence,
degradation_likely,
improvement_likely,
recommendation,
}
}
fn recommend(&self, predicted_rtt_ms: f64, degrading: bool) -> PredictionRecommendation {
let recent_loss: f64 = self.samples.iter()
.rev()
.take(5)
.map(|s| s.loss)
.sum::<f64>() / 5.0;
if recent_loss > 0.2 || predicted_rtt_ms > 1000.0 {
PredictionRecommendation::Failover
} else if degrading && predicted_rtt_ms > 200.0 {
PredictionRecommendation::ConsiderSwitch
} else if recent_loss > 0.05 || predicted_rtt_ms > 100.0 {
PredictionRecommendation::ReduceLoad
} else if recent_loss < 0.01 && predicted_rtt_ms < 50.0 {
PredictionRecommendation::IncreaseLoad
} else {
PredictionRecommendation::Continue
}
}
pub fn detect_anomaly(&self, threshold_stddev: f64) -> bool {
if self.samples.len() < 10 {
return false;
}
let std_dev = self.variance.sqrt();
if std_dev == 0.0 {
return false;
}
if let Some(latest) = self.samples.back() {
let z_score = (latest.rtt_ms - self.level).abs() / std_dev;
z_score > threshold_stddev
} else {
false
}
}
pub fn trend_direction(&self) -> TrendDirection {
if self.trend > 1.0 {
TrendDirection::Increasing
} else if self.trend < -1.0 {
TrendDirection::Decreasing
} else {
TrendDirection::Stable
}
}
pub fn current_estimate(&self) -> Duration {
Duration::from_secs_f64(self.level.max(0.0) / 1000.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrendDirection {
Increasing,
Decreasing,
Stable,
}
impl Default for QualityPredictor {
fn default() -> Self {
Self::new(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_predictor_basic() {
let mut predictor = QualityPredictor::new(100);
for i in 0..20 {
let rtt = Duration::from_millis(50 + i);
predictor.add_sample(rtt, 0.0);
}
let prediction = predictor.predict(Duration::from_secs(1));
assert!(prediction.confidence > 0.0);
assert!(prediction.rtt > Duration::ZERO);
}
#[test]
fn test_trend_detection() {
let mut predictor = QualityPredictor::new(100);
for i in 0..20 {
let rtt = Duration::from_millis(50 + i * 10);
predictor.add_sample(rtt, 0.0);
}
assert_eq!(predictor.trend_direction(), TrendDirection::Increasing);
}
}