use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use reqwest::Client;
use std::collections::HashMap;
use std::time::Duration;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tracing::{debug, info};
use super::super::StreamQuality;
use super::super::backend::{
BackendType, ProgressCallback, StreamBackend, StreamConfig, StreamProgress,
};
pub struct NativeHlsBackend {
client: Client,
max_concurrent: usize,
max_retries: u32,
}
impl NativeHlsBackend {
pub fn new() -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.pool_max_idle_per_host(16) .pool_idle_timeout(Duration::from_secs(60))
.tcp_nodelay(true) .build()?;
Ok(Self {
client,
max_concurrent: 8, max_retries: 3,
})
}
#[must_use]
pub fn with_concurrency(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
async fn parse_master_playlist(
&self,
url: &str,
headers: &HashMap<String, String>,
) -> Result<Vec<HlsVariant>> {
let content = self.fetch_playlist(url, headers).await?;
let base_url = url.rsplit_once('/').map_or("", |(base, _)| base);
let mut variants = Vec::new();
let mut lines = content.lines().peekable();
while let Some(line) = lines.next() {
if let Some(rest) = line.strip_prefix("#EXT-X-STREAM-INF:") {
let attrs = Self::parse_attributes(rest);
let bandwidth = attrs
.get("BANDWIDTH")
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let resolution = attrs.get("RESOLUTION").cloned();
let codecs = attrs.get("CODECS").cloned();
if let Some(uri_line) = lines.next()
&& !uri_line.starts_with('#')
{
let uri = Self::resolve_url(base_url, uri_line);
let height = resolution
.as_ref()
.and_then(|r| r.split('x').nth(1))
.and_then(|h| h.parse().ok())
.unwrap_or(0);
variants.push(HlsVariant {
bandwidth,
height,
codecs,
uri,
});
}
}
}
variants.sort_by(|a, b| b.bandwidth.cmp(&a.bandwidth));
Ok(variants)
}
async fn parse_media_playlist(
&self,
url: &str,
headers: &HashMap<String, String>,
) -> Result<HlsPlaylist> {
let content = self.fetch_playlist(url, headers).await?;
let base_url = url.rsplit_once('/').map_or("", |(base, _)| base);
let mut segments = Vec::new();
let mut is_live = true;
let mut media_sequence = 0u64;
let mut target_duration = 10.0f64;
let mut current_duration = 0.0f64;
for line in content.lines() {
if line.starts_with("#EXT-X-ENDLIST") {
is_live = false;
} else if let Some(rest) = line.strip_prefix("#EXT-X-MEDIA-SEQUENCE:") {
media_sequence = rest.parse().unwrap_or(0);
} else if let Some(rest) = line.strip_prefix("#EXT-X-TARGETDURATION:") {
target_duration = rest.parse().unwrap_or(10.0);
} else if let Some(rest) = line.strip_prefix("#EXTINF:") {
current_duration = rest
.split(',')
.next()
.and_then(|d| d.parse().ok())
.unwrap_or(target_duration);
} else if !line.starts_with('#') && !line.is_empty() {
let uri = Self::resolve_url(base_url, line);
segments.push(HlsSegment {
sequence: media_sequence + segments.len() as u64,
duration: current_duration,
uri,
});
}
}
Ok(HlsPlaylist {
segments,
is_live,
target_duration,
media_sequence,
})
}
async fn fetch_playlist(&self, url: &str, headers: &HashMap<String, String>) -> Result<String> {
let mut req = self.client.get(url);
for (k, v) in headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = req
.send()
.await
.with_context(|| format!("Failed to fetch playlist from {url}"))?;
if !resp.status().is_success() {
return Err(anyhow!(
"Playlist fetch returned HTTP {}: {}",
resp.status(),
url
));
}
resp.text()
.await
.with_context(|| format!("Failed to read playlist body from {url}"))
}
async fn fetch_segment(&self, url: &str, headers: &HashMap<String, String>) -> Result<Vec<u8>> {
let mut last_error = None;
for attempt in 0..self.max_retries {
let mut req = self.client.get(url);
for (k, v) in headers {
req = req.header(k.as_str(), v.as_str());
}
match req.send().await {
Ok(resp) if resp.status().is_success() => {
return resp
.bytes()
.await
.map(|b| b.to_vec())
.with_context(|| format!("Failed to read segment body from {url}"));
}
Ok(resp) => {
last_error = Some(anyhow!(
"Segment fetch returned HTTP {}: {}",
resp.status(),
url
));
}
Err(e) => {
last_error = Some(
anyhow::Error::new(e).context(format!("Segment request failed: {url}")),
);
}
}
if attempt < self.max_retries - 1 {
tokio::time::sleep(Duration::from_millis(500 * (u64::from(attempt) + 1))).await;
}
}
Err(last_error.unwrap_or_else(|| anyhow!("Segment fetch exhausted all retries: {url}")))
}
fn parse_attributes(attr_str: &str) -> HashMap<String, String> {
let mut attrs = HashMap::new();
let mut chars = attr_str.chars().peekable();
while chars.peek().is_some() {
let key: String = chars.by_ref().take_while(|&c| c != '=').collect();
if key.is_empty() {
break;
}
let value = if chars.peek() == Some(&'"') {
chars.next(); let v: String = chars.by_ref().take_while(|&c| c != '"').collect();
chars.next(); v
} else {
chars.by_ref().take_while(|&c| c != ',').collect()
};
attrs.insert(key.trim().to_string(), value.trim().to_string());
}
attrs
}
fn resolve_url(base: &str, relative: &str) -> String {
if relative.starts_with("http://") || relative.starts_with("https://") {
relative.to_string()
} else if relative.starts_with('/') {
if let Some(idx) = base.find("://") {
if let Some(end) = base[idx + 3..].find('/') {
format!("{}{}", &base[..idx + 3 + end], relative)
} else {
format!("{base}{relative}")
}
} else {
relative.to_string()
}
} else {
format!("{base}/{relative}")
}
}
fn select_variant(variants: &[HlsVariant], quality: StreamQuality) -> Option<&HlsVariant> {
if variants.is_empty() {
return None;
}
match quality {
StreamQuality::Best => variants.first(),
StreamQuality::Worst => variants.last(),
StreamQuality::Specific(height) => {
#[allow(clippy::cast_possible_wrap)]
variants
.iter()
.min_by_key(|v| (v.height as i32 - height as i32).abs())
}
}
}
async fn stream_live_with_duration<W: AsyncWrite + Unpin + Send>(
&self,
playlist_url: &str,
headers: &HashMap<String, String>,
output: &mut W,
progress: Option<&ProgressCallback>,
start_time: std::time::Instant,
duration_secs: Option<u64>,
) -> Result<()> {
let mut last_sequence = 0u64;
let mut bytes_downloaded = 0u64;
let mut segments_completed = 0u32;
loop {
if let Some(max_dur) = duration_secs
&& start_time.elapsed().as_secs() >= max_dur
{
info!("Duration limit reached ({max_dur}s), stopping live stream");
break;
}
let playlist = self.parse_media_playlist(playlist_url, headers).await?;
let new_segments: Vec<_> = playlist
.segments
.iter()
.filter(|s| s.sequence > last_sequence)
.collect();
if !new_segments.is_empty() {
debug!("Found {} new segments", new_segments.len());
for seg in new_segments {
let data = self.fetch_segment(&seg.uri, headers).await?;
bytes_downloaded += data.len() as u64;
segments_completed += 1;
last_sequence = seg.sequence;
output.write_all(&data).await?;
if let Some(cb) = progress {
cb(StreamProgress {
bytes_downloaded,
segments_completed,
segments_total: None,
elapsed_seconds: start_time.elapsed().as_secs_f64(),
});
}
if let Some(max_dur) = duration_secs
&& start_time.elapsed().as_secs() >= max_dur
{
info!("Duration limit reached ({max_dur}s), stopping live stream");
return Ok(());
}
}
}
if !playlist.is_live {
break;
}
tokio::time::sleep(Duration::from_secs_f64(playlist.target_duration / 2.0)).await;
}
Ok(())
}
async fn stream_to_internal<W: AsyncWrite + Unpin + Send>(
&self,
manifest_url: &str,
config: &StreamConfig,
output: &mut W,
progress: Option<ProgressCallback>,
duration_secs: Option<u64>,
) -> Result<()> {
let headers = &config.headers;
let start_time = std::time::Instant::now();
let content = self.fetch_playlist(manifest_url, headers).await?;
let is_master = content.contains("#EXT-X-STREAM-INF:");
let media_url = if is_master {
let variants = self.parse_master_playlist(manifest_url, headers).await?;
debug!("Found {} quality variants", variants.len());
let variant = Self::select_variant(&variants, config.quality)
.ok_or_else(|| anyhow!("No suitable quality variant found"))?;
info!(
"Selected variant: {}p @ {} bps",
variant.height, variant.bandwidth
);
variant.uri.clone()
} else {
manifest_url.to_string()
};
let playlist = self.parse_media_playlist(&media_url, headers).await?;
info!(
"Playlist: {} segments, live={}",
playlist.segments.len(),
playlist.is_live
);
let total_segments = if playlist.is_live {
None
} else {
#[allow(clippy::cast_possible_truncation)]
Some(playlist.segments.len() as u32)
};
let mut bytes_downloaded = 0u64;
let mut segments_completed = 0u32;
if playlist.is_live {
self.stream_live_with_duration(
&media_url,
headers,
output,
progress.as_ref(),
start_time,
duration_secs,
)
.await?;
} else {
let max_segments = duration_secs.and_then(|dur| {
if playlist.segments.is_empty() {
None
} else {
let avg_seg_duration = playlist.target_duration;
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
Some((dur as f64 / avg_seg_duration).ceil() as usize)
}
});
let segments_to_fetch = if let Some(max) = max_segments {
playlist.segments.iter().take(max).collect::<Vec<_>>()
} else {
playlist.segments.iter().collect::<Vec<_>>()
};
for chunk in segments_to_fetch.chunks(self.max_concurrent) {
let futures: Vec<_> = chunk
.iter()
.map(|seg| self.fetch_segment(&seg.uri, headers))
.collect();
let results = futures::future::join_all(futures).await;
for result in results {
let data = result?;
bytes_downloaded += data.len() as u64;
segments_completed += 1;
output.write_all(&data).await?;
if let Some(cb) = &progress {
cb(StreamProgress {
bytes_downloaded,
segments_completed,
segments_total: total_segments,
elapsed_seconds: start_time.elapsed().as_secs_f64(),
});
}
}
}
}
output.flush().await?;
Ok(())
}
}
#[async_trait]
impl StreamBackend for NativeHlsBackend {
fn backend_type(&self) -> BackendType {
BackendType::Native
}
fn can_handle(&self, manifest_url: &str, encrypted: bool) -> bool {
!encrypted && manifest_url.contains(".m3u8")
}
async fn stream_to<W: AsyncWrite + Unpin + Send>(
&self,
manifest_url: &str,
config: &StreamConfig,
output: &mut W,
progress: Option<ProgressCallback>,
) -> Result<()> {
self.stream_to_internal(manifest_url, config, output, progress, None)
.await
}
async fn stream_to_file(
&self,
manifest_url: &str,
config: &StreamConfig,
path: &std::path::Path,
progress: Option<ProgressCallback>,
duration_secs: Option<u64>,
) -> Result<()> {
let file = tokio::fs::File::create(path).await?;
let mut writer = tokio::io::BufWriter::new(file);
self.stream_to_internal(manifest_url, config, &mut writer, progress, duration_secs)
.await
}
}
#[derive(Debug, Clone)]
struct HlsVariant {
bandwidth: u64,
height: u32,
#[allow(dead_code)]
codecs: Option<String>,
uri: String,
}
#[derive(Debug)]
struct HlsPlaylist {
segments: Vec<HlsSegment>,
is_live: bool,
target_duration: f64,
#[allow(dead_code)]
media_sequence: u64,
}
#[derive(Debug, Clone)]
struct HlsSegment {
sequence: u64,
#[allow(dead_code)]
duration: f64,
uri: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_url_relative() {
assert_eq!(
NativeHlsBackend::resolve_url("https://example.com/path", "video.ts"),
"https://example.com/path/video.ts"
);
}
#[test]
fn test_resolve_url_absolute_path() {
assert_eq!(
NativeHlsBackend::resolve_url("https://example.com/path", "/video.ts"),
"https://example.com/video.ts"
);
}
#[test]
fn test_resolve_url_full_url() {
assert_eq!(
NativeHlsBackend::resolve_url(
"https://example.com/path",
"https://cdn.example.com/video.ts"
),
"https://cdn.example.com/video.ts"
);
}
#[test]
fn test_resolve_url_http() {
assert_eq!(
NativeHlsBackend::resolve_url(
"http://example.com/path",
"http://cdn.example.com/video.ts"
),
"http://cdn.example.com/video.ts"
);
}
#[test]
fn test_resolve_url_absolute_path_no_trailing_slash() {
assert_eq!(
NativeHlsBackend::resolve_url("https://example.com", "/video.ts"),
"https://example.com/video.ts"
);
}
#[test]
fn test_parse_attributes() {
let attrs = NativeHlsBackend::parse_attributes("BANDWIDTH=1280000,RESOLUTION=720x480");
assert_eq!(attrs.get("BANDWIDTH"), Some(&"1280000".to_string()));
assert_eq!(attrs.get("RESOLUTION"), Some(&"720x480".to_string()));
}
#[test]
fn test_parse_attributes_quoted_codecs() {
let attrs = NativeHlsBackend::parse_attributes(
"CODECS=\"avc1.4d401f,mp4a.40.2\",BANDWIDTH=2000000",
);
assert_eq!(
attrs.get("CODECS"),
Some(&"avc1.4d401f,mp4a.40.2".to_string())
);
assert_eq!(attrs.get("BANDWIDTH"), Some(&"2000000".to_string()));
}
#[test]
fn test_parse_attributes_empty() {
let attrs = NativeHlsBackend::parse_attributes("");
assert!(attrs.is_empty());
}
#[test]
fn test_select_variant_best() {
let variants = vec![
HlsVariant {
bandwidth: 5_000_000,
height: 1080,
codecs: None,
uri: "1080p.m3u8".into(),
},
HlsVariant {
bandwidth: 2_000_000,
height: 720,
codecs: None,
uri: "720p.m3u8".into(),
},
HlsVariant {
bandwidth: 500_000,
height: 360,
codecs: None,
uri: "360p.m3u8".into(),
},
];
let best = NativeHlsBackend::select_variant(&variants, StreamQuality::Best).unwrap();
assert_eq!(best.height, 1080);
}
#[test]
fn test_select_variant_worst() {
let variants = vec![
HlsVariant {
bandwidth: 5_000_000,
height: 1080,
codecs: None,
uri: "1080p.m3u8".into(),
},
HlsVariant {
bandwidth: 500_000,
height: 360,
codecs: None,
uri: "360p.m3u8".into(),
},
];
let worst = NativeHlsBackend::select_variant(&variants, StreamQuality::Worst).unwrap();
assert_eq!(worst.height, 360);
}
#[test]
fn test_select_variant_specific() {
let variants = vec![
HlsVariant {
bandwidth: 5_000_000,
height: 1080,
codecs: None,
uri: "1080p.m3u8".into(),
},
HlsVariant {
bandwidth: 2_000_000,
height: 720,
codecs: None,
uri: "720p.m3u8".into(),
},
HlsVariant {
bandwidth: 500_000,
height: 360,
codecs: None,
uri: "360p.m3u8".into(),
},
];
let specific =
NativeHlsBackend::select_variant(&variants, StreamQuality::Specific(700)).unwrap();
assert_eq!(specific.height, 720, "should pick closest to 700p");
}
#[test]
fn test_select_variant_empty() {
let variants: Vec<HlsVariant> = vec![];
assert!(NativeHlsBackend::select_variant(&variants, StreamQuality::Best).is_none());
}
#[test]
fn test_can_handle_hls() {
let backend = NativeHlsBackend::new().unwrap();
assert!(backend.can_handle("https://example.com/stream.m3u8", false));
}
#[test]
fn test_can_handle_encrypted_rejected() {
let backend = NativeHlsBackend::new().unwrap();
assert!(
!backend.can_handle("https://example.com/stream.m3u8", true),
"native backend should not handle encrypted streams"
);
}
#[test]
fn test_can_handle_non_hls() {
let backend = NativeHlsBackend::new().unwrap();
assert!(!backend.can_handle("https://example.com/stream.mpd", false));
assert!(!backend.can_handle("https://example.com/video.mp4", false));
}
}