#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ManifestType {
HlsMaster,
HlsMedia,
DashMpd,
}
impl ManifestType {
#[must_use]
pub const fn content_type(&self) -> &'static str {
match self {
Self::HlsMaster | Self::HlsMedia => "application/vnd.apple.mpegurl",
Self::DashMpd => "application/dash+xml",
}
}
#[must_use]
pub const fn label(&self) -> &'static str {
match self {
Self::HlsMaster => "HLS-Master",
Self::HlsMedia => "HLS-Media",
Self::DashMpd => "DASH-MPD",
}
}
}
#[derive(Debug, Clone)]
pub struct CachedManifest {
pub manifest_type: ManifestType,
pub content: String,
pub etag: String,
pub created_at: Instant,
pub ttl: Duration,
pub stale_while_revalidate: Duration,
pub hits: u64,
pub revalidating: bool,
}
impl CachedManifest {
#[must_use]
pub fn new(manifest_type: ManifestType, content: String, ttl: Duration) -> Self {
let etag = compute_etag(&content);
Self {
manifest_type,
content,
etag,
created_at: Instant::now(),
ttl,
stale_while_revalidate: ttl / 2,
hits: 0,
revalidating: false,
}
}
#[must_use]
pub fn is_fresh(&self) -> bool {
self.created_at.elapsed() < self.ttl
}
#[must_use]
pub fn is_stale_but_usable(&self) -> bool {
let age = self.created_at.elapsed();
age >= self.ttl && age < self.ttl + self.stale_while_revalidate
}
#[must_use]
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
#[must_use]
pub fn remaining_ttl(&self) -> Duration {
self.ttl
.checked_sub(self.created_at.elapsed())
.unwrap_or(Duration::ZERO)
}
#[must_use]
pub fn cache_control_header(&self) -> String {
let max_age = self.ttl.as_secs();
let swr = self.stale_while_revalidate.as_secs();
format!("public, max-age={max_age}, stale-while-revalidate={swr}")
}
#[must_use]
pub fn etag_header(&self) -> String {
format!("\"{}\"", self.etag)
}
}
fn compute_etag(content: &str) -> String {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325_u64;
for byte in content.bytes() {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x100_0000_01b3);
}
format!("{hash:016x}")
}
#[derive(Debug, Clone)]
pub struct ManifestCacheConfig {
pub max_entries: usize,
pub hls_media_ttl: Duration,
pub hls_master_ttl: Duration,
pub dash_mpd_ttl: Duration,
pub eviction_interval: Duration,
pub max_size_bytes: usize,
}
impl Default for ManifestCacheConfig {
fn default() -> Self {
Self {
max_entries: 1024,
hls_media_ttl: Duration::from_secs(2), hls_master_ttl: Duration::from_secs(300), dash_mpd_ttl: Duration::from_secs(2),
eviction_interval: Duration::from_secs(10),
max_size_bytes: 64 * 1024 * 1024, }
}
}
impl ManifestCacheConfig {
#[must_use]
pub fn default_ttl(&self, manifest_type: ManifestType) -> Duration {
match manifest_type {
ManifestType::HlsMaster => self.hls_master_ttl,
ManifestType::HlsMedia => self.hls_media_ttl,
ManifestType::DashMpd => self.dash_mpd_ttl,
}
}
}
struct CacheInner {
entries: HashMap<String, CachedManifest>,
max_entries: usize,
total_hits: u64,
total_misses: u64,
total_bytes: usize,
max_bytes: usize,
}
impl CacheInner {
fn new(max_entries: usize, max_bytes: usize) -> Self {
Self {
entries: HashMap::with_capacity(max_entries.min(256)),
max_entries,
total_hits: 0,
total_misses: 0,
total_bytes: 0,
max_bytes,
}
}
fn insert(&mut self, url: String, entry: CachedManifest) {
let entry_bytes = entry.content.len();
while self.entries.len() >= self.max_entries {
self.evict_oldest();
}
while self.total_bytes + entry_bytes > self.max_bytes && !self.entries.is_empty() {
self.evict_oldest();
}
if let Some(old) = self.entries.get(&url) {
self.total_bytes = self.total_bytes.saturating_sub(old.content.len());
}
self.total_bytes += entry_bytes;
self.entries.insert(url, entry);
}
fn get(&mut self, url: &str) -> Option<&mut CachedManifest> {
if self.entries.contains_key(url) {
self.total_hits += 1;
self.entries.get_mut(url)
} else {
self.total_misses += 1;
None
}
}
fn remove(&mut self, url: &str) -> Option<CachedManifest> {
if let Some(e) = self.entries.remove(url) {
self.total_bytes = self.total_bytes.saturating_sub(e.content.len());
Some(e)
} else {
None
}
}
fn evict_expired(&mut self) {
let to_remove: Vec<String> = self
.entries
.iter()
.filter(|(_, e)| !e.is_fresh() && !e.is_stale_but_usable())
.map(|(k, _)| k.clone())
.collect();
for key in to_remove {
if let Some(e) = self.entries.remove(&key) {
self.total_bytes = self.total_bytes.saturating_sub(e.content.len());
}
}
}
fn evict_oldest(&mut self) {
let oldest_key = self
.entries
.iter()
.min_by_key(|(_, e)| e.created_at)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
if let Some(e) = self.entries.remove(&key) {
self.total_bytes = self.total_bytes.saturating_sub(e.content.len());
}
}
}
fn len(&self) -> usize {
self.entries.len()
}
fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Clone)]
pub struct ManifestCache {
inner: Arc<RwLock<CacheInner>>,
config: ManifestCacheConfig,
}
impl ManifestCache {
#[must_use]
pub fn new(config: ManifestCacheConfig) -> Self {
let inner = CacheInner::new(config.max_entries, config.max_size_bytes);
Self {
inner: Arc::new(RwLock::new(inner)),
config,
}
}
pub async fn put(&self, url: impl Into<String>, content: String, manifest_type: ManifestType) {
let ttl = self.config.default_ttl(manifest_type);
let entry = CachedManifest::new(manifest_type, content, ttl);
let mut inner = self.inner.write().await;
inner.insert(url.into(), entry);
}
pub async fn put_with_ttl(
&self,
url: impl Into<String>,
content: String,
manifest_type: ManifestType,
ttl: Duration,
) {
let entry = CachedManifest::new(manifest_type, content, ttl);
let mut inner = self.inner.write().await;
inner.insert(url.into(), entry);
}
pub async fn get(&self, url: &str) -> Option<CachedManifest> {
let mut inner = self.inner.write().await;
if let Some(entry) = inner.get(url) {
if entry.is_fresh() || entry.is_stale_but_usable() {
entry.hits += 1;
return Some(entry.clone());
}
}
None
}
pub async fn get_with_revalidation(&self, url: &str) -> (Option<CachedManifest>, bool) {
let mut inner = self.inner.write().await;
if let Some(entry) = inner.get(url) {
let fresh = entry.is_fresh();
let usable = fresh || entry.is_stale_but_usable();
if usable {
entry.hits += 1;
let needs_revalidation = !fresh && !entry.revalidating;
if needs_revalidation {
entry.revalidating = true;
}
return (Some(entry.clone()), needs_revalidation);
}
}
(None, false)
}
pub async fn etag_matches(&self, url: &str, client_etag: &str) -> bool {
let inner = self.inner.read().await;
if let Some(entry) = inner.entries.get(url) {
let server_etag = entry.etag_header();
let client_clean = client_etag.trim_matches('"');
let server_clean = entry.etag.as_str();
return client_clean == server_clean || client_etag == server_etag;
}
false
}
pub async fn invalidate(&self, url: &str) {
let mut inner = self.inner.write().await;
inner.remove(url);
}
pub async fn evict_expired(&self) {
let mut inner = self.inner.write().await;
inner.evict_expired();
}
pub async fn clear(&self) {
let mut inner = self.inner.write().await;
inner.entries.clear();
inner.total_bytes = 0;
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
pub async fn hit_stats(&self) -> (u64, u64) {
let inner = self.inner.read().await;
(inner.total_hits, inner.total_misses)
}
pub async fn total_bytes(&self) -> usize {
self.inner.read().await.total_bytes
}
}
impl Default for ManifestCache {
fn default() -> Self {
Self::new(ManifestCacheConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_hls() -> String {
"#EXTM3U\n#EXT-X-VERSION:3\n#EXTINF:6.006,\nseg0.ts\n".to_owned()
}
fn sample_mpd() -> String {
"<?xml version=\"1.0\"?><MPD></MPD>".to_owned()
}
#[test]
fn test_manifest_type_content_type() {
assert!(ManifestType::HlsMaster.content_type().contains("mpegurl"));
assert!(ManifestType::DashMpd.content_type().contains("dash+xml"));
}
#[test]
fn test_manifest_type_label() {
assert_eq!(ManifestType::HlsMaster.label(), "HLS-Master");
assert_eq!(ManifestType::DashMpd.label(), "DASH-MPD");
}
#[test]
fn test_cached_manifest_is_fresh() {
let entry = CachedManifest::new(
ManifestType::HlsMedia,
sample_hls(),
Duration::from_secs(60),
);
assert!(entry.is_fresh());
}
#[test]
fn test_cached_manifest_stale_but_usable() {
let mut entry = CachedManifest::new(ManifestType::HlsMedia, sample_hls(), Duration::ZERO);
entry.stale_while_revalidate = Duration::from_secs(60);
assert!(entry.is_stale_but_usable());
}
#[test]
fn test_cache_control_header() {
let entry = CachedManifest::new(
ManifestType::HlsMedia,
sample_hls(),
Duration::from_secs(10),
);
let cc = entry.cache_control_header();
assert!(cc.contains("max-age=10"));
assert!(cc.contains("stale-while-revalidate"));
}
#[test]
fn test_etag_header_quoted() {
let entry = CachedManifest::new(
ManifestType::HlsMedia,
sample_hls(),
Duration::from_secs(10),
);
let etag = entry.etag_header();
assert!(etag.starts_with('"') && etag.ends_with('"'));
}
#[test]
fn test_etag_changes_with_content() {
let e1 = CachedManifest::new(
ManifestType::HlsMedia,
"content A".to_owned(),
Duration::from_secs(10),
);
let e2 = CachedManifest::new(
ManifestType::HlsMedia,
"content B".to_owned(),
Duration::from_secs(10),
);
assert_ne!(e1.etag, e2.etag);
}
#[test]
fn test_remaining_ttl() {
let entry = CachedManifest::new(
ManifestType::HlsMedia,
sample_hls(),
Duration::from_secs(30),
);
let rem = entry.remaining_ttl();
assert!(rem > Duration::from_secs(25));
}
#[test]
fn test_cache_config_defaults() {
let cfg = ManifestCacheConfig::default();
assert!(cfg.hls_master_ttl > cfg.hls_media_ttl);
assert_eq!(cfg.default_ttl(ManifestType::HlsMaster), cfg.hls_master_ttl);
}
#[tokio::test]
async fn test_cache_put_get() {
let cache = ManifestCache::default();
cache
.put(
"http://example.com/master.m3u8",
sample_hls(),
ManifestType::HlsMaster,
)
.await;
let result = cache.get("http://example.com/master.m3u8").await;
assert!(result.is_some());
assert_eq!(
result.expect("should be some").manifest_type,
ManifestType::HlsMaster
);
}
#[tokio::test]
async fn test_cache_miss() {
let cache = ManifestCache::default();
let result = cache.get("http://example.com/missing.m3u8").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_expired_entry() {
let cache = ManifestCache::default();
cache
.put_with_ttl(
"http://example.com/live.m3u8",
sample_hls(),
ManifestType::HlsMedia,
Duration::from_nanos(1),
)
.await;
tokio::time::sleep(Duration::from_millis(5)).await;
let result = cache.get("http://example.com/live.m3u8").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_etag_match() {
let cache = ManifestCache::default();
let content = sample_hls();
cache
.put(
"http://x.com/m.m3u8",
content.clone(),
ManifestType::HlsMedia,
)
.await;
let entry = cache
.get("http://x.com/m.m3u8")
.await
.expect("should exist");
let etag = entry.etag.clone();
let matched = cache.etag_matches("http://x.com/m.m3u8", &etag).await;
assert!(matched);
}
#[tokio::test]
async fn test_etag_mismatch() {
let cache = ManifestCache::default();
cache
.put("http://x.com/m.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
let matched = cache
.etag_matches("http://x.com/m.m3u8", "stale_etag")
.await;
assert!(!matched);
}
#[tokio::test]
async fn test_cache_invalidate() {
let cache = ManifestCache::default();
cache
.put("http://x.com/m.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache.invalidate("http://x.com/m.m3u8").await;
assert!(cache.is_empty().await);
}
#[tokio::test]
async fn test_cache_clear() {
let cache = ManifestCache::default();
cache
.put("http://a.com/a.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache
.put("http://b.com/b.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache.clear().await;
assert_eq!(cache.len().await, 0);
}
#[tokio::test]
async fn test_cache_hit_stats() {
let cache = ManifestCache::default();
cache
.put("http://x.com/m.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache.get("http://x.com/m.m3u8").await;
cache.get("http://x.com/m.m3u8").await;
cache.get("http://x.com/missing.m3u8").await; let (hits, misses) = cache.hit_stats().await;
assert_eq!(hits, 2);
assert_eq!(misses, 1);
}
#[tokio::test]
async fn test_cache_total_bytes() {
let cache = ManifestCache::default();
let content = sample_hls();
let expected_bytes = content.len();
cache
.put("http://x.com/m.m3u8", content, ManifestType::HlsMedia)
.await;
assert_eq!(cache.total_bytes().await, expected_bytes);
}
#[tokio::test]
async fn test_cache_evict_expired() {
let cache = ManifestCache::default();
cache
.put_with_ttl(
"http://x.com/a.m3u8",
sample_hls(),
ManifestType::HlsMedia,
Duration::from_nanos(1),
)
.await;
cache
.put("http://x.com/b.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
tokio::time::sleep(Duration::from_millis(5)).await;
cache.evict_expired().await;
assert_eq!(cache.len().await, 1);
}
#[tokio::test]
async fn test_cache_max_entries() {
let mut cfg = ManifestCacheConfig::default();
cfg.max_entries = 2;
let cache = ManifestCache::new(cfg);
cache
.put("http://x.com/a.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache
.put("http://x.com/b.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
cache
.put("http://x.com/c.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
assert!(cache.len().await <= 2);
}
#[tokio::test]
async fn test_cache_revalidation_fresh() {
let cache = ManifestCache::default();
cache
.put("http://x.com/m.m3u8", sample_hls(), ManifestType::HlsMedia)
.await;
let (entry, needs_revalidation) = cache.get_with_revalidation("http://x.com/m.m3u8").await;
assert!(entry.is_some());
assert!(!needs_revalidation); }
#[tokio::test]
async fn test_cache_dash_mpd() {
let cache = ManifestCache::default();
let mpd = sample_mpd();
cache
.put(
"http://x.com/manifest.mpd",
mpd.clone(),
ManifestType::DashMpd,
)
.await;
let entry = cache
.get("http://x.com/manifest.mpd")
.await
.expect("should exist");
assert_eq!(entry.content, mpd);
assert_eq!(entry.manifest_type, ManifestType::DashMpd);
}
}