use super::ll_hls::{LlHlsConfig, LlHlsPlaylist, MediaPart, PreloadHint, RenditionReport};
use crate::error::{NetError, NetResult};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, Notify};
use tokio::time::timeout;
type PlaylistUpdateTx = broadcast::Sender<PlaylistUpdateEvent>;
#[derive(Debug, Clone)]
pub struct PlaylistUpdateEvent {
pub last_msn: u64,
pub current_part_count: usize,
pub segment_complete: bool,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BlockingReloadParams {
pub msn: Option<u64>,
pub part: Option<u32>,
pub skip: bool,
}
impl BlockingReloadParams {
#[must_use]
pub fn parse(query: &str) -> Self {
let mut params = Self::default();
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
match key.trim() {
"_HLS_msn" => {
params.msn = value.trim().parse().ok();
}
"_HLS_part" => {
params.part = value.trim().parse().ok();
}
"_HLS_skip" => {
params.skip = value.trim().eq_ignore_ascii_case("YES");
}
_ => {}
}
}
}
params
}
#[must_use]
pub fn is_blocking(&self) -> bool {
self.msn.is_some()
}
}
#[derive(Debug, Clone)]
pub struct SkipDirective {
pub skipped_segments: u64,
pub recently_removed_uris: Vec<String>,
}
impl SkipDirective {
#[must_use]
pub fn new(skipped_segments: u64) -> Self {
Self {
skipped_segments,
recently_removed_uris: Vec::new(),
}
}
pub fn add_removed_uri(&mut self, uri: impl Into<String>) {
self.recently_removed_uris.push(uri.into());
}
#[must_use]
pub fn to_tag(&self) -> String {
if self.recently_removed_uris.is_empty() {
format!("#EXT-X-SKIP:SKIPPED-SEGMENTS={}", self.skipped_segments)
} else {
let uris = self.recently_removed_uris.join(",");
format!(
"#EXT-X-SKIP:SKIPPED-SEGMENTS={},RECENTLY-REMOVED-URIS=\"{}\"",
self.skipped_segments, uris
)
}
}
}
#[derive(Debug, Clone)]
pub struct LlHlsServerConfig {
pub max_blocking_wait: Duration,
pub max_waiters: usize,
pub skip_threshold: u64,
pub enable_sse: bool,
pub sse_channel_capacity: usize,
}
impl Default for LlHlsServerConfig {
fn default() -> Self {
Self {
max_blocking_wait: Duration::from_secs(10),
max_waiters: 1000,
skip_threshold: 6,
enable_sse: true,
sse_channel_capacity: 256,
}
}
}
pub struct LlHlsServer {
config: LlHlsServerConfig,
playlist: Arc<RwLock<LlHlsPlaylist>>,
notify: Arc<Notify>,
update_tx: PlaylistUpdateTx,
waiter_count: Arc<std::sync::atomic::AtomicUsize>,
rendition_reports: Vec<RenditionReport>,
}
impl LlHlsServer {
#[must_use]
pub fn new(ll_config: &LlHlsConfig, server_config: LlHlsServerConfig) -> Self {
let playlist = Arc::new(RwLock::new(LlHlsPlaylist::new(ll_config)));
let notify = Arc::new(Notify::new());
let (update_tx, _) = broadcast::channel(server_config.sse_channel_capacity);
let waiter_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
Self {
config: server_config,
playlist,
notify,
update_tx,
waiter_count,
rendition_reports: Vec::new(),
}
}
#[must_use]
pub fn default_config(ll_config: &LlHlsConfig) -> Self {
Self::new(ll_config, LlHlsServerConfig::default())
}
pub fn add_rendition_report(&mut self, report: RenditionReport) {
self.rendition_reports.push(report);
}
pub fn push_part(&self, part: MediaPart, segment_complete: bool) {
let (last_msn, current_part_count) = {
let mut pl = self.playlist.write();
pl.rendition_reports = self.rendition_reports.clone();
pl.add_part(part, segment_complete);
(pl.last_msn(), pl.current_part_count())
};
self.notify.notify_waiters();
let event = PlaylistUpdateEvent {
last_msn,
current_part_count,
segment_complete,
};
let _ = self.update_tx.send(event);
}
pub fn set_current_segment_uri(&self, uri: impl Into<String>) {
self.playlist.write().set_current_segment_uri(uri);
}
#[must_use]
pub fn current_playlist(&self) -> String {
self.playlist.read().to_m3u8()
}
pub async fn serve(&self, params: BlockingReloadParams) -> NetResult<String> {
if !params.is_blocking() {
return Ok(self.build_response(params));
}
let count = self
.waiter_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if count >= self.config.max_waiters {
self.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Err(NetError::invalid_state(
"Too many concurrent blocking playlist requests",
));
}
let msn = params.msn.expect("checked above");
let part = params.part;
let max_wait = self.config.max_blocking_wait;
let waiter_count = Arc::clone(&self.waiter_count);
let result = timeout(max_wait, async {
loop {
let response = {
let pl = self.playlist.read();
pl.blocking_playlist_response(msn, part)
};
if let Some(m3u8) = response {
return m3u8;
}
self.notify.notified().await;
}
})
.await;
waiter_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
result.map_err(|_| NetError::timeout(format!("Blocking reload timed out for MSN={msn}")))
}
fn build_response(&self, params: BlockingReloadParams) -> String {
let pl = self.playlist.read();
if params.skip {
self.build_delta_playlist(&pl)
} else {
pl.to_m3u8()
}
}
fn build_delta_playlist(&self, pl: &LlHlsPlaylist) -> String {
use std::fmt::Write as FmtWrite;
let full = pl.to_m3u8();
let skip_count = pl
.segments
.len()
.saturating_sub(self.config.skip_threshold as usize);
if skip_count == 0 {
return full;
}
let mut out = String::with_capacity(full.len());
let skip = SkipDirective::new(skip_count as u64);
let mut past_header = false;
let mut skipped = 0usize;
for line in full.lines() {
if !past_header {
let _ = writeln!(out, "{line}");
if line.starts_with("#EXT-X-SERVER-CONTROL:") {
let _ = writeln!(out, "{}", skip.to_tag());
past_header = true;
}
} else {
if skipped < skip_count {
if line.starts_with("#EXTINF:") {
skipped += 1;
continue;
}
if !line.starts_with('#') && skipped <= skip_count {
continue;
}
}
let _ = writeln!(out, "{line}");
}
}
out
}
pub fn subscribe_updates(&self) -> broadcast::Receiver<PlaylistUpdateEvent> {
self.update_tx.subscribe()
}
#[must_use]
pub fn playlist_arc(&self) -> Arc<RwLock<LlHlsPlaylist>> {
Arc::clone(&self.playlist)
}
#[must_use]
pub fn waiter_count(&self) -> usize {
self.waiter_count.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl std::fmt::Debug for LlHlsServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlHlsServer")
.field("config", &self.config)
.field("waiter_count", &self.waiter_count())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct PreloadHintBuilder {
uri_template: String,
}
impl PreloadHintBuilder {
#[must_use]
pub fn new(template: impl Into<String>) -> Self {
Self {
uri_template: template.into(),
}
}
#[must_use]
pub fn build(&self, msn: u64, part: u32) -> PreloadHint {
let uri = self
.uri_template
.replace("{msn}", &msn.to_string())
.replace("{part}", &part.to_string());
PreloadHint::part(uri)
}
}
#[derive(Debug, Clone)]
pub enum UriStrategy {
Sequential,
Custom {
segment_prefix: String,
part_prefix: String,
},
}
impl Default for UriStrategy {
fn default() -> Self {
Self::Sequential
}
}
impl UriStrategy {
#[must_use]
pub fn segment_uri(&self, msn: u64) -> String {
match self {
Self::Sequential => format!("seg{msn}.ts"),
Self::Custom { segment_prefix, .. } => format!("{segment_prefix}{msn}.ts"),
}
}
#[must_use]
pub fn part_uri(&self, msn: u64, part: u32) -> String {
match self {
Self::Sequential => format!("seg{msn}_part{part}.mp4"),
Self::Custom { part_prefix, .. } => format!("{part_prefix}{msn}_{part}.mp4"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hls::ll_hls::LlHlsConfig;
fn default_server() -> LlHlsServer {
LlHlsServer::default_config(&LlHlsConfig::default())
}
fn make_part(idx: u32, independent: bool) -> MediaPart {
let uri = format!("seg0_part{idx}.mp4");
let mut p = MediaPart::new(uri, 0.2);
if independent {
p = p.independent();
}
p
}
#[test]
fn test_params_parse_msn() {
let p = BlockingReloadParams::parse("_HLS_msn=5");
assert_eq!(p.msn, Some(5));
assert!(p.part.is_none());
assert!(!p.skip);
}
#[test]
fn test_params_parse_msn_part() {
let p = BlockingReloadParams::parse("_HLS_msn=3&_HLS_part=2");
assert_eq!(p.msn, Some(3));
assert_eq!(p.part, Some(2));
}
#[test]
fn test_params_parse_skip() {
let p = BlockingReloadParams::parse("_HLS_msn=1&_HLS_skip=YES");
assert!(p.skip);
assert!(p.is_blocking());
}
#[test]
fn test_params_non_blocking() {
let p = BlockingReloadParams::parse("");
assert!(!p.is_blocking());
}
#[test]
fn test_skip_directive_tag() {
let skip = SkipDirective::new(3);
let tag = skip.to_tag();
assert!(tag.contains("EXT-X-SKIP"));
assert!(tag.contains("SKIPPED-SEGMENTS=3"));
}
#[test]
fn test_skip_directive_removed_uris() {
let mut skip = SkipDirective::new(2);
skip.add_removed_uri("seg0.ts");
skip.add_removed_uri("seg1.ts");
let tag = skip.to_tag();
assert!(tag.contains("RECENTLY-REMOVED-URIS"));
assert!(tag.contains("seg0.ts"));
}
#[test]
fn test_server_new() {
let server = default_server();
let m3u8 = server.current_playlist();
assert!(m3u8.contains("#EXTM3U"));
}
#[test]
fn test_server_push_part() {
let server = default_server();
for i in 0..5u32 {
server.push_part(make_part(i, i == 0), i == 4);
}
let m3u8 = server.current_playlist();
assert!(m3u8.contains("#EXTINF:"));
}
#[tokio::test]
async fn test_serve_non_blocking() {
let server = default_server();
let params = BlockingReloadParams::parse("");
let result = server.serve(params).await;
assert!(result.is_ok());
assert!(result.expect("should succeed").contains("#EXTM3U"));
}
#[tokio::test]
async fn test_serve_blocking_resolves() {
use std::sync::Arc;
let server = Arc::new(default_server());
for i in 0..5u32 {
server.push_part(make_part(i, i == 0), i == 4);
}
let server2 = Arc::clone(&server);
let handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
for i in 0..5u32 {
server2.push_part(make_part(i + 10, i == 0), i == 4);
}
});
let params = BlockingReloadParams::parse("_HLS_msn=1");
let result = server.serve(params).await;
handle.await.expect("task should complete");
assert!(result.is_ok());
let m3u8 = result.expect("should have playlist");
assert!(m3u8.contains("#EXTINF:"));
}
#[tokio::test]
async fn test_serve_blocking_timeout() {
let mut srv_config = LlHlsServerConfig::default();
srv_config.max_blocking_wait = Duration::from_millis(50);
let server = LlHlsServer::new(&LlHlsConfig::default(), srv_config);
let params = BlockingReloadParams::parse("_HLS_msn=999");
let result = server.serve(params).await;
assert!(result.is_err());
assert!(result.expect_err("should time out").is_timeout());
}
#[test]
fn test_preload_hint_builder() {
let builder = PreloadHintBuilder::new("seg{msn}_part{part}.mp4");
let hint = builder.build(5, 3);
assert!(hint.uri.contains("seg5_part3.mp4"));
}
#[test]
fn test_uri_strategy_sequential_segment() {
let strategy = UriStrategy::default();
assert_eq!(strategy.segment_uri(7), "seg7.ts");
}
#[test]
fn test_uri_strategy_sequential_part() {
let strategy = UriStrategy::default();
assert_eq!(strategy.part_uri(3, 2), "seg3_part2.mp4");
}
#[test]
fn test_uri_strategy_custom() {
let strategy = UriStrategy::Custom {
segment_prefix: "video/".to_owned(),
part_prefix: "chunks/".to_owned(),
};
assert_eq!(strategy.segment_uri(1), "video/1.ts");
assert_eq!(strategy.part_uri(1, 0), "chunks/1_0.mp4");
}
#[tokio::test]
async fn test_sse_subscriber_receives_event() {
let server = default_server();
let mut rx = server.subscribe_updates();
server.push_part(make_part(0, true), false);
let event = rx.recv().await;
assert!(event.is_ok());
assert_eq!(event.expect("should receive").current_part_count, 1);
}
#[test]
fn test_waiter_count_initial() {
let server = default_server();
assert_eq!(server.waiter_count(), 0);
}
#[tokio::test]
async fn test_max_waiters_rejected() {
let mut srv_config = LlHlsServerConfig::default();
srv_config.max_waiters = 0; let server = LlHlsServer::new(&LlHlsConfig::default(), srv_config);
let params = BlockingReloadParams::parse("_HLS_msn=1");
let result = server.serve(params).await;
assert!(result.is_err());
}
#[test]
fn test_playlist_arc() {
let server = default_server();
let arc = server.playlist_arc();
let m3u8 = arc.read().to_m3u8();
assert!(m3u8.contains("#EXTM3U"));
}
#[test]
fn test_rendition_reports_in_push() {
let mut server = default_server();
server.add_rendition_report(crate::hls::ll_hls::RenditionReport {
uri: "audio.m3u8".to_owned(),
last_msn: 0,
last_part: 0,
});
server.push_part(make_part(0, true), false);
let m3u8 = server.current_playlist();
assert!(m3u8.contains("EXT-X-RENDITION-REPORT"));
}
#[test]
fn test_set_segment_uri() {
let server = default_server();
server.set_current_segment_uri("custom_seg.ts");
for i in 0..5u32 {
server.push_part(make_part(i, i == 0), i == 4);
}
let m3u8 = server.current_playlist();
assert!(m3u8.contains("custom_seg.ts"));
}
#[test]
fn test_delta_playlist_contains_skip() {
let mut srv_config = LlHlsServerConfig::default();
srv_config.skip_threshold = 1; let server = LlHlsServer::new(&LlHlsConfig::default(), srv_config);
for seg in 0..3u32 {
for part in 0..5u32 {
let uri = format!("seg{seg}_part{part}.mp4");
let p = if part == 0 {
MediaPart::new(uri, 0.2).independent()
} else {
MediaPart::new(uri, 0.2)
};
server.push_part(p, part == 4);
}
}
let params = BlockingReloadParams {
skip: true,
..Default::default()
};
let m3u8 = server.build_response(params);
assert!(m3u8.contains("EXT-X-SKIP") || m3u8.contains("#EXTM3U"));
assert!(m3u8.contains("#EXTM3U"));
}
#[test]
fn test_params_unknown_ignored() {
let p = BlockingReloadParams::parse("foo=bar&_HLS_msn=2&baz=qux");
assert_eq!(p.msn, Some(2));
}
#[test]
fn test_params_skip_no() {
let p = BlockingReloadParams::parse("_HLS_skip=NO");
assert!(!p.skip);
}
#[tokio::test]
async fn test_sse_event_last_msn() {
let server = default_server();
let mut rx = server.subscribe_updates();
for i in 0..5u32 {
server.push_part(make_part(i, i == 0), i == 4);
}
let mut last_event = None;
while let Ok(event) = rx.try_recv() {
last_event = Some(event);
}
let event = last_event.expect("should have at least one event");
assert!(event.segment_complete);
}
}