#![allow(dead_code)]
use std::collections::VecDeque;
use std::fmt;
use std::time::Duration;
use crate::error::{NetError, NetResult};
#[derive(Debug, Clone)]
pub struct EwmaConfig {
pub alpha: f64,
pub loss_penalty_factor: f64,
pub min_kbps: f64,
pub max_kbps: f64,
}
impl Default for EwmaConfig {
fn default() -> Self {
Self {
alpha: 0.2,
loss_penalty_factor: 0.5,
min_kbps: 64.0,
max_kbps: 100_000.0,
}
}
}
impl EwmaConfig {
pub fn validate(&self) -> NetResult<()> {
if self.alpha <= 0.0 || self.alpha > 1.0 {
return Err(NetError::protocol(format!(
"EWMA alpha must be in (0,1], got {}",
self.alpha
)));
}
if self.min_kbps >= self.max_kbps {
return Err(NetError::protocol(format!(
"min_kbps ({}) must be < max_kbps ({})",
self.min_kbps, self.max_kbps
)));
}
if !(0.0..=1.0).contains(&self.loss_penalty_factor) {
return Err(NetError::protocol(format!(
"loss_penalty_factor must be in [0,1], got {}",
self.loss_penalty_factor
)));
}
Ok(())
}
}
#[derive(Debug)]
pub struct EwmaBandwidthEstimator {
config: EwmaConfig,
estimate_kbps: Option<f64>,
recent_raw: VecDeque<f64>,
history_len: usize,
}
impl EwmaBandwidthEstimator {
pub fn new(config: EwmaConfig) -> NetResult<Self> {
config.validate()?;
Ok(Self {
config,
estimate_kbps: None,
recent_raw: VecDeque::with_capacity(20),
history_len: 20,
})
}
pub fn with_defaults() -> NetResult<Self> {
Self::new(EwmaConfig::default())
}
#[must_use]
pub fn estimate_kbps(&self) -> Option<f64> {
self.estimate_kbps
}
pub fn update(&mut self, measured_kbps: f64, loss_fraction: f64) -> f64 {
let loss = loss_fraction.clamp(0.0, 1.0);
let penalty = 1.0 - self.config.loss_penalty_factor * loss;
let raw = (measured_kbps * penalty).max(0.0);
if self.recent_raw.len() == self.history_len {
self.recent_raw.pop_front();
}
self.recent_raw.push_back(raw);
let new_estimate = match self.estimate_kbps {
None => raw,
Some(prev) => self.config.alpha * raw + (1.0 - self.config.alpha) * prev,
};
let clamped = new_estimate.clamp(self.config.min_kbps, self.config.max_kbps);
self.estimate_kbps = Some(clamped);
clamped
}
pub fn reset(&mut self) {
self.estimate_kbps = None;
self.recent_raw.clear();
}
#[must_use]
pub fn raw_variance_kbps(&self) -> f64 {
let n = self.recent_raw.len();
if n < 2 {
return 0.0;
}
let mean: f64 = self.recent_raw.iter().sum::<f64>() / n as f64;
let variance: f64 = self
.recent_raw
.iter()
.map(|x| {
let d = x - mean;
d * d
})
.sum::<f64>()
/ n as f64;
variance
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProbeState {
Idle,
Probing,
Cooldown,
}
impl fmt::Display for ProbeState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Idle => "Idle",
Self::Probing => "Probing",
Self::Cooldown => "Cooldown",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProbeResult {
pub timestamp_ms: u64,
pub bytes_delivered: u64,
pub duration_ms: u64,
pub loss_fraction: f64,
}
impl ProbeResult {
#[must_use]
pub fn new(
timestamp_ms: u64,
bytes_delivered: u64,
duration_ms: u64,
loss_fraction: f64,
) -> Self {
Self {
timestamp_ms,
bytes_delivered,
duration_ms,
loss_fraction: loss_fraction.clamp(0.0, 1.0),
}
}
#[must_use]
pub fn measured_kbps(&self) -> f64 {
if self.duration_ms == 0 {
return 0.0;
}
(self.bytes_delivered as f64 * 8.0) / self.duration_ms as f64
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProbeReport {
pub measured_kbps: f64,
pub estimate_kbps: f64,
pub loss_fraction: f64,
pub new_state: ProbeState,
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub probe_interval: Duration,
pub probe_duration: Duration,
pub cooldown_duration: Duration,
pub loss_probe_threshold: f64,
pub max_probe_bytes: u64,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
probe_interval: Duration::from_secs(2),
probe_duration: Duration::from_millis(200),
cooldown_duration: Duration::from_millis(500),
loss_probe_threshold: 0.02,
max_probe_bytes: 65_536,
}
}
}
#[derive(Debug)]
pub struct ProbeScheduler {
config: SchedulerConfig,
estimator: EwmaBandwidthEstimator,
state: ProbeState,
elapsed_ms: u64,
state_entered_ms: u64,
probe_count: u64,
last_loss: f64,
}
impl ProbeScheduler {
pub fn new(config: SchedulerConfig, ewma: EwmaConfig) -> NetResult<Self> {
let estimator = EwmaBandwidthEstimator::new(ewma)?;
Ok(Self {
config,
estimator,
state: ProbeState::Idle,
elapsed_ms: 0,
state_entered_ms: 0,
probe_count: 0,
last_loss: 0.0,
})
}
pub fn with_defaults() -> NetResult<Self> {
Self::new(SchedulerConfig::default(), EwmaConfig::default())
}
#[must_use]
pub fn state(&self) -> ProbeState {
self.state
}
#[must_use]
pub fn estimate_kbps(&self) -> Option<f64> {
self.estimator.estimate_kbps()
}
#[must_use]
pub fn probe_count(&self) -> u64 {
self.probe_count
}
#[must_use]
pub fn last_loss(&self) -> f64 {
self.last_loss
}
pub fn tick(&mut self, delta_ms: u64) -> Option<u64> {
self.elapsed_ms += delta_ms;
match self.state {
ProbeState::Idle => {
let interval_ms = self.effective_interval_ms();
let time_in_state = self.elapsed_ms - self.state_entered_ms;
if time_in_state >= interval_ms {
self.transition(ProbeState::Probing);
Some(self.probe_bytes())
} else {
None
}
}
ProbeState::Probing | ProbeState::Cooldown => None,
}
}
pub fn report(&mut self, result: ProbeResult) -> NetResult<ProbeReport> {
if self.state != ProbeState::Probing {
return Err(NetError::invalid_state(format!(
"report() called in state {}; expected Probing",
self.state
)));
}
let measured_kbps = result.measured_kbps();
self.last_loss = result.loss_fraction;
let estimate_kbps = self.estimator.update(measured_kbps, result.loss_fraction);
self.probe_count += 1;
self.transition(ProbeState::Cooldown);
Ok(ProbeReport {
measured_kbps,
estimate_kbps,
loss_fraction: result.loss_fraction,
new_state: self.state,
})
}
pub fn end_cooldown(&mut self) {
if self.state == ProbeState::Cooldown {
self.transition(ProbeState::Idle);
}
}
fn transition(&mut self, new_state: ProbeState) {
self.state = new_state;
self.state_entered_ms = self.elapsed_ms;
}
fn effective_interval_ms(&self) -> u64 {
let base = self.config.probe_interval.as_millis() as u64;
if self.last_loss >= self.config.loss_probe_threshold {
base / 2
} else {
base
}
}
fn probe_bytes(&self) -> u64 {
let target = self
.estimator
.estimate_kbps()
.map_or(self.config.max_probe_bytes, |bw| {
let bytes_per_ms = bw / 8.0;
let probe_ms = self.config.probe_duration.as_millis() as f64;
(bytes_per_ms * probe_ms) as u64
});
target.min(self.config.max_probe_bytes).max(1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ewma_config_bad_alpha() {
let cfg = EwmaConfig {
alpha: 0.0,
..Default::default()
};
assert!(cfg.validate().is_err());
let cfg2 = EwmaConfig {
alpha: 1.5,
..Default::default()
};
assert!(cfg2.validate().is_err());
}
#[test]
fn test_ewma_config_bad_bounds() {
let cfg = EwmaConfig {
min_kbps: 1000.0,
max_kbps: 500.0,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_ewma_first_update() {
let mut est = EwmaBandwidthEstimator::with_defaults().expect("valid config");
assert!(est.estimate_kbps().is_none());
let e = est.update(5_000.0, 0.0);
assert!((e - 5_000.0).abs() < 1.0, "estimate={e}");
}
#[test]
fn test_ewma_loss_penalty() {
let mut est = EwmaBandwidthEstimator::with_defaults().expect("valid config");
let no_loss = est.update(5_000.0, 0.0);
est.reset();
let with_loss = est.update(5_000.0, 0.5);
assert!(
with_loss < no_loss,
"no_loss={no_loss} with_loss={with_loss}"
);
}
#[test]
fn test_ewma_clamp_min() {
let cfg = EwmaConfig {
min_kbps: 100.0,
..Default::default()
};
let mut est = EwmaBandwidthEstimator::new(cfg).expect("valid config");
let e = est.update(0.0, 1.0);
assert!(e >= 100.0, "e={e}");
}
#[test]
fn test_ewma_clamp_max() {
let cfg = EwmaConfig {
max_kbps: 1_000.0,
..Default::default()
};
let mut est = EwmaBandwidthEstimator::new(cfg).expect("valid config");
let e = est.update(1_000_000.0, 0.0);
assert!(e <= 1_000.0, "e={e}");
}
#[test]
fn test_probe_result_kbps() {
let r = ProbeResult::new(0, 10_000, 100, 0.0);
assert!((r.measured_kbps() - 800.0).abs() < 1e-6);
}
#[test]
fn test_probe_result_kbps_zero_duration() {
let r = ProbeResult::new(0, 5_000, 0, 0.0);
assert_eq!(r.measured_kbps(), 0.0);
}
#[test]
fn test_scheduler_starts_idle() {
let sched = ProbeScheduler::with_defaults().expect("valid config");
assert_eq!(sched.state(), ProbeState::Idle);
assert_eq!(sched.probe_count(), 0);
}
#[test]
fn test_scheduler_tick_triggers_probe() {
let cfg = SchedulerConfig {
probe_interval: Duration::from_millis(200),
..Default::default()
};
let mut sched = ProbeScheduler::new(cfg, EwmaConfig::default()).expect("valid");
assert!(sched.tick(199).is_none());
let bytes = sched.tick(1);
assert!(bytes.is_some(), "expected probe bytes");
assert_eq!(sched.state(), ProbeState::Probing);
}
#[test]
fn test_scheduler_report() {
let cfg = SchedulerConfig {
probe_interval: Duration::from_millis(100),
..Default::default()
};
let mut sched = ProbeScheduler::new(cfg, EwmaConfig::default()).expect("valid");
sched.tick(100); let result = ProbeResult::new(100, 8_000, 100, 0.01);
let report = sched.report(result).expect("should succeed");
assert_eq!(report.new_state, ProbeState::Cooldown);
assert!(report.estimate_kbps > 0.0);
assert_eq!(sched.probe_count(), 1);
}
#[test]
fn test_scheduler_report_wrong_state() {
let mut sched = ProbeScheduler::with_defaults().expect("valid");
let result = ProbeResult::new(0, 1_000, 100, 0.0);
assert!(sched.report(result).is_err());
}
#[test]
fn test_scheduler_end_cooldown() {
let cfg = SchedulerConfig {
probe_interval: Duration::from_millis(100),
..Default::default()
};
let mut sched = ProbeScheduler::new(cfg, EwmaConfig::default()).expect("valid");
sched.tick(100);
sched
.report(ProbeResult::new(100, 8_000, 100, 0.0))
.expect("report ok");
assert_eq!(sched.state(), ProbeState::Cooldown);
sched.end_cooldown();
assert_eq!(sched.state(), ProbeState::Idle);
}
#[test]
fn test_high_loss_shortens_interval() {
let cfg = SchedulerConfig {
probe_interval: Duration::from_millis(2_000),
loss_probe_threshold: 0.02,
..Default::default()
};
let mut sched = ProbeScheduler::new(cfg, EwmaConfig::default()).expect("valid");
sched.tick(2_000);
sched
.report(ProbeResult::new(0, 100, 100, 0.05))
.expect("report ok");
sched.end_cooldown();
assert!(sched.tick(999).is_none());
let bytes = sched.tick(1);
assert!(
bytes.is_some(),
"probe should fire after 1000 ms under high loss"
);
}
#[test]
fn test_probe_state_display() {
assert_eq!(format!("{}", ProbeState::Idle), "Idle");
assert_eq!(format!("{}", ProbeState::Probing), "Probing");
assert_eq!(format!("{}", ProbeState::Cooldown), "Cooldown");
}
}