use crate::bandwidth_trigger::{
BandwidthObservation, BandwidthTrigger, TriggerAction, TriggerConfig,
};
use crate::error::NetResult;
pub struct BandwidthAdaptationController {
trigger: BandwidthTrigger,
callback: Box<dyn Fn(TriggerAction) + Send>,
}
impl BandwidthAdaptationController {
pub fn new(
config: TriggerConfig,
callback: impl Fn(TriggerAction) + Send + 'static,
) -> NetResult<Self> {
let trigger = BandwidthTrigger::new(config)?;
Ok(Self {
trigger,
callback: Box::new(callback),
})
}
pub fn update(&mut self, obs: BandwidthObservation) {
self.trigger.add_observation(obs);
let action = self.trigger.evaluate();
match &action {
TriggerAction::Hold => {} _ => (self.callback)(action),
}
}
#[must_use]
pub fn ema_bps(&self) -> f64 {
self.trigger.ema_bps()
}
#[must_use]
pub fn current_tier(&self) -> usize {
self.trigger.current_tier()
}
pub fn reset(&mut self) {
self.trigger.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
fn make_low_bandwidth_config() -> TriggerConfig {
TriggerConfig {
upgrade_hold: std::time::Duration::ZERO,
..TriggerConfig::default()
}
}
fn feed_bps_and_count(ctrl: &mut BandwidthAdaptationController, bps: f64, n: usize) -> usize {
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..n {
ctrl.update(BandwidthObservation::new(bps));
}
counter.load(Ordering::Relaxed)
}
#[test]
fn test_controller_fires_on_downgrade() {
let fired = Arc::new(AtomicUsize::new(0));
let fired_clone = Arc::clone(&fired);
let config = TriggerConfig::default();
let mut ctrl = BandwidthAdaptationController::new(config, move |action| {
if matches!(action, TriggerAction::Downgrade { .. }) {
fired_clone.fetch_add(1, Ordering::Relaxed);
}
})
.expect("valid config");
ctrl.trigger.force_tier(3).expect("tier 3 exists");
for _ in 0..15 {
ctrl.update(BandwidthObservation::new(100_000.0));
}
assert!(
fired.load(Ordering::Relaxed) > 0,
"callback should have fired at least once with Downgrade"
);
}
#[test]
fn test_controller_no_callback_on_hold() {
let fired = Arc::new(AtomicUsize::new(0));
let fired_clone = Arc::clone(&fired);
let config = TriggerConfig {
upgrade_hold: std::time::Duration::from_secs(100),
..TriggerConfig::default()
};
let mut ctrl = BandwidthAdaptationController::new(config, move |_action| {
fired_clone.fetch_add(1, Ordering::Relaxed);
})
.expect("valid config");
for _ in 0..10 {
ctrl.update(BandwidthObservation::new(501_000.0));
}
assert_eq!(
fired.load(Ordering::Relaxed),
0,
"callback must not fire on Hold"
);
}
#[test]
fn test_controller_fires_on_upgrade() {
let upgrades = Arc::new(AtomicUsize::new(0));
let upgrades_clone = Arc::clone(&upgrades);
let config = make_low_bandwidth_config(); let mut ctrl = BandwidthAdaptationController::new(config, move |action| {
if matches!(action, TriggerAction::Upgrade { .. }) {
upgrades_clone.fetch_add(1, Ordering::Relaxed);
}
})
.expect("valid config");
for _ in 0..20 {
ctrl.update(BandwidthObservation::new(50_000_000.0));
}
assert!(
upgrades.load(Ordering::Relaxed) > 0,
"callback should have fired at least once with Upgrade"
);
}
#[test]
fn test_controller_invalid_config_returns_error() {
let mut config = TriggerConfig::default();
config.tiers.clear(); let result = BandwidthAdaptationController::new(config, |_| {});
assert!(result.is_err(), "empty tier list should be rejected");
}
#[test]
fn test_controller_reset_clears_state() {
let fired = Arc::new(AtomicUsize::new(0));
let fired_clone = Arc::clone(&fired);
let config = make_low_bandwidth_config();
let mut ctrl = BandwidthAdaptationController::new(config, move |_| {
fired_clone.fetch_add(1, Ordering::Relaxed);
})
.expect("valid config");
ctrl.update(BandwidthObservation::new(5_000_000.0));
assert!(ctrl.ema_bps() > 0.0);
ctrl.reset();
assert_eq!(ctrl.ema_bps(), 0.0);
assert_eq!(ctrl.current_tier(), 0);
}
}