use crate::types::*;
use std::collections::VecDeque;
use std::time::{Duration, Instant};
use tracing::{debug, instrument};
pub trait AbrAlgorithm: Send + Sync {
fn select_rendition<'a>(
&self,
renditions: &'a [Rendition],
context: &AbrContext,
) -> Option<&'a Rendition>;
fn update(&mut self, measurement: &BandwidthMeasurement);
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone, Default)]
pub struct AbrContext {
pub buffer_level: f64,
pub target_buffer: f64,
pub playback_rate: f64,
pub is_live: bool,
pub screen_width: Option<u32>,
pub max_bitrate: u64,
pub network: NetworkInfo,
}
#[derive(Debug, Clone)]
pub struct BandwidthMeasurement {
pub bytes: usize,
pub duration: Duration,
pub timestamp: Instant,
}
impl BandwidthMeasurement {
pub fn throughput_bps(&self) -> u64 {
if self.duration.as_secs_f64() > 0.0 {
((self.bytes as f64 * 8.0) / self.duration.as_secs_f64()) as u64
} else {
0
}
}
}
pub struct AbrEngine {
algorithm: Box<dyn AbrAlgorithm>,
bandwidth_history: VecDeque<BandwidthMeasurement>,
max_history: usize,
bandwidth_estimate: u64,
last_selection: Option<usize>,
stability_counter: u32,
}
impl AbrEngine {
pub fn new(algorithm_type: AbrAlgorithmType) -> Self {
let algorithm: Box<dyn AbrAlgorithm> = match algorithm_type {
AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()), };
Self {
algorithm,
bandwidth_history: VecDeque::with_capacity(20),
max_history: 20,
bandwidth_estimate: 0,
last_selection: None,
stability_counter: 0,
}
}
#[instrument(skip(self))]
pub fn record_measurement(&mut self, bytes: usize, duration: Duration) {
let measurement = BandwidthMeasurement {
bytes,
duration,
timestamp: Instant::now(),
};
if self.bandwidth_history.len() >= self.max_history {
self.bandwidth_history.pop_front();
}
self.bandwidth_history.push_back(measurement.clone());
let sample = measurement.throughput_bps();
if self.bandwidth_estimate == 0 {
self.bandwidth_estimate = sample;
} else {
self.bandwidth_estimate =
((self.bandwidth_estimate as f64 * 0.8) + (sample as f64 * 0.2)) as u64;
}
self.algorithm.update(&measurement);
debug!(
bytes = bytes,
duration_ms = duration.as_millis(),
throughput_mbps = sample as f64 / 1_000_000.0,
estimate_mbps = self.bandwidth_estimate as f64 / 1_000_000.0,
"Bandwidth measurement recorded"
);
}
#[instrument(skip(self, renditions))]
pub fn select_rendition<'a>(
&mut self,
renditions: &'a [Rendition],
context: &AbrContext,
) -> Option<&'a Rendition> {
if renditions.is_empty() {
return None;
}
let selected = self.algorithm.select_rendition(renditions, context)?;
let new_index = renditions.iter().position(|r| r.id == selected.id)?;
if let Some(last) = self.last_selection {
if new_index != last {
self.stability_counter += 1;
if self.stability_counter < 3 {
return renditions.get(last);
}
}
self.stability_counter = 0;
}
self.last_selection = Some(new_index);
debug!(
selected_id = %selected.id,
bandwidth = selected.bandwidth,
resolution = ?selected.resolution,
"Rendition selected"
);
Some(selected)
}
pub fn bandwidth_estimate(&self) -> u64 {
self.bandwidth_estimate
}
pub fn algorithm_name(&self) -> &'static str {
self.algorithm.name()
}
pub fn set_algorithm(&mut self, algorithm_type: AbrAlgorithmType) {
self.algorithm = match algorithm_type {
AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()),
};
}
}
pub struct ThroughputAlgorithm {
safety_factor: f64,
throughput_estimate: u64,
}
impl ThroughputAlgorithm {
pub fn new() -> Self {
Self {
safety_factor: 0.8, throughput_estimate: 0,
}
}
}
impl Default for ThroughputAlgorithm {
fn default() -> Self {
Self::new()
}
}
impl AbrAlgorithm for ThroughputAlgorithm {
fn select_rendition<'a>(
&self,
renditions: &'a [Rendition],
context: &AbrContext,
) -> Option<&'a Rendition> {
let available_bandwidth =
(context.network.bandwidth_estimate as f64 * self.safety_factor) as u64;
let max_bitrate = if context.max_bitrate > 0 {
context.max_bitrate.min(available_bandwidth)
} else {
available_bandwidth
};
renditions
.iter()
.filter(|r| r.bandwidth <= max_bitrate)
.filter(|r| {
if let (Some(res), Some(screen_w)) = (&r.resolution, context.screen_width) {
res.width <= screen_w
} else {
true
}
})
.max_by_key(|r| r.bandwidth)
}
fn update(&mut self, measurement: &BandwidthMeasurement) {
let sample = measurement.throughput_bps();
if self.throughput_estimate == 0 {
self.throughput_estimate = sample;
} else {
self.throughput_estimate =
((self.throughput_estimate as f64 * 0.7) + (sample as f64 * 0.3)) as u64;
}
}
fn name(&self) -> &'static str {
"throughput"
}
}
pub struct BolaAlgorithm {
buffer_min: f64,
_buffer_max: f64,
v: f64,
gamma: f64,
}
impl BolaAlgorithm {
pub fn new() -> Self {
Self {
buffer_min: 5.0,
_buffer_max: 30.0,
v: 0.93,
gamma: 5.0,
}
}
fn utility(&self, rendition: &Rendition) -> f64 {
(rendition.bandwidth as f64).ln()
}
}
impl Default for BolaAlgorithm {
fn default() -> Self {
Self::new()
}
}
impl AbrAlgorithm for BolaAlgorithm {
fn select_rendition<'a>(
&self,
renditions: &'a [Rendition],
context: &AbrContext,
) -> Option<&'a Rendition> {
if renditions.is_empty() {
return None;
}
let buffer = context.buffer_level;
let mut best: Option<&Rendition> = None;
let mut best_score = f64::NEG_INFINITY;
for rendition in renditions {
if context.max_bitrate > 0 && rendition.bandwidth > context.max_bitrate {
continue;
}
let utility = self.utility(rendition);
let size = rendition.bandwidth as f64;
let score = (self.v * utility - buffer) / (size / 1_000_000.0 + self.gamma);
if score > best_score {
best_score = score;
best = Some(rendition);
}
}
if buffer < self.buffer_min {
return renditions.first();
}
best
}
fn update(&mut self, _measurement: &BandwidthMeasurement) {
}
fn name(&self) -> &'static str {
"bola"
}
}
pub struct HybridAlgorithm {
throughput: ThroughputAlgorithm,
bola: BolaAlgorithm,
_throughput_weight: f64,
}
impl HybridAlgorithm {
pub fn new() -> Self {
Self {
throughput: ThroughputAlgorithm::new(),
bola: BolaAlgorithm::new(),
_throughput_weight: 0.5,
}
}
}
impl Default for HybridAlgorithm {
fn default() -> Self {
Self::new()
}
}
impl AbrAlgorithm for HybridAlgorithm {
fn select_rendition<'a>(
&self,
renditions: &'a [Rendition],
context: &AbrContext,
) -> Option<&'a Rendition> {
let throughput_pick = self.throughput.select_rendition(renditions, context);
let bola_pick = self.bola.select_rendition(renditions, context);
match (throughput_pick, bola_pick) {
(Some(t), Some(b)) => {
if context.buffer_level < 10.0 {
Some(b)
} else if t.bandwidth <= b.bandwidth {
Some(t)
} else {
let t_idx = renditions.iter().position(|r| r.id == t.id).unwrap_or(0);
let b_idx = renditions.iter().position(|r| r.id == b.id).unwrap_or(0);
let avg_idx = (t_idx + b_idx) / 2;
renditions.get(avg_idx)
}
}
(Some(t), None) => Some(t),
(None, Some(b)) => Some(b),
(None, None) => renditions.first(),
}
}
fn update(&mut self, measurement: &BandwidthMeasurement) {
self.throughput.update(measurement);
self.bola.update(measurement);
}
fn name(&self) -> &'static str {
"hybrid"
}
}
#[cfg(test)]
mod tests {
use super::*;
use url::Url;
fn create_test_renditions() -> Vec<Rendition> {
vec![
Rendition {
id: "360p".to_string(),
bandwidth: 800_000,
resolution: Some(Resolution::new(640, 360)),
frame_rate: None,
video_codec: Some(VideoCodec::H264),
audio_codec: Some(AudioCodec::Aac),
uri: Url::parse("https://example.com/360p.m3u8").unwrap(),
hdr: None,
language: None,
name: None,
},
Rendition {
id: "720p".to_string(),
bandwidth: 2_800_000,
resolution: Some(Resolution::new(1280, 720)),
frame_rate: None,
video_codec: Some(VideoCodec::H264),
audio_codec: Some(AudioCodec::Aac),
uri: Url::parse("https://example.com/720p.m3u8").unwrap(),
hdr: None,
language: None,
name: None,
},
Rendition {
id: "1080p".to_string(),
bandwidth: 5_000_000,
resolution: Some(Resolution::new(1920, 1080)),
frame_rate: None,
video_codec: Some(VideoCodec::H264),
audio_codec: Some(AudioCodec::Aac),
uri: Url::parse("https://example.com/1080p.m3u8").unwrap(),
hdr: None,
language: None,
name: None,
},
]
}
#[test]
fn test_throughput_selection() {
let renditions = create_test_renditions();
let algorithm = ThroughputAlgorithm::new();
let context = AbrContext {
buffer_level: 20.0,
network: NetworkInfo {
bandwidth_estimate: 10_000_000,
..Default::default()
},
..Default::default()
};
let selected = algorithm.select_rendition(&renditions, &context);
assert_eq!(selected.map(|r| &r.id), Some(&"1080p".to_string()));
let context = AbrContext {
buffer_level: 20.0,
network: NetworkInfo {
bandwidth_estimate: 1_000_000,
..Default::default()
},
..Default::default()
};
let selected = algorithm.select_rendition(&renditions, &context);
assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
}
#[test]
fn test_bola_low_buffer() {
let renditions = create_test_renditions();
let algorithm = BolaAlgorithm::new();
let context = AbrContext {
buffer_level: 2.0,
..Default::default()
};
let selected = algorithm.select_rendition(&renditions, &context);
assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
}
}