#![allow(dead_code)]
use std::collections::VecDeque;
use std::time::{Duration, Instant};
use crate::error::{NetError, NetResult};
#[derive(Debug, Clone, PartialEq)]
pub struct QualityTier {
pub label: String,
pub bitrate_bps: u64,
pub width: u32,
pub height: u32,
}
impl QualityTier {
#[must_use]
pub fn new(label: impl Into<String>, bitrate_bps: u64, width: u32, height: u32) -> Self {
Self {
label: label.into(),
bitrate_bps,
width,
height,
}
}
#[must_use]
pub fn audio_only(label: impl Into<String>, bitrate_bps: u64) -> Self {
Self::new(label, bitrate_bps, 0, 0)
}
}
#[derive(Debug, Clone)]
pub struct TriggerConfig {
pub tiers: Vec<QualityTier>,
pub ema_alpha: f64,
pub safety_factor: f64,
pub upgrade_hold: Duration,
pub window_depth: usize,
pub downgrade_cooldown: Duration,
}
impl Default for TriggerConfig {
fn default() -> Self {
let tiers = vec![
QualityTier::new("240p", 400_000, 426, 240),
QualityTier::new("480p", 1_200_000, 854, 480),
QualityTier::new("720p", 2_500_000, 1280, 720),
QualityTier::new("1080p", 5_000_000, 1920, 1080),
];
Self {
tiers,
ema_alpha: 0.25,
safety_factor: 1.25,
upgrade_hold: Duration::from_secs(3),
window_depth: 8,
downgrade_cooldown: Duration::from_secs(2),
}
}
}
impl TriggerConfig {
pub fn validate(&self) -> NetResult<()> {
if self.tiers.is_empty() {
return Err(NetError::invalid_state(
"at least one quality tier required",
));
}
for tier in &self.tiers {
if tier.bitrate_bps == 0 {
return Err(NetError::invalid_state(format!(
"tier '{}' has zero bitrate",
tier.label
)));
}
}
if self.safety_factor <= 0.0 || self.safety_factor > 10.0 {
return Err(NetError::invalid_state("safety_factor must be in (0, 10]"));
}
if self.ema_alpha <= 0.0 || self.ema_alpha > 1.0 {
return Err(NetError::invalid_state("ema_alpha must be in (0, 1]"));
}
Ok(())
}
pub fn prepare(&mut self) -> NetResult<()> {
self.tiers.sort_by_key(|t| t.bitrate_bps);
self.validate()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TriggerAction {
Hold,
Downgrade {
tier_index: usize,
reason: String,
},
Upgrade {
tier_index: usize,
reason: String,
},
}
impl TriggerAction {
#[must_use]
pub const fn is_change(&self) -> bool {
!matches!(self, Self::Hold)
}
#[must_use]
pub const fn tier_index(&self) -> Option<usize> {
match self {
Self::Hold => None,
Self::Downgrade { tier_index, .. } | Self::Upgrade { tier_index, .. } => {
Some(*tier_index)
}
}
}
}
#[derive(Debug, Clone)]
pub struct TriggerEvent {
pub when: Instant,
pub action: TriggerAction,
pub ema_bps: f64,
pub from_tier: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct BandwidthObservation {
pub bps: f64,
pub measured_at: Instant,
}
impl BandwidthObservation {
#[must_use]
pub fn new(bps: f64) -> Self {
Self {
bps: bps.max(0.0),
measured_at: Instant::now(),
}
}
}
#[derive(Debug)]
pub struct BandwidthTrigger {
config: TriggerConfig,
ema_bps: f64,
ema_initialised: bool,
current_tier: usize,
samples: VecDeque<BandwidthObservation>,
upgrade_candidate_since: Option<Instant>,
last_downgrade: Option<Instant>,
history: Vec<TriggerEvent>,
}
impl BandwidthTrigger {
pub fn new(mut config: TriggerConfig) -> NetResult<Self> {
config.prepare()?;
let window = config.window_depth;
Ok(Self {
config,
ema_bps: 0.0,
ema_initialised: false,
current_tier: 0,
samples: VecDeque::with_capacity(window),
upgrade_candidate_since: None,
last_downgrade: None,
history: Vec::new(),
})
}
pub fn add_observation(&mut self, obs: BandwidthObservation) {
if self.samples.len() == self.config.window_depth {
self.samples.pop_front();
}
self.samples.push_back(obs);
if !self.ema_initialised {
self.ema_bps = obs.bps;
self.ema_initialised = true;
} else {
let alpha = self.config.ema_alpha;
self.ema_bps = alpha * obs.bps + (1.0 - alpha) * self.ema_bps;
}
}
#[must_use]
pub const fn ema_bps(&self) -> f64 {
self.ema_bps
}
#[must_use]
pub const fn current_tier(&self) -> usize {
self.current_tier
}
#[must_use]
pub fn current_tier_info(&self) -> &QualityTier {
&self.config.tiers[self.current_tier]
}
#[must_use]
pub fn tiers(&self) -> &[QualityTier] {
&self.config.tiers
}
#[must_use]
pub fn history(&self) -> &[TriggerEvent] {
&self.history
}
pub fn force_tier(&mut self, tier_index: usize) -> NetResult<()> {
if tier_index >= self.config.tiers.len() {
return Err(NetError::invalid_state(format!(
"tier_index {tier_index} out of range (max {})",
self.config.tiers.len() - 1
)));
}
self.current_tier = tier_index;
self.upgrade_candidate_since = None;
Ok(())
}
#[must_use]
pub fn evaluate(&mut self) -> TriggerAction {
if !self.ema_initialised {
return TriggerAction::Hold;
}
let now = Instant::now();
let sf = self.config.safety_factor;
let tier_count = self.config.tiers.len();
let current_tier_bps = self.config.tiers[self.current_tier].bitrate_bps as f64;
let downgrade_threshold = current_tier_bps * sf;
let cooldown_elapsed = self
.last_downgrade
.map(|t| now.duration_since(t) >= self.config.downgrade_cooldown)
.unwrap_or(true);
if self.ema_bps < downgrade_threshold && self.current_tier > 0 && cooldown_elapsed {
let new_tier = self.best_tier_for(self.ema_bps);
if new_tier < self.current_tier {
let reason = format!(
"EMA {:.0} bps < threshold {:.0} bps (tier '{}' × {:.2})",
self.ema_bps,
downgrade_threshold,
self.config.tiers[self.current_tier].label,
sf,
);
self.last_downgrade = Some(now);
self.upgrade_candidate_since = None;
let action = TriggerAction::Downgrade {
tier_index: new_tier,
reason: reason.clone(),
};
self.record_event(now, action.clone());
self.current_tier = new_tier;
return action;
}
}
if self.current_tier + 1 < tier_count {
let next_tier_bps = self.config.tiers[self.current_tier + 1].bitrate_bps as f64;
let upgrade_threshold = next_tier_bps * sf;
if self.ema_bps >= upgrade_threshold {
match self.upgrade_candidate_since {
None => {
self.upgrade_candidate_since = Some(now);
}
Some(since) => {
if now.duration_since(since) >= self.config.upgrade_hold {
let new_tier = self.current_tier + 1;
let reason = format!(
"EMA {:.0} bps ≥ threshold {:.0} bps (tier '{}' × {:.2}) for {:.1}s",
self.ema_bps,
upgrade_threshold,
self.config.tiers[new_tier].label,
sf,
now.duration_since(since).as_secs_f64(),
);
self.upgrade_candidate_since = None;
let action = TriggerAction::Upgrade {
tier_index: new_tier,
reason: reason.clone(),
};
self.record_event(now, action.clone());
self.current_tier = new_tier;
return action;
}
}
}
} else {
self.upgrade_candidate_since = None;
}
}
TriggerAction::Hold
}
fn best_tier_for(&self, ema_bps: f64) -> usize {
let sf = self.config.safety_factor;
let mut best = 0;
for (idx, tier) in self.config.tiers.iter().enumerate() {
if tier.bitrate_bps as f64 * sf <= ema_bps {
best = idx;
}
}
best
}
fn record_event(&mut self, when: Instant, action: TriggerAction) {
self.history.push(TriggerEvent {
when,
action,
ema_bps: self.ema_bps,
from_tier: self.current_tier,
});
}
pub fn reset(&mut self) {
self.ema_bps = 0.0;
self.ema_initialised = false;
self.current_tier = 0;
self.samples.clear();
self.upgrade_candidate_since = None;
self.last_downgrade = None;
self.history.clear();
}
#[must_use]
pub fn snapshot(&self) -> TriggerSnapshot {
TriggerSnapshot {
ema_bps: self.ema_bps,
current_tier: self.current_tier,
current_tier_label: self.config.tiers[self.current_tier].label.clone(),
current_tier_bitrate_bps: self.config.tiers[self.current_tier].bitrate_bps,
sample_count: self.samples.len(),
event_count: self.history.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct TriggerSnapshot {
pub ema_bps: f64,
pub current_tier: usize,
pub current_tier_label: String,
pub current_tier_bitrate_bps: u64,
pub sample_count: usize,
pub event_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn make_trigger() -> BandwidthTrigger {
BandwidthTrigger::new(TriggerConfig::default()).expect("valid config")
}
fn feed_bps(trigger: &mut BandwidthTrigger, bps: f64, n: usize) -> TriggerAction {
let mut action = TriggerAction::Hold;
for _ in 0..n {
trigger.add_observation(BandwidthObservation::new(bps));
action = trigger.evaluate();
}
action
}
#[test]
fn test_default_config_validates() {
let mut cfg = TriggerConfig::default();
cfg.prepare().expect("default config should be valid");
}
#[test]
fn test_empty_tiers_rejected() {
let mut cfg = TriggerConfig {
tiers: vec![],
..TriggerConfig::default()
};
assert!(cfg.prepare().is_err());
}
#[test]
fn test_zero_bitrate_tier_rejected() {
let mut cfg = TriggerConfig {
tiers: vec![QualityTier::new("bad", 0, 0, 0)],
..TriggerConfig::default()
};
assert!(cfg.prepare().is_err());
}
#[test]
fn test_invalid_safety_factor_rejected() {
let mut cfg = TriggerConfig {
safety_factor: -1.0,
..TriggerConfig::default()
};
assert!(cfg.prepare().is_err());
}
#[test]
fn test_hold_before_observations() {
let mut trigger = make_trigger();
assert_eq!(trigger.evaluate(), TriggerAction::Hold);
}
#[test]
fn test_stays_at_lowest_tier_on_low_bandwidth() {
let mut trigger = make_trigger();
feed_bps(&mut trigger, 100_000.0, 10);
assert_eq!(trigger.current_tier(), 0);
}
#[test]
fn test_downgrade_emitted_on_bandwidth_drop() {
let mut trigger = make_trigger();
trigger.force_tier(3).expect("tier 3 exists");
feed_bps(&mut trigger, 200_000.0, 10);
assert!(trigger.current_tier() < 3, "should have downgraded");
}
#[test]
fn test_upgrade_requires_hold_period() {
let mut trigger = BandwidthTrigger::new(TriggerConfig {
upgrade_hold: Duration::from_secs(100), ..TriggerConfig::default()
})
.expect("valid");
feed_bps(&mut trigger, 50_000_000.0, 5);
assert_eq!(
trigger.current_tier(),
0,
"upgrade should not have fired yet"
);
}
#[test]
fn test_upgrade_fires_after_hold() {
let mut trigger = BandwidthTrigger::new(TriggerConfig {
upgrade_hold: Duration::ZERO, ..TriggerConfig::default()
})
.expect("valid");
feed_bps(&mut trigger, 2_000_000.0, 5);
assert!(trigger.current_tier() > 0, "should have upgraded");
}
#[test]
fn test_ema_initialised_from_first_sample() {
let mut trigger = make_trigger();
trigger.add_observation(BandwidthObservation::new(3_000_000.0));
assert!((trigger.ema_bps() - 3_000_000.0).abs() < 1.0);
}
#[test]
fn test_ema_smoothing() {
let mut trigger = make_trigger();
trigger.add_observation(BandwidthObservation::new(10_000_000.0));
trigger.add_observation(BandwidthObservation::new(0.0)); assert!(trigger.ema_bps() > 0.0);
assert!(trigger.ema_bps() < 10_000_000.0);
}
#[test]
fn test_force_tier_out_of_range_errors() {
let mut trigger = make_trigger();
assert!(trigger.force_tier(99).is_err());
}
#[test]
fn test_snapshot_reflects_state() {
let mut trigger = make_trigger();
trigger.add_observation(BandwidthObservation::new(5_000_000.0));
let snap = trigger.snapshot();
assert_eq!(snap.sample_count, 1);
assert!(snap.ema_bps > 0.0);
}
#[test]
fn test_reset_clears_state() {
let mut trigger = make_trigger();
feed_bps(&mut trigger, 5_000_000.0, 5);
trigger.reset();
assert_eq!(trigger.snapshot().sample_count, 0);
assert_eq!(trigger.snapshot().ema_bps, 0.0);
assert_eq!(trigger.current_tier(), 0);
}
#[test]
fn test_history_records_events() {
let mut trigger = BandwidthTrigger::new(TriggerConfig {
upgrade_hold: Duration::ZERO,
..TriggerConfig::default()
})
.expect("valid");
feed_bps(&mut trigger, 10_000_000.0, 20);
assert!(
!trigger.history().is_empty(),
"expected at least one event in history"
);
}
#[test]
fn test_tier_label() {
let trigger = make_trigger();
let info = trigger.current_tier_info();
assert!(!info.label.is_empty());
}
#[test]
fn test_quality_tier_audio_only() {
let tier = QualityTier::audio_only("AAC 128k", 128_000);
assert_eq!(tier.width, 0);
assert_eq!(tier.height, 0);
assert_eq!(tier.bitrate_bps, 128_000);
}
#[test]
fn test_trigger_action_is_change() {
assert!(!TriggerAction::Hold.is_change());
assert!(TriggerAction::Downgrade {
tier_index: 0,
reason: String::new()
}
.is_change());
assert!(TriggerAction::Upgrade {
tier_index: 1,
reason: String::new()
}
.is_change());
}
#[test]
fn test_trigger_action_tier_index() {
assert_eq!(TriggerAction::Hold.tier_index(), None);
assert_eq!(
TriggerAction::Upgrade {
tier_index: 2,
reason: String::new()
}
.tier_index(),
Some(2)
);
}
}