use std::{
path::PathBuf,
str::FromStr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use m3u8_rs::MediaPlaylist;
use reqwest::{Client, Url};
use tokio::io::AsyncWriteExt;
use super::{decrypt::M3u8Key, utils::load_m3u8, M3u8Segment, M3u8StreamingSegment};
use crate::{
consumer::Consumer,
error::{IoriError, IoriResult},
};
pub struct M3u8Source {
m3u8_url: String,
key: Option<String>,
shaka_packager_command: Option<PathBuf>,
sequence: AtomicU64,
client: Arc<Client>,
}
impl M3u8Source {
pub fn new(
client: Arc<Client>,
m3u8_url: String,
key: Option<String>,
shaka_packager_command: Option<PathBuf>,
) -> Self {
Self {
m3u8_url,
key,
shaka_packager_command,
sequence: AtomicU64::new(0),
client,
}
}
pub async fn load_segments(
&self,
latest_media_sequence: Option<u64>,
) -> IoriResult<(Vec<M3u8Segment>, Url, MediaPlaylist)> {
let (playlist_url, playlist) =
load_m3u8(&self.client, Url::from_str(&self.m3u8_url)?).await?;
let mut key = None;
let mut initial_block = None;
let mut segments = Vec::with_capacity(playlist.segments.len());
for (i, segment) in playlist.segments.iter().enumerate() {
if let Some(k) = &segment.key {
key = M3u8Key::from_key(
&self.client,
k,
&playlist_url,
playlist.media_sequence,
self.key.clone(),
self.shaka_packager_command.clone(),
)
.await?
.map(Arc::new);
}
if let Some(m) = &segment.map {
let url = playlist_url.join(&m.uri)?;
let bytes = self.client.get(url).send().await?.bytes().await?.to_vec();
initial_block = Some(Arc::new(bytes));
}
let url = playlist_url.join(&segment.uri)?;
let filename = url
.path_segments()
.and_then(|c| c.last())
.map(|r| {
if r.ends_with(".m4s") {
format!("{}.mp4", &r[..r.len() - 4])
} else {
r.to_string()
}
})
.unwrap_or("output.ts".to_string());
let media_sequence = playlist.media_sequence + i as u64;
if let Some(latest_media_sequence) = latest_media_sequence {
if media_sequence <= latest_media_sequence as u64 {
continue;
}
}
let segment = M3u8Segment {
url,
filename,
key: key.clone(),
initial_segment: initial_block.clone(),
sequence: self.sequence.fetch_add(1, Ordering::Relaxed),
media_sequence,
byte_range: segment.byte_range.clone(),
};
segments.push(segment);
}
Ok((segments, playlist_url, playlist))
}
}
pub struct HlsSegmentFetcher {
client: Arc<Client>,
consumer: Consumer,
}
impl HlsSegmentFetcher {
pub fn new(client: Arc<Client>, consumer: Consumer) -> Self {
Self { client, consumer }
}
pub async fn fetch<S>(&self, segment: &S) -> IoriResult<()>
where
S: M3u8StreamingSegment + Send + Sync + 'static,
{
let tmp_file = self.consumer.open_writer(segment).await?;
let mut tmp_file = match tmp_file {
Some(f) => f,
None => return Ok(()),
};
let mut request = self.client.get(segment.url());
if let Some(byte_range) = segment.byte_range() {
let start = byte_range.offset.unwrap_or(0);
let end = start + byte_range.length - 1;
request = request.header("Range", format!("bytes={}-{}", start, end));
}
let response = request.send().await?;
if !response.status().is_success() {
return Err(IoriError::HttpError(response.status()));
}
let bytes = response.bytes().await?;
let decryptor = segment.key().map(|key| key.to_decryptor());
if let Some(decryptor) = decryptor {
let bytes = if let Some(initial_segment) = segment.initial_segment() {
let mut result = initial_segment.to_vec();
result.extend_from_slice(&bytes);
result
} else {
bytes.to_vec()
};
let bytes = decryptor.decrypt(&bytes)?;
tmp_file.write_all(&bytes).await?;
} else {
if let Some(initial_segment) = segment.initial_segment() {
tmp_file.write_all(&initial_segment).await?;
}
tmp_file.write_all(&bytes).await?;
}
tmp_file.finish().await?;
Ok(())
}
}