#[derive(Debug, Clone)]
pub struct BandwidthSample {
pub bytes: u64,
pub duration_ms: u64,
pub timestamp: std::time::Instant,
}
#[derive(Debug)]
pub struct AbrBandwidthEstimator {
window: std::collections::VecDeque<BandwidthSample>,
window_size: usize,
smoothing_factor: f64,
current_estimate_bps: f64,
}
impl AbrBandwidthEstimator {
#[must_use]
pub fn new(window_size: usize) -> Self {
Self {
window: std::collections::VecDeque::with_capacity(window_size.max(1)),
window_size: window_size.max(1),
smoothing_factor: 0.3,
current_estimate_bps: 0.0,
}
}
#[must_use]
pub fn estimate_bps(&self) -> f64 {
self.current_estimate_bps
}
#[must_use]
pub fn estimate_kbps(&self) -> f64 {
self.current_estimate_bps / 1_000.0
}
#[must_use]
pub fn estimate_mbps(&self) -> f64 {
self.current_estimate_bps / 1_000_000.0
}
pub fn add_sample(&mut self, bytes: u64, duration_ms: u64) {
let sample_bps = if duration_ms == 0 {
0.0
} else {
(bytes as f64 * 8.0 * 1_000.0) / duration_ms as f64
};
if self.current_estimate_bps <= 0.0 {
self.current_estimate_bps = sample_bps;
} else {
self.current_estimate_bps = self.smoothing_factor * sample_bps
+ (1.0 - self.smoothing_factor) * self.current_estimate_bps;
}
let sample = BandwidthSample {
bytes,
duration_ms,
timestamp: std::time::Instant::now(),
};
if self.window.len() >= self.window_size {
self.window.pop_front();
}
self.window.push_back(sample);
}
#[must_use]
pub fn percentile_bps(&self, percentile: f64) -> f64 {
if self.window.is_empty() {
return 0.0;
}
let mut rates: Vec<f64> = self
.window
.iter()
.map(|s| {
if s.duration_ms == 0 {
0.0
} else {
(s.bytes as f64 * 8.0 * 1_000.0) / s.duration_ms as f64
}
})
.collect();
rates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let p = percentile.clamp(0.0, 1.0);
let idx = ((rates.len() as f64 - 1.0) * p) as usize;
rates[idx.min(rates.len() - 1)]
}
#[must_use]
pub fn sample_count(&self) -> usize {
self.window.len()
}
}
#[derive(Debug, Clone)]
pub struct AbrVariant {
pub bandwidth: u64,
pub width: u32,
pub height: u32,
pub codecs: String,
pub uri: String,
pub name: String,
pub frame_rate: Option<f64>,
pub hdcp_level: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AbrSwitchReason {
BandwidthIncrease,
BandwidthDecrease,
BufferStarvation,
UserRequested,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SelectionResult {
Stay {
variant: usize,
},
SwitchUp {
from: usize,
to: usize,
reason: AbrSwitchReason,
},
SwitchDown {
from: usize,
to: usize,
reason: AbrSwitchReason,
},
EmergencySwitch {
from: usize,
to: usize,
},
}
impl SelectionResult {
#[must_use]
pub const fn variant_index(&self) -> usize {
match self {
Self::Stay { variant } => *variant,
Self::SwitchUp { to, .. } => *to,
Self::SwitchDown { to, .. } => *to,
Self::EmergencySwitch { to, .. } => *to,
}
}
#[must_use]
pub const fn is_switch(&self) -> bool {
!matches!(self, Self::Stay { .. })
}
#[must_use]
pub const fn is_emergency(&self) -> bool {
matches!(self, Self::EmergencySwitch { .. })
}
}
#[derive(Debug)]
pub struct AbrController {
variants: Vec<AbrVariant>,
current_index: usize,
bandwidth_estimator: AbrBandwidthEstimator,
buffer_duration_s: f64,
min_buffer_s: f64,
panic_buffer_s: f64,
safety_factor: f64,
switch_cooldown_segments: u32,
segments_since_switch: u32,
}
impl AbrController {
pub fn new(mut variants: Vec<AbrVariant>) -> Result<Self, String> {
if variants.is_empty() {
return Err("AbrController requires at least one variant".into());
}
variants.sort_by_key(|v| v.bandwidth);
Ok(Self {
bandwidth_estimator: AbrBandwidthEstimator::new(10),
variants,
current_index: 0,
buffer_duration_s: 0.0,
min_buffer_s: 15.0,
panic_buffer_s: 5.0,
safety_factor: 0.8,
switch_cooldown_segments: 3,
segments_since_switch: 0,
})
}
#[must_use]
pub fn current_variant(&self) -> &AbrVariant {
&self.variants[self.current_index]
}
#[must_use]
pub fn variant_count(&self) -> usize {
self.variants.len()
}
pub fn update_bandwidth(&mut self, bytes: u64, duration_ms: u64) {
self.bandwidth_estimator.add_sample(bytes, duration_ms);
}
pub fn update_buffer(&mut self, buffer_duration_s: f64) {
self.buffer_duration_s = buffer_duration_s;
}
pub fn select_variant(&mut self) -> SelectionResult {
let old = self.current_index;
if self.buffer_duration_s < self.panic_buffer_s && old > 0 {
self.current_index = 0;
self.segments_since_switch = 0;
return SelectionResult::EmergencySwitch { from: old, to: 0 };
}
if self.segments_since_switch < self.switch_cooldown_segments {
self.segments_since_switch += 1;
return SelectionResult::Stay { variant: old };
}
let safe_bw = self.bandwidth_estimator.estimate_bps() * self.safety_factor;
let mut target = 0usize;
for (i, v) in self.variants.iter().enumerate() {
if v.bandwidth as f64 <= safe_bw {
target = i;
}
}
let result = if target > old {
if self.buffer_duration_s >= self.min_buffer_s {
let next = (old + 1).min(target);
self.current_index = next;
self.segments_since_switch = 0;
SelectionResult::SwitchUp {
from: old,
to: next,
reason: AbrSwitchReason::BandwidthIncrease,
}
} else {
self.segments_since_switch += 1;
SelectionResult::Stay { variant: old }
}
} else if target < old {
self.current_index = target;
self.segments_since_switch = 0;
SelectionResult::SwitchDown {
from: old,
to: target,
reason: AbrSwitchReason::BandwidthDecrease,
}
} else {
self.segments_since_switch += 1;
SelectionResult::Stay { variant: old }
};
result
}
pub fn force_variant(&mut self, index: usize) -> Result<(), String> {
if index >= self.variants.len() {
return Err(format!(
"Variant index {index} out of range (max {})",
self.variants.len() - 1
));
}
self.current_index = index;
self.segments_since_switch = 0;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BufferedSegment {
pub sequence: u64,
pub variant_index: usize,
pub data: Vec<u8>,
pub duration_s: f64,
pub download_time_ms: u64,
}
#[derive(Debug)]
pub struct SegmentFetcher {
controller: AbrController,
segment_duration_s: f64,
max_buffer_segments: usize,
buffered_segments: std::collections::VecDeque<BufferedSegment>,
}
impl SegmentFetcher {
#[must_use]
pub fn new(controller: AbrController, segment_duration_s: f64) -> Self {
Self {
controller,
segment_duration_s,
max_buffer_segments: 30,
buffered_segments: std::collections::VecDeque::new(),
}
}
#[must_use]
pub fn buffer_level_s(&self) -> f64 {
self.buffered_segments.iter().map(|s| s.duration_s).sum()
}
pub fn next_variant(&mut self) -> &AbrVariant {
let _result = self.controller.select_variant();
let buf = self.buffer_level_s();
self.controller.update_buffer(buf);
self.controller.current_variant()
}
pub fn record_download(
&mut self,
sequence: u64,
bytes: u64,
duration_ms: u64,
segment_duration_s: f64,
) {
self.controller.update_bandwidth(bytes, duration_ms);
let variant_index = self.controller.current_index;
let seg = BufferedSegment {
sequence,
variant_index,
data: Vec::new(), duration_s: segment_duration_s,
download_time_ms: duration_ms,
};
if self.buffered_segments.len() >= self.max_buffer_segments {
self.buffered_segments.pop_front();
}
self.buffered_segments.push_back(seg);
}
pub fn pop_segment(&mut self) -> Option<BufferedSegment> {
self.buffered_segments.pop_front()
}
#[must_use]
pub fn buffered_count(&self) -> usize {
self.buffered_segments.len()
}
}
#[cfg(test)]
mod streaming_abr_tests {
use super::*;
fn make_variants() -> Vec<AbrVariant> {
vec![
AbrVariant {
bandwidth: 500_000,
width: 640,
height: 360,
codecs: "avc1.42c01e,mp4a.40.2".into(),
uri: "low.m3u8".into(),
name: "360p".into(),
frame_rate: Some(30.0),
hdcp_level: None,
},
AbrVariant {
bandwidth: 1_500_000,
width: 1280,
height: 720,
codecs: "avc1.42c01e,mp4a.40.2".into(),
uri: "mid.m3u8".into(),
name: "720p".into(),
frame_rate: Some(30.0),
hdcp_level: None,
},
AbrVariant {
bandwidth: 4_000_000,
width: 1920,
height: 1080,
codecs: "avc1.640028,mp4a.40.2".into(),
uri: "high.m3u8".into(),
name: "1080p".into(),
frame_rate: Some(60.0),
hdcp_level: None,
},
]
}
#[test]
fn test_bandwidth_estimator_basic() {
let mut est = AbrBandwidthEstimator::new(10);
est.add_sample(1_000_000, 1_000); est.add_sample(2_000_000, 1_000); est.add_sample(1_500_000, 1_000); assert!(est.estimate_bps() > 0.0, "estimate should be positive");
assert_eq!(est.sample_count(), 3);
}
#[test]
fn test_bandwidth_estimator_percentile() {
let mut est = AbrBandwidthEstimator::new(20);
for _ in 0..5 {
est.add_sample(125_000, 1_000); }
for _ in 0..5 {
est.add_sample(1_250_000, 1_000); }
let p15 = est.percentile_bps(0.15);
let p85 = est.percentile_bps(0.85);
assert!(p15 < p85, "15th percentile should be lower than 85th");
assert!(p15 > 0.0, "percentile should be positive");
}
#[test]
fn test_abr_controller_creation() {
let mut variants = make_variants();
variants.reverse(); let ctrl = AbrController::new(variants).expect("should create controller");
assert_eq!(ctrl.variant_count(), 3);
assert_eq!(ctrl.current_variant().bandwidth, 500_000);
}
#[test]
fn test_abr_stay_on_low_buffer() {
let mut ctrl = AbrController::new(make_variants()).expect("should succeed in test");
ctrl.force_variant(2).expect("should succeed in test");
ctrl.update_bandwidth(500_000, 1_000); ctrl.update_buffer(2.0);
ctrl.segments_since_switch = ctrl.switch_cooldown_segments;
let result = ctrl.select_variant();
assert!(
result.is_emergency(),
"expected emergency switch, got {result:?}"
);
assert_eq!(
result.variant_index(),
0,
"emergency switch must go to index 0"
);
}
#[test]
fn test_abr_switch_up_good_bandwidth() {
let mut ctrl = AbrController::new(make_variants()).expect("should succeed in test");
ctrl.update_bandwidth(5_000_000, 1_000);
ctrl.update_buffer(20.0);
ctrl.segments_since_switch = ctrl.switch_cooldown_segments;
let result = ctrl.select_variant();
assert!(
result.is_switch(),
"expected a switch with excellent bandwidth"
);
assert!(
result.variant_index() > 0,
"should switch up from index 0, got {}",
result.variant_index()
);
}
#[test]
fn test_abr_cooldown() {
let mut ctrl = AbrController::new(make_variants()).expect("should succeed in test");
ctrl.update_bandwidth(5_000_000, 1_000);
ctrl.update_buffer(20.0);
ctrl.segments_since_switch = ctrl.switch_cooldown_segments;
let first = ctrl.select_variant();
let _ = first;
let second = ctrl.select_variant();
assert!(
matches!(second, SelectionResult::Stay { .. }),
"cooldown should prevent immediate second switch, got {second:?}"
);
}
#[test]
fn test_selection_result_accessors() {
let stay = SelectionResult::Stay { variant: 1 };
assert_eq!(stay.variant_index(), 1);
assert!(!stay.is_switch());
assert!(!stay.is_emergency());
let up = SelectionResult::SwitchUp {
from: 0,
to: 1,
reason: AbrSwitchReason::BandwidthIncrease,
};
assert_eq!(up.variant_index(), 1);
assert!(up.is_switch());
assert!(!up.is_emergency());
let down = SelectionResult::SwitchDown {
from: 2,
to: 1,
reason: AbrSwitchReason::BandwidthDecrease,
};
assert_eq!(down.variant_index(), 1);
assert!(down.is_switch());
assert!(!down.is_emergency());
let emergency = SelectionResult::EmergencySwitch { from: 2, to: 0 };
assert_eq!(emergency.variant_index(), 0);
assert!(emergency.is_switch());
assert!(emergency.is_emergency());
}
#[test]
fn test_segment_fetcher_buffer_level() {
let ctrl = AbrController::new(make_variants()).expect("should succeed in test");
let mut fetcher = SegmentFetcher::new(ctrl, 4.0);
fetcher.record_download(0, 500_000, 1_000, 4.0);
fetcher.record_download(1, 500_000, 1_000, 4.0);
fetcher.record_download(2, 500_000, 1_000, 4.0);
let level = fetcher.buffer_level_s();
assert!(
(level - 12.0).abs() < f64::EPSILON,
"3 × 4 s segments = 12 s, got {level}"
);
assert_eq!(fetcher.buffered_count(), 3);
}
#[test]
fn test_segment_fetcher_pop() {
let ctrl = AbrController::new(make_variants()).expect("should succeed in test");
let mut fetcher = SegmentFetcher::new(ctrl, 6.0);
fetcher.record_download(0, 750_000, 800, 6.0);
fetcher.record_download(1, 750_000, 800, 6.0);
assert_eq!(fetcher.buffered_count(), 2);
let seg = fetcher.pop_segment().expect("should return a segment");
assert_eq!(seg.sequence, 0);
assert_eq!(fetcher.buffered_count(), 1);
let level = fetcher.buffer_level_s();
assert!(
(level - 6.0).abs() < f64::EPSILON,
"after pop, 1 × 6 s segment remains, got {level}"
);
}
}