Skip to main content

kino_core/
abr.rs

1//! Adaptive Bitrate (ABR) Engine
2//!
3//! Implements multiple ABR algorithms:
4//! - Throughput-based: Simple bandwidth estimation
5//! - BOLA: Buffer Occupancy based Lyapunov Algorithm
6//! - Hybrid: Combines throughput and buffer metrics
7
8use crate::types::*;
9use std::collections::VecDeque;
10use std::time::{Duration, Instant};
11use tracing::{debug, instrument};
12
13/// ABR algorithm trait
14pub trait AbrAlgorithm: Send + Sync {
15    /// Select the best rendition given current conditions
16    fn select_rendition<'a>(
17        &self,
18        renditions: &'a [Rendition],
19        context: &AbrContext,
20    ) -> Option<&'a Rendition>;
21
22    /// Update algorithm state with new measurement
23    fn update(&mut self, measurement: &BandwidthMeasurement);
24
25    /// Get algorithm name
26    fn name(&self) -> &'static str;
27}
28
29/// Context for ABR decisions
30#[derive(Debug, Clone, Default)]
31pub struct AbrContext {
32    /// Current buffer level in seconds
33    pub buffer_level: f64,
34    /// Target buffer level
35    pub target_buffer: f64,
36    /// Current playback rate (1.0 = normal)
37    pub playback_rate: f64,
38    /// Is stream live
39    pub is_live: bool,
40    /// Screen width for resolution capping
41    pub screen_width: Option<u32>,
42    /// Maximum allowed bitrate (0 = unlimited)
43    pub max_bitrate: u64,
44    /// Network info
45    pub network: NetworkInfo,
46}
47
48/// Bandwidth measurement sample
49#[derive(Debug, Clone)]
50pub struct BandwidthMeasurement {
51    /// Bytes downloaded
52    pub bytes: usize,
53    /// Time taken
54    pub duration: Duration,
55    /// Timestamp
56    pub timestamp: Instant,
57}
58
59impl BandwidthMeasurement {
60    /// Calculate throughput in bits per second
61    pub fn throughput_bps(&self) -> u64 {
62        if self.duration.as_secs_f64() > 0.0 {
63            ((self.bytes as f64 * 8.0) / self.duration.as_secs_f64()) as u64
64        } else {
65            0
66        }
67    }
68}
69
70/// ABR Engine combining multiple algorithms
71pub struct AbrEngine {
72    /// Active algorithm
73    algorithm: Box<dyn AbrAlgorithm>,
74    /// Bandwidth history
75    bandwidth_history: VecDeque<BandwidthMeasurement>,
76    /// Maximum history size
77    max_history: usize,
78    /// Current bandwidth estimate
79    bandwidth_estimate: u64,
80    /// Last selected rendition index
81    last_selection: Option<usize>,
82    /// Stability counter (prevent oscillation)
83    stability_counter: u32,
84}
85
86impl AbrEngine {
87    /// Create new ABR engine with specified algorithm
88    pub fn new(algorithm_type: AbrAlgorithmType) -> Self {
89        let algorithm: Box<dyn AbrAlgorithm> = match algorithm_type {
90            AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
91            AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
92            AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
93            AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()), // Fallback
94        };
95
96        Self {
97            algorithm,
98            bandwidth_history: VecDeque::with_capacity(20),
99            max_history: 20,
100            bandwidth_estimate: 0,
101            last_selection: None,
102            stability_counter: 0,
103        }
104    }
105
106    /// Record a bandwidth measurement
107    #[instrument(skip(self))]
108    pub fn record_measurement(&mut self, bytes: usize, duration: Duration) {
109        let measurement = BandwidthMeasurement {
110            bytes,
111            duration,
112            timestamp: Instant::now(),
113        };
114
115        // Update history
116        if self.bandwidth_history.len() >= self.max_history {
117            self.bandwidth_history.pop_front();
118        }
119        self.bandwidth_history.push_back(measurement.clone());
120
121        // Update estimate using EWMA
122        let sample = measurement.throughput_bps();
123        if self.bandwidth_estimate == 0 {
124            self.bandwidth_estimate = sample;
125        } else {
126            // EWMA with alpha = 0.2
127            self.bandwidth_estimate =
128                ((self.bandwidth_estimate as f64 * 0.8) + (sample as f64 * 0.2)) as u64;
129        }
130
131        // Update algorithm
132        self.algorithm.update(&measurement);
133
134        debug!(
135            bytes = bytes,
136            duration_ms = duration.as_millis(),
137            throughput_mbps = sample as f64 / 1_000_000.0,
138            estimate_mbps = self.bandwidth_estimate as f64 / 1_000_000.0,
139            "Bandwidth measurement recorded"
140        );
141    }
142
143    /// Select best rendition
144    #[instrument(skip(self, renditions))]
145    pub fn select_rendition<'a>(
146        &mut self,
147        renditions: &'a [Rendition],
148        context: &AbrContext,
149    ) -> Option<&'a Rendition> {
150        if renditions.is_empty() {
151            return None;
152        }
153
154        // Get algorithm recommendation
155        let selected = self.algorithm.select_rendition(renditions, context)?;
156
157        // Find index
158        let new_index = renditions.iter().position(|r| r.id == selected.id)?;
159
160        // Apply stability filter to prevent oscillation
161        if let Some(last) = self.last_selection {
162            if new_index != last {
163                self.stability_counter += 1;
164                if self.stability_counter < 3 {
165                    // Don't switch yet
166                    return renditions.get(last);
167                }
168            }
169            self.stability_counter = 0;
170        }
171
172        self.last_selection = Some(new_index);
173
174        debug!(
175            selected_id = %selected.id,
176            bandwidth = selected.bandwidth,
177            resolution = ?selected.resolution,
178            "Rendition selected"
179        );
180
181        Some(selected)
182    }
183
184    /// Get current bandwidth estimate
185    pub fn bandwidth_estimate(&self) -> u64 {
186        self.bandwidth_estimate
187    }
188
189    /// Get algorithm name
190    pub fn algorithm_name(&self) -> &'static str {
191        self.algorithm.name()
192    }
193
194    /// Force switch algorithm
195    pub fn set_algorithm(&mut self, algorithm_type: AbrAlgorithmType) {
196        self.algorithm = match algorithm_type {
197            AbrAlgorithmType::Throughput => Box::new(ThroughputAlgorithm::new()),
198            AbrAlgorithmType::Bola => Box::new(BolaAlgorithm::new()),
199            AbrAlgorithmType::Hybrid => Box::new(HybridAlgorithm::new()),
200            AbrAlgorithmType::Ml => Box::new(ThroughputAlgorithm::new()),
201        };
202    }
203}
204
205/// Throughput-based ABR algorithm
206pub struct ThroughputAlgorithm {
207    /// Safety factor (0.0-1.0)
208    safety_factor: f64,
209    /// Estimated throughput
210    throughput_estimate: u64,
211}
212
213impl ThroughputAlgorithm {
214    pub fn new() -> Self {
215        Self {
216            safety_factor: 0.8, // Use 80% of estimated bandwidth
217            throughput_estimate: 0,
218        }
219    }
220}
221
222impl Default for ThroughputAlgorithm {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228impl AbrAlgorithm for ThroughputAlgorithm {
229    fn select_rendition<'a>(
230        &self,
231        renditions: &'a [Rendition],
232        context: &AbrContext,
233    ) -> Option<&'a Rendition> {
234        let available_bandwidth =
235            (context.network.bandwidth_estimate as f64 * self.safety_factor) as u64;
236
237        // Filter by max bitrate if set
238        let max_bitrate = if context.max_bitrate > 0 {
239            context.max_bitrate.min(available_bandwidth)
240        } else {
241            available_bandwidth
242        };
243
244        // Select highest quality that fits in bandwidth
245        renditions
246            .iter()
247            .filter(|r| r.bandwidth <= max_bitrate)
248            .filter(|r| {
249                // Filter by screen resolution if available
250                if let (Some(res), Some(screen_w)) = (&r.resolution, context.screen_width) {
251                    res.width <= screen_w
252                } else {
253                    true
254                }
255            })
256            .max_by_key(|r| r.bandwidth)
257    }
258
259    fn update(&mut self, measurement: &BandwidthMeasurement) {
260        let sample = measurement.throughput_bps();
261        if self.throughput_estimate == 0 {
262            self.throughput_estimate = sample;
263        } else {
264            self.throughput_estimate =
265                ((self.throughput_estimate as f64 * 0.7) + (sample as f64 * 0.3)) as u64;
266        }
267    }
268
269    fn name(&self) -> &'static str {
270        "throughput"
271    }
272}
273
274/// BOLA (Buffer Occupancy based Lyapunov Algorithm)
275/// Paper: https://arxiv.org/abs/1601.06748
276pub struct BolaAlgorithm {
277    /// Minimum buffer (seconds)
278    buffer_min: f64,
279    /// Maximum buffer (seconds)
280    _buffer_max: f64,
281    /// BOLA parameter V
282    v: f64,
283    /// BOLA parameter gamma
284    gamma: f64,
285}
286
287impl BolaAlgorithm {
288    pub fn new() -> Self {
289        Self {
290            buffer_min: 5.0,
291            _buffer_max: 30.0,
292            v: 0.93,
293            gamma: 5.0,
294        }
295    }
296
297    /// Calculate utility for a rendition
298    fn utility(&self, rendition: &Rendition) -> f64 {
299        // Logarithmic utility function
300        (rendition.bandwidth as f64).ln()
301    }
302}
303
304impl Default for BolaAlgorithm {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310impl AbrAlgorithm for BolaAlgorithm {
311    fn select_rendition<'a>(
312        &self,
313        renditions: &'a [Rendition],
314        context: &AbrContext,
315    ) -> Option<&'a Rendition> {
316        if renditions.is_empty() {
317            return None;
318        }
319
320        let buffer = context.buffer_level;
321
322        // BOLA formula: maximize (V * utility - buffer_level) / (bitrate + gamma)
323        let mut best: Option<&Rendition> = None;
324        let mut best_score = f64::NEG_INFINITY;
325
326        for rendition in renditions {
327            // Skip if over max bitrate
328            if context.max_bitrate > 0 && rendition.bandwidth > context.max_bitrate {
329                continue;
330            }
331
332            let utility = self.utility(rendition);
333            let size = rendition.bandwidth as f64;
334
335            // BOLA objective function
336            let score = (self.v * utility - buffer) / (size / 1_000_000.0 + self.gamma);
337
338            if score > best_score {
339                best_score = score;
340                best = Some(rendition);
341            }
342        }
343
344        // Safety: if buffer is very low, pick lowest quality
345        if buffer < self.buffer_min {
346            return renditions.first();
347        }
348
349        best
350    }
351
352    fn update(&mut self, _measurement: &BandwidthMeasurement) {
353        // BOLA doesn't use throughput measurements directly
354    }
355
356    fn name(&self) -> &'static str {
357        "bola"
358    }
359}
360
361/// Hybrid algorithm combining throughput and buffer metrics
362pub struct HybridAlgorithm {
363    throughput: ThroughputAlgorithm,
364    bola: BolaAlgorithm,
365    /// Weight for throughput (0.0-1.0)
366    _throughput_weight: f64,
367}
368
369impl HybridAlgorithm {
370    pub fn new() -> Self {
371        Self {
372            throughput: ThroughputAlgorithm::new(),
373            bola: BolaAlgorithm::new(),
374            _throughput_weight: 0.5,
375        }
376    }
377}
378
379impl Default for HybridAlgorithm {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385impl AbrAlgorithm for HybridAlgorithm {
386    fn select_rendition<'a>(
387        &self,
388        renditions: &'a [Rendition],
389        context: &AbrContext,
390    ) -> Option<&'a Rendition> {
391        let throughput_pick = self.throughput.select_rendition(renditions, context);
392        let bola_pick = self.bola.select_rendition(renditions, context);
393
394        match (throughput_pick, bola_pick) {
395            (Some(t), Some(b)) => {
396                // If buffer is low, prefer BOLA (more conservative)
397                if context.buffer_level < 10.0 {
398                    Some(b)
399                } else if t.bandwidth <= b.bandwidth {
400                    Some(t)
401                } else {
402                    // Average the two
403                    let t_idx = renditions.iter().position(|r| r.id == t.id).unwrap_or(0);
404                    let b_idx = renditions.iter().position(|r| r.id == b.id).unwrap_or(0);
405                    let avg_idx = (t_idx + b_idx) / 2;
406                    renditions.get(avg_idx)
407                }
408            }
409            (Some(t), None) => Some(t),
410            (None, Some(b)) => Some(b),
411            (None, None) => renditions.first(),
412        }
413    }
414
415    fn update(&mut self, measurement: &BandwidthMeasurement) {
416        self.throughput.update(measurement);
417        self.bola.update(measurement);
418    }
419
420    fn name(&self) -> &'static str {
421        "hybrid"
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use url::Url;
429
430    fn create_test_renditions() -> Vec<Rendition> {
431        vec![
432            Rendition {
433                id: "360p".to_string(),
434                bandwidth: 800_000,
435                resolution: Some(Resolution::new(640, 360)),
436                frame_rate: None,
437                video_codec: Some(VideoCodec::H264),
438                audio_codec: Some(AudioCodec::Aac),
439                uri: Url::parse("https://example.com/360p.m3u8").unwrap(),
440                hdr: None,
441                language: None,
442                name: None,
443            },
444            Rendition {
445                id: "720p".to_string(),
446                bandwidth: 2_800_000,
447                resolution: Some(Resolution::new(1280, 720)),
448                frame_rate: None,
449                video_codec: Some(VideoCodec::H264),
450                audio_codec: Some(AudioCodec::Aac),
451                uri: Url::parse("https://example.com/720p.m3u8").unwrap(),
452                hdr: None,
453                language: None,
454                name: None,
455            },
456            Rendition {
457                id: "1080p".to_string(),
458                bandwidth: 5_000_000,
459                resolution: Some(Resolution::new(1920, 1080)),
460                frame_rate: None,
461                video_codec: Some(VideoCodec::H264),
462                audio_codec: Some(AudioCodec::Aac),
463                uri: Url::parse("https://example.com/1080p.m3u8").unwrap(),
464                hdr: None,
465                language: None,
466                name: None,
467            },
468        ]
469    }
470
471    #[test]
472    fn test_throughput_selection() {
473        let renditions = create_test_renditions();
474        let algorithm = ThroughputAlgorithm::new();
475
476        // High bandwidth - should select 1080p
477        let context = AbrContext {
478            buffer_level: 20.0,
479            network: NetworkInfo {
480                bandwidth_estimate: 10_000_000,
481                ..Default::default()
482            },
483            ..Default::default()
484        };
485
486        let selected = algorithm.select_rendition(&renditions, &context);
487        assert_eq!(selected.map(|r| &r.id), Some(&"1080p".to_string()));
488
489        // Low bandwidth - should select 360p
490        let context = AbrContext {
491            buffer_level: 20.0,
492            network: NetworkInfo {
493                bandwidth_estimate: 1_000_000,
494                ..Default::default()
495            },
496            ..Default::default()
497        };
498
499        let selected = algorithm.select_rendition(&renditions, &context);
500        assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
501    }
502
503    #[test]
504    fn test_bola_low_buffer() {
505        let renditions = create_test_renditions();
506        let algorithm = BolaAlgorithm::new();
507
508        // Low buffer - should select lowest quality
509        let context = AbrContext {
510            buffer_level: 2.0,
511            ..Default::default()
512        };
513
514        let selected = algorithm.select_rendition(&renditions, &context);
515        assert_eq!(selected.map(|r| &r.id), Some(&"360p".to_string()));
516    }
517}