#![allow(dead_code)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::similar_names)]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::match_same_arms)]
#![allow(clippy::many_single_char_names)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::range_plus_one)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::manual_div_ceil)]
#![allow(clippy::comparison_chain)]
#![allow(clippy::unused_self)]
#![allow(clippy::trivially_copy_pass_by_ref)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::struct_excessive_bools)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::redundant_closure_for_method_calls)]
#![allow(clippy::must_use_candidate)]
#![allow(clippy::should_implement_trait)]
#![allow(clippy::items_after_statements)]
#![allow(clippy::if_not_else)]
#![allow(clippy::format_push_string)]
#![allow(clippy::single_match_else)]
#![allow(clippy::redundant_slicing)]
#![allow(clippy::uninlined_format_args)]
#![allow(clippy::map_unwrap_or)]
#![allow(clippy::derivable_impls)]
#![allow(clippy::assigning_clones)]
#![allow(clippy::if_same_then_else)]
#![allow(clippy::format_collect)]
#![allow(clippy::useless_conversion)]
#![allow(clippy::unused_async)]
#![allow(clippy::identity_op)]
#[cfg(test)]
use super::mpd::SegmentTemplate;
use super::mpd::{AdaptationSet, Mpd, Period, Representation};
use super::segment::{DashSegment, SegmentGenerator, SegmentInfo};
use crate::abr::{AbrDecision, AdaptiveBitrateController, QualityLevel};
use crate::error::{NetError, NetResult};
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct DashClientConfig {
pub max_retries: u32,
pub retry_delay: Duration,
pub timeout: Duration,
pub max_segment_size: usize,
pub base_url: Option<String>,
pub cache_init_segments: bool,
pub enable_abr: bool,
pub min_buffer_duration: Duration,
pub target_buffer_duration: Duration,
pub max_buffer_duration: Duration,
}
impl Default for DashClientConfig {
fn default() -> Self {
Self {
max_retries: 3,
retry_delay: Duration::from_millis(500),
timeout: Duration::from_secs(30),
max_segment_size: 50 * 1024 * 1024, base_url: None,
cache_init_segments: true,
enable_abr: true,
min_buffer_duration: Duration::from_secs(2),
target_buffer_duration: Duration::from_secs(10),
max_buffer_duration: Duration::from_secs(30),
}
}
}
impl DashClientConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
#[must_use]
pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub const fn with_init_cache(mut self, enable: bool) -> Self {
self.cache_init_segments = enable;
self
}
#[must_use]
pub const fn with_abr(mut self, enable: bool) -> Self {
self.enable_abr = enable;
self
}
#[must_use]
pub const fn with_target_buffer(mut self, duration: Duration) -> Self {
self.target_buffer_duration = duration;
self
}
}
#[derive(Debug, Clone)]
pub struct FetchResult {
pub data: Bytes,
pub segment_info: SegmentInfo,
pub download_time: Duration,
pub bytes_downloaded: usize,
pub throughput: f64,
pub is_init: bool,
}
impl FetchResult {
#[must_use]
pub fn new(
data: Bytes,
segment_info: SegmentInfo,
download_time: Duration,
is_init: bool,
) -> Self {
let bytes_downloaded = data.len();
let throughput = if download_time.as_secs_f64() > 0.0 {
bytes_downloaded as f64 / download_time.as_secs_f64()
} else {
0.0
};
Self {
data,
segment_info,
download_time,
bytes_downloaded,
throughput,
is_init,
}
}
#[must_use]
pub fn throughput_bps(&self) -> f64 {
self.throughput * 8.0
}
}
#[derive(Debug)]
pub struct RepresentationSelection {
pub representation_index: usize,
pub representation: Representation,
pub generator: SegmentGenerator,
}
impl RepresentationSelection {
pub fn new(
representation_index: usize,
representation: Representation,
base_url: Option<&str>,
) -> NetResult<Self> {
let mut generator = SegmentGenerator::from_representation(&representation)
.ok_or_else(|| NetError::invalid_state("No segment template in representation"))?;
if let Some(url) = base_url {
generator = generator.with_base_url(url);
}
Ok(Self {
representation_index,
representation,
generator,
})
}
#[must_use]
pub const fn bandwidth(&self) -> u64 {
self.representation.bandwidth
}
#[must_use]
pub fn representation_id(&self) -> &str {
&self.representation.id
}
}
#[derive(Debug)]
pub struct StreamSession {
pub period_index: usize,
pub adaptation_set_index: usize,
pub representation: RepresentationSelection,
pub current_segment_number: u64,
pub current_time: f64,
pub segments_downloaded: u64,
pub bytes_downloaded: u64,
pub buffer_duration: Duration,
}
impl StreamSession {
pub fn new(
period_index: usize,
adaptation_set_index: usize,
representation: RepresentationSelection,
) -> Self {
let start_number = if let Some(seg) = representation.generator.segment_by_number(1) {
seg.info.number
} else {
1
};
Self {
period_index,
adaptation_set_index,
representation,
current_segment_number: start_number,
current_time: 0.0,
segments_downloaded: 0,
bytes_downloaded: 0,
buffer_duration: Duration::ZERO,
}
}
#[must_use]
pub fn next_segment(&self) -> Option<DashSegment> {
self.representation
.generator
.segment_by_number(self.current_segment_number)
}
#[must_use]
pub fn segment_at_time(&self, time_secs: f64) -> Option<DashSegment> {
self.representation.generator.segment_at_time(time_secs)
}
pub fn segments_in_range(&self, start_secs: f64, end_secs: f64) -> Vec<DashSegment> {
self.representation
.generator
.segments_for_range(start_secs, end_secs)
}
pub fn advance(&mut self) {
self.current_segment_number += 1;
}
pub fn report_download(&mut self, bytes: usize, segment_duration: Duration) {
self.segments_downloaded += 1;
self.bytes_downloaded += bytes as u64;
self.buffer_duration += segment_duration;
}
pub fn update_buffer(&mut self, consumed: Duration) {
if self.buffer_duration >= consumed {
self.buffer_duration -= consumed;
} else {
self.buffer_duration = Duration::ZERO;
}
}
pub fn switch_representation(
&mut self,
new_representation: RepresentationSelection,
) -> NetResult<()> {
let current_time = self.current_time;
self.representation = new_representation;
if let Some(segment) = self.representation.generator.segment_at_time(current_time) {
self.current_segment_number = segment.info.number;
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct InitSegmentCache {
data: Bytes,
representation_id: String,
cached_at: Instant,
}
pub struct DashClient {
config: DashClientConfig,
mpd: Mpd,
init_cache: Arc<RwLock<HashMap<String, InitSegmentCache>>>,
abr_controller: Option<Box<dyn AdaptiveBitrateController>>,
total_bytes_downloaded: u64,
total_download_time: Duration,
}
impl DashClient {
#[must_use]
pub fn new(mpd: Mpd, config: DashClientConfig) -> Self {
Self {
config,
mpd,
init_cache: Arc::new(RwLock::new(HashMap::new())),
abr_controller: None,
total_bytes_downloaded: 0,
total_download_time: Duration::ZERO,
}
}
#[must_use]
pub fn with_defaults(mpd: Mpd) -> Self {
Self::new(mpd, DashClientConfig::default())
}
pub fn set_abr_controller(&mut self, controller: Box<dyn AdaptiveBitrateController>) {
self.abr_controller = Some(controller);
}
#[must_use]
pub const fn mpd(&self) -> &Mpd {
&self.mpd
}
#[must_use]
pub const fn config(&self) -> &DashClientConfig {
&self.config
}
#[must_use]
pub fn period_count(&self) -> usize {
self.mpd.periods.len()
}
#[must_use]
pub fn period(&self, index: usize) -> Option<&Period> {
self.mpd.periods.get(index)
}
#[must_use]
pub fn periods(&self) -> &[Period] {
&self.mpd.periods
}
pub fn video_adaptation_sets(
&self,
period_index: usize,
) -> NetResult<Vec<(usize, &AdaptationSet)>> {
let period = self
.period(period_index)
.ok_or_else(|| NetError::not_found(format!("Period {period_index} not found")))?;
Ok(period
.adaptation_sets
.iter()
.enumerate()
.filter(|(_, as_)| as_.is_video())
.collect())
}
pub fn audio_adaptation_sets(
&self,
period_index: usize,
) -> NetResult<Vec<(usize, &AdaptationSet)>> {
let period = self
.period(period_index)
.ok_or_else(|| NetError::not_found(format!("Period {period_index} not found")))?;
Ok(period
.adaptation_sets
.iter()
.enumerate()
.filter(|(_, as_)| as_.is_audio())
.collect())
}
pub fn quality_levels_from_adaptation_set(
&self,
adaptation_set: &AdaptationSet,
) -> Vec<QualityLevel> {
crate::abr::dash::representations_to_quality_levels(&adaptation_set.representations)
}
pub fn select_representation(
&self,
period_index: usize,
adaptation_set_index: usize,
preferred_bandwidth: Option<u64>,
) -> NetResult<RepresentationSelection> {
let period = self
.period(period_index)
.ok_or_else(|| NetError::not_found(format!("Period {period_index} not found")))?;
let adaptation_set = period
.adaptation_sets
.get(adaptation_set_index)
.ok_or_else(|| {
NetError::not_found(format!("Adaptation set {adaptation_set_index} not found"))
})?;
let base_url = self.resolve_base_url(period, adaptation_set);
let rep_index = if let Some(target_bw) = preferred_bandwidth {
self.find_best_representation_for_bandwidth(adaptation_set, target_bw)
} else {
0
};
let representation = adaptation_set
.representations
.get(rep_index)
.ok_or_else(|| NetError::not_found(format!("Representation {rep_index} not found")))?
.clone();
RepresentationSelection::new(rep_index, representation, base_url.as_deref())
}
fn find_best_representation_for_bandwidth(
&self,
adaptation_set: &AdaptationSet,
bandwidth: u64,
) -> usize {
let mut best_idx = 0;
let mut best_bandwidth = 0u64;
for (idx, rep) in adaptation_set.representations.iter().enumerate() {
if rep.bandwidth <= bandwidth && rep.bandwidth > best_bandwidth {
best_idx = idx;
best_bandwidth = rep.bandwidth;
}
}
best_idx
}
fn resolve_base_url(&self, period: &Period, adaptation_set: &AdaptationSet) -> Option<String> {
if let Some(ref base) = self.config.base_url {
return Some(base.clone());
}
if let Some(base) = self.mpd.base_urls.first() {
return Some(base.clone());
}
if let Some(base) = period.base_urls.first() {
return Some(base.clone());
}
adaptation_set
.representations
.first()
.and_then(|rep| rep.base_urls.first())
.cloned()
}
pub fn create_session(
&self,
period_index: usize,
adaptation_set_index: usize,
initial_bandwidth: Option<u64>,
) -> NetResult<StreamSession> {
let representation =
self.select_representation(period_index, adaptation_set_index, initial_bandwidth)?;
Ok(StreamSession::new(
period_index,
adaptation_set_index,
representation,
))
}
pub async fn fetch_initialization_segment(
&mut self,
session: &StreamSession,
) -> NetResult<FetchResult> {
let init_segment = session
.representation
.generator
.initialization_segment()
.ok_or_else(|| NetError::segment("No initialization segment available"))?;
let cache_key = format!(
"{}_{}",
session.period_index,
session.representation.representation_id()
);
if self.config.cache_init_segments {
let cache = self.init_cache.read().await;
if let Some(cached) = cache.get(&cache_key) {
return Ok(FetchResult::new(
cached.data.clone(),
init_segment.info.clone(),
Duration::ZERO,
true,
));
}
}
let start = Instant::now();
let data = self
.fetch_segment_data(&init_segment.url, init_segment.byte_range)
.await?;
let download_time = start.elapsed();
if self.config.cache_init_segments {
let mut cache = self.init_cache.write().await;
cache.insert(
cache_key,
InitSegmentCache {
data: data.clone(),
representation_id: session.representation.representation_id().to_string(),
cached_at: Instant::now(),
},
);
}
Ok(FetchResult::new(
data,
init_segment.info,
download_time,
true,
))
}
pub async fn fetch_next_segment(
&mut self,
session: &mut StreamSession,
) -> NetResult<FetchResult> {
let segment = session
.next_segment()
.ok_or_else(|| NetError::segment("No more segments available"))?;
let start = Instant::now();
let data = self
.fetch_segment_data(&segment.url, segment.byte_range)
.await?;
let download_time = start.elapsed();
self.total_bytes_downloaded += data.len() as u64;
self.total_download_time += download_time;
if let Some(ref mut abr) = self.abr_controller {
abr.report_segment_download(data.len(), download_time);
abr.report_buffer_level(session.buffer_duration);
}
session.report_download(data.len(), segment.info.segment_duration());
session.advance();
Ok(FetchResult::new(data, segment.info, download_time, false))
}
pub async fn fetch_segment_by_number(
&mut self,
session: &StreamSession,
segment_number: u64,
) -> NetResult<FetchResult> {
let segment = session
.representation
.generator
.segment_by_number(segment_number)
.ok_or_else(|| NetError::segment(format!("Segment {segment_number} not available")))?;
let start = Instant::now();
let data = self
.fetch_segment_data(&segment.url, segment.byte_range)
.await?;
let download_time = start.elapsed();
Ok(FetchResult::new(data, segment.info, download_time, false))
}
pub async fn fetch_segment_at_time(
&mut self,
session: &StreamSession,
time_secs: f64,
) -> NetResult<FetchResult> {
let segment = session
.segment_at_time(time_secs)
.ok_or_else(|| NetError::segment(format!("No segment at time {time_secs}s")))?;
let start = Instant::now();
let data = self
.fetch_segment_data(&segment.url, segment.byte_range)
.await?;
let download_time = start.elapsed();
Ok(FetchResult::new(data, segment.info, download_time, false))
}
pub async fn perform_abr_decision(
&mut self,
session: &mut StreamSession,
) -> NetResult<Option<AbrDecision>> {
if !self.config.enable_abr {
return Ok(None);
}
let Some(ref abr) = self.abr_controller else {
return Ok(None);
};
let period = self
.period(session.period_index)
.ok_or_else(|| NetError::invalid_state("Period not found"))?;
let adaptation_set = period
.adaptation_sets
.get(session.adaptation_set_index)
.ok_or_else(|| NetError::invalid_state("Adaptation set not found"))?;
let levels = self.quality_levels_from_adaptation_set(adaptation_set);
let decision = abr.select_quality(&levels, session.representation.representation_index);
if let Some(target_idx) = decision.switch_target() {
if target_idx != session.representation.representation_index {
let new_rep = self.select_representation(
session.period_index,
session.adaptation_set_index,
Some(levels[target_idx].bandwidth),
)?;
session.switch_representation(new_rep)?;
}
}
Ok(Some(decision))
}
async fn fetch_segment_data(
&self,
url: &str,
byte_range: Option<(u64, u64)>,
) -> NetResult<Bytes> {
let mut last_error = None;
let mut delay = self.config.retry_delay;
for attempt in 0..=self.config.max_retries {
match self.fetch_http_segment(url, byte_range).await {
Ok(data) => return Ok(data),
Err(e) => {
last_error = Some(e);
if attempt < self.config.max_retries {
tokio::time::sleep(delay).await;
delay *= 2; }
}
}
}
Err(last_error.unwrap_or_else(|| NetError::segment("Unknown fetch error")))
}
async fn fetch_http_segment(
&self,
url: &str,
byte_range: Option<(u64, u64)>,
) -> NetResult<Bytes> {
let _range_header = byte_range.map(|(start, end)| format!("bytes={start}-{end}"));
tokio::time::timeout(self.config.timeout, async {
Err(NetError::not_found(format!(
"HTTP client not implemented: {url}"
)))
})
.await
.map_err(|_| NetError::timeout("Segment fetch timed out"))?
}
#[must_use]
pub fn average_throughput(&self) -> f64 {
if self.total_download_time.as_secs_f64() > 0.0 {
self.total_bytes_downloaded as f64 / self.total_download_time.as_secs_f64()
} else {
0.0
}
}
#[must_use]
pub fn average_throughput_bps(&self) -> f64 {
self.average_throughput() * 8.0
}
#[must_use]
pub const fn total_bytes_downloaded(&self) -> u64 {
self.total_bytes_downloaded
}
pub async fn clear_init_cache(&self) {
let mut cache = self.init_cache.write().await;
cache.clear();
}
pub async fn init_cache_size(&self) -> usize {
let cache = self.init_cache.read().await;
cache.len()
}
#[must_use]
pub fn presentation_duration(&self) -> Option<Duration> {
self.mpd.duration()
}
#[must_use]
pub const fn is_live(&self) -> bool {
self.mpd.is_live()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dash::mpd::MpdType;
fn create_test_mpd() -> Mpd {
let mut mpd = Mpd::new();
mpd.mpd_type = MpdType::Static;
mpd.media_presentation_duration = Some(Duration::from_secs(600));
let mut period = Period::new();
let mut video_as = AdaptationSet::new();
video_as.content_type = Some("video".to_string());
video_as.mime_type = Some("video/mp4".to_string());
let mut low_rep = Representation::new("low", 500_000);
low_rep.width = Some(640);
low_rep.height = Some(360);
low_rep.segment_template = Some(
SegmentTemplate::new(90000)
.with_media("video_$RepresentationID$_$Number$.m4s")
.with_initialization("video_$RepresentationID$_init.mp4"),
);
low_rep
.segment_template
.as_mut()
.expect("should succeed in test")
.duration = Some(180000);
let mut high_rep = Representation::new("high", 2_000_000);
high_rep.width = Some(1920);
high_rep.height = Some(1080);
high_rep.segment_template = low_rep.segment_template.clone();
video_as.representations.push(low_rep);
video_as.representations.push(high_rep);
period.adaptation_sets.push(video_as);
mpd.periods.push(period);
mpd
}
#[test]
fn test_client_creation() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
assert_eq!(client.period_count(), 1);
assert!(!client.is_live());
assert_eq!(
client.presentation_duration(),
Some(Duration::from_secs(600))
);
}
#[test]
fn test_find_adaptation_sets() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let video_sets = client
.video_adaptation_sets(0)
.expect("should succeed in test");
assert_eq!(video_sets.len(), 1);
let audio_sets = client
.audio_adaptation_sets(0)
.expect("should succeed in test");
assert_eq!(audio_sets.len(), 0);
}
#[test]
fn test_representation_selection() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let rep = client
.select_representation(0, 0, Some(500_000))
.expect("should succeed in test");
assert_eq!(rep.representation.id, "low");
let rep = client
.select_representation(0, 0, Some(2_000_000))
.expect("should succeed in test");
assert_eq!(rep.representation.id, "high");
let rep = client
.select_representation(0, 0, Some(100_000))
.expect("should succeed in test");
assert_eq!(rep.representation.id, "low");
}
#[test]
fn test_session_creation() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let session = client
.create_session(0, 0, Some(500_000))
.expect("should succeed in test");
assert_eq!(session.period_index, 0);
assert_eq!(session.adaptation_set_index, 0);
assert_eq!(session.current_segment_number, 1);
assert_eq!(session.segments_downloaded, 0);
}
#[test]
fn test_session_segment_navigation() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let mut session = client
.create_session(0, 0, None)
.expect("should succeed in test");
let seg1 = session.next_segment().expect("should succeed in test");
assert_eq!(seg1.info.number, 1);
session.advance();
let seg2 = session.next_segment().expect("should succeed in test");
assert_eq!(seg2.info.number, 2);
}
#[test]
fn test_quality_levels_from_adaptation_set() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let period = client.period(0).expect("should succeed in test");
let adaptation_set = &period.adaptation_sets[0];
let levels = client.quality_levels_from_adaptation_set(adaptation_set);
assert_eq!(levels.len(), 2);
assert_eq!(levels[0].bandwidth, 500_000);
assert_eq!(levels[1].bandwidth, 2_000_000);
}
#[test]
fn test_config_builder() {
let config = DashClientConfig::new()
.with_base_url("https://example.com/")
.with_max_retries(5)
.with_timeout(Duration::from_secs(60))
.with_init_cache(false)
.with_abr(true);
assert_eq!(config.base_url, Some("https://example.com/".to_string()));
assert_eq!(config.max_retries, 5);
assert_eq!(config.timeout, Duration::from_secs(60));
assert!(!config.cache_init_segments);
assert!(config.enable_abr);
}
#[test]
fn test_session_buffer_management() {
let mpd = create_test_mpd();
let client = DashClient::with_defaults(mpd);
let mut session = client
.create_session(0, 0, None)
.expect("should succeed in test");
session.report_download(1000, Duration::from_secs(2));
assert_eq!(session.buffer_duration, Duration::from_secs(2));
assert_eq!(session.bytes_downloaded, 1000);
session.update_buffer(Duration::from_secs(1));
assert_eq!(session.buffer_duration, Duration::from_secs(1));
}
#[test]
fn test_fetch_result() {
let data = Bytes::from(vec![0u8; 1000]);
let info = SegmentInfo::new(1, 0, 90000, 90000);
let result = FetchResult::new(data, info, Duration::from_secs(1), false);
assert_eq!(result.bytes_downloaded, 1000);
assert!((result.throughput - 1000.0).abs() < 0.1);
assert!((result.throughput_bps() - 8000.0).abs() < 0.1);
assert!(!result.is_init);
}
}