1use crate::types::*;
9use std::collections::VecDeque;
10use std::time::{Duration, Instant};
11use tracing::{debug, instrument};
12
13pub trait AbrAlgorithm: Send + Sync {
15 fn select_rendition<'a>(
17 &self,
18 renditions: &'a [Rendition],
19 context: &AbrContext,
20 ) -> Option<&'a Rendition>;
21
22 fn update(&mut self, measurement: &BandwidthMeasurement);
24
25 fn name(&self) -> &'static str;
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct AbrContext {
32 pub buffer_level: f64,
34 pub target_buffer: f64,
36 pub playback_rate: f64,
38 pub is_live: bool,
40 pub screen_width: Option<u32>,
42 pub max_bitrate: u64,
44 pub network: NetworkInfo,
46}
47
48#[derive(Debug, Clone)]
50pub struct BandwidthMeasurement {
51 pub bytes: usize,
53 pub duration: Duration,
55 pub timestamp: Instant,
57}
58
59impl BandwidthMeasurement {
60 pub fn throughput_bps(&self) -> u64 {
62 if self.duration.as_secs_f64() > 0.0 {
63 ((self.bytes as f64 * 8.0) / self.duration.as_secs_f64()) as u64
64 } else {
65 0
66 }
67 }
68}
69
70pub struct AbrEngine {
72 algorithm: Box<dyn AbrAlgorithm>,
74 bandwidth_history: VecDeque<BandwidthMeasurement>,
76 max_history: usize,
78 bandwidth_estimate: u64,
80 last_selection: Option<usize>,
82 stability_counter: u32,
84}
85
86impl AbrEngine {
87 pub fn new(algorithm_type: AbrAlgorithmType) -> Self {
89 let algorithm: Box<dyn AbrAlgorithm> = match algorithm_type {
90 AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
91 AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
92 AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
93 AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()), };
95
96 Self {
97 algorithm,
98 bandwidth_history: VecDeque::with_capacity(20),
99 max_history: 20,
100 bandwidth_estimate: 0,
101 last_selection: None,
102 stability_counter: 0,
103 }
104 }
105
106 #[instrument(skip(self))]
108 pub fn record_measurement(&mut self, bytes: usize, duration: Duration) {
109 let measurement = BandwidthMeasurement {
110 bytes,
111 duration,
112 timestamp: Instant::now(),
113 };
114
115 if self.bandwidth_history.len() >= self.max_history {
117 self.bandwidth_history.pop_front();
118 }
119 self.bandwidth_history.push_back(measurement.clone());
120
121 let sample = measurement.throughput_bps();
123 if self.bandwidth_estimate == 0 {
124 self.bandwidth_estimate = sample;
125 } else {
126 self.bandwidth_estimate =
128 ((self.bandwidth_estimate as f64 * 0.8) + (sample as f64 * 0.2)) as u64;
129 }
130
131 self.algorithm.update(&measurement);
133
134 debug!(
135 bytes = bytes,
136 duration_ms = duration.as_millis(),
137 throughput_mbps = sample as f64 / 1_000_000.0,
138 estimate_mbps = self.bandwidth_estimate as f64 / 1_000_000.0,
139 "Bandwidth measurement recorded"
140 );
141 }
142
143 #[instrument(skip(self, renditions))]
145 pub fn select_rendition<'a>(
146 &mut self,
147 renditions: &'a [Rendition],
148 context: &AbrContext,
149 ) -> Option<&'a Rendition> {
150 if renditions.is_empty() {
151 return None;
152 }
153
154 let selected = self.algorithm.select_rendition(renditions, context)?;
156
157 let new_index = renditions.iter().position(|r| r.id == selected.id)?;
159
160 if let Some(last) = self.last_selection {
162 if new_index != last {
163 self.stability_counter += 1;
164 if self.stability_counter < 3 {
165 return renditions.get(last);
167 }
168 }
169 self.stability_counter = 0;
170 }
171
172 self.last_selection = Some(new_index);
173
174 debug!(
175 selected_id = %selected.id,
176 bandwidth = selected.bandwidth,
177 resolution = ?selected.resolution,
178 "Rendition selected"
179 );
180
181 Some(selected)
182 }
183
184 pub fn bandwidth_estimate(&self) -> u64 {
186 self.bandwidth_estimate
187 }
188
189 pub fn algorithm_name(&self) -> &'static str {
191 self.algorithm.name()
192 }
193
194 pub fn set_algorithm(&mut self, algorithm_type: AbrAlgorithmType) {
196 self.algorithm = match algorithm_type {
197 AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
198 AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
199 AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
200 AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()),
201 };
202 }
203}
204
205pub struct ThroughputAlgorithm {
207 safety_factor: f64,
209 throughput_estimate: u64,
211}
212
213impl ThroughputAlgorithm {
214 pub fn new() -> Self {
215 Self {
216 safety_factor: 0.8, throughput_estimate: 0,
218 }
219 }
220}
221
222impl Default for ThroughputAlgorithm {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228impl AbrAlgorithm for ThroughputAlgorithm {
229 fn select_rendition<'a>(
230 &self,
231 renditions: &'a [Rendition],
232 context: &AbrContext,
233 ) -> Option<&'a Rendition> {
234 let available_bandwidth =
235 (context.network.bandwidth_estimate as f64 * self.safety_factor) as u64;
236
237 let max_bitrate = if context.max_bitrate > 0 {
239 context.max_bitrate.min(available_bandwidth)
240 } else {
241 available_bandwidth
242 };
243
244 renditions
246 .iter()
247 .filter(|r| r.bandwidth <= max_bitrate)
248 .filter(|r| {
249 if let (Some(res), Some(screen_w)) = (&r.resolution, context.screen_width) {
251 res.width <= screen_w
252 } else {
253 true
254 }
255 })
256 .max_by_key(|r| r.bandwidth)
257 }
258
259 fn update(&mut self, measurement: &BandwidthMeasurement) {
260 let sample = measurement.throughput_bps();
261 if self.throughput_estimate == 0 {
262 self.throughput_estimate = sample;
263 } else {
264 self.throughput_estimate =
265 ((self.throughput_estimate as f64 * 0.7) + (sample as f64 * 0.3)) as u64;
266 }
267 }
268
269 fn name(&self) -> &'static str {
270 "throughput"
271 }
272}
273
274pub struct BolaAlgorithm {
277 buffer_min: f64,
279 _buffer_max: f64,
281 v: f64,
283 gamma: f64,
285}
286
287impl BolaAlgorithm {
288 pub fn new() -> Self {
289 Self {
290 buffer_min: 5.0,
291 _buffer_max: 30.0,
292 v: 0.93,
293 gamma: 5.0,
294 }
295 }
296
297 fn utility(&self, rendition: &Rendition) -> f64 {
299 (rendition.bandwidth as f64).ln()
301 }
302}
303
304impl Default for BolaAlgorithm {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310impl AbrAlgorithm for BolaAlgorithm {
311 fn select_rendition<'a>(
312 &self,
313 renditions: &'a [Rendition],
314 context: &AbrContext,
315 ) -> Option<&'a Rendition> {
316 if renditions.is_empty() {
317 return None;
318 }
319
320 let buffer = context.buffer_level;
321
322 let mut best: Option<&Rendition> = None;
324 let mut best_score = f64::NEG_INFINITY;
325
326 for rendition in renditions {
327 if context.max_bitrate > 0 && rendition.bandwidth > context.max_bitrate {
329 continue;
330 }
331
332 let utility = self.utility(rendition);
333 let size = rendition.bandwidth as f64;
334
335 let score = (self.v * utility - buffer) / (size / 1_000_000.0 + self.gamma);
337
338 if score > best_score {
339 best_score = score;
340 best = Some(rendition);
341 }
342 }
343
344 if buffer < self.buffer_min {
346 return renditions.first();
347 }
348
349 best
350 }
351
352 fn update(&mut self, _measurement: &BandwidthMeasurement) {
353 }
355
356 fn name(&self) -> &'static str {
357 "bola"
358 }
359}
360
361pub struct HybridAlgorithm {
363 throughput: ThroughputAlgorithm,
364 bola: BolaAlgorithm,
365 _throughput_weight: f64,
367}
368
369impl HybridAlgorithm {
370 pub fn new() -> Self {
371 Self {
372 throughput: ThroughputAlgorithm::new(),
373 bola: BolaAlgorithm::new(),
374 _throughput_weight: 0.5,
375 }
376 }
377}
378
379impl Default for HybridAlgorithm {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385impl AbrAlgorithm for HybridAlgorithm {
386 fn select_rendition<'a>(
387 &self,
388 renditions: &'a [Rendition],
389 context: &AbrContext,
390 ) -> Option<&'a Rendition> {
391 let throughput_pick = self.throughput.select_rendition(renditions, context);
392 let bola_pick = self.bola.select_rendition(renditions, context);
393
394 match (throughput_pick, bola_pick) {
395 (Some(t), Some(b)) => {
396 if context.buffer_level < 10.0 {
398 Some(b)
399 } else if t.bandwidth <= b.bandwidth {
400 Some(t)
401 } else {
402 let t_idx = renditions.iter().position(|r| r.id == t.id).unwrap_or(0);
404 let b_idx = renditions.iter().position(|r| r.id == b.id).unwrap_or(0);
405 let avg_idx = (t_idx + b_idx) / 2;
406 renditions.get(avg_idx)
407 }
408 }
409 (Some(t), None) => Some(t),
410 (None, Some(b)) => Some(b),
411 (None, None) => renditions.first(),
412 }
413 }
414
415 fn update(&mut self, measurement: &BandwidthMeasurement) {
416 self.throughput.update(measurement);
417 self.bola.update(measurement);
418 }
419
420 fn name(&self) -> &'static str {
421 "hybrid"
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use url::Url;
429
430 fn create_test_renditions() -> Vec<Rendition> {
431 vec![
432 Rendition {
433 id: "360p".to_string(),
434 bandwidth: 800_000,
435 resolution: Some(Resolution::new(640, 360)),
436 frame_rate: None,
437 video_codec: Some(VideoCodec::H264),
438 audio_codec: Some(AudioCodec::Aac),
439 uri: Url::parse("https://example.com/360p.m3u8").unwrap(),
440 hdr: None,
441 language: None,
442 name: None,
443 },
444 Rendition {
445 id: "720p".to_string(),
446 bandwidth: 2_800_000,
447 resolution: Some(Resolution::new(1280, 720)),
448 frame_rate: None,
449 video_codec: Some(VideoCodec::H264),
450 audio_codec: Some(AudioCodec::Aac),
451 uri: Url::parse("https://example.com/720p.m3u8").unwrap(),
452 hdr: None,
453 language: None,
454 name: None,
455 },
456 Rendition {
457 id: "1080p".to_string(),
458 bandwidth: 5_000_000,
459 resolution: Some(Resolution::new(1920, 1080)),
460 frame_rate: None,
461 video_codec: Some(VideoCodec::H264),
462 audio_codec: Some(AudioCodec::Aac),
463 uri: Url::parse("https://example.com/1080p.m3u8").unwrap(),
464 hdr: None,
465 language: None,
466 name: None,
467 },
468 ]
469 }
470
471 #[test]
472 fn test_throughput_selection() {
473 let renditions = create_test_renditions();
474 let algorithm = ThroughputAlgorithm::new();
475
476 let context = AbrContext {
478 buffer_level: 20.0,
479 network: NetworkInfo {
480 bandwidth_estimate: 10_000_000,
481 ..Default::default()
482 },
483 ..Default::default()
484 };
485
486 let selected = algorithm.select_rendition(&renditions, &context);
487 assert_eq!(selected.map(|r| &r.id), Some(&"1080p".to_string()));
488
489 let context = AbrContext {
491 buffer_level: 20.0,
492 network: NetworkInfo {
493 bandwidth_estimate: 1_000_000,
494 ..Default::default()
495 },
496 ..Default::default()
497 };
498
499 let selected = algorithm.select_rendition(&renditions, &context);
500 assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
501 }
502
503 #[test]
504 fn test_bola_low_buffer() {
505 let renditions = create_test_renditions();
506 let algorithm = BolaAlgorithm::new();
507
508 let context = AbrContext {
510 buffer_level: 2.0,
511 ..Default::default()
512 };
513
514 let selected = algorithm.select_rendition(&renditions, &context);
515 assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
516 }
517}