use std::collections::HashMap;
use std::time::Instant;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum WeightedCacheError {
#[error("invalid weight config: {0}")]
InvalidConfig(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CacheMediaType {
VideoSegment,
AudioSegment,
Image,
Manifest,
Thumbnail,
Metadata,
Generic,
}
#[derive(Debug, Clone, Copy)]
pub struct TypeWeights {
pub recency: f64,
pub priority: f64,
pub size_penalty: f64,
}
impl TypeWeights {
#[must_use]
fn normalise(self) -> Self {
let sum = self.recency + self.priority + self.size_penalty;
if sum == 0.0 {
let third = 1.0 / 3.0;
return Self {
recency: third,
priority: third,
size_penalty: third,
};
}
Self {
recency: self.recency / sum,
priority: self.priority / sum,
size_penalty: self.size_penalty / sum,
}
}
}
impl Default for TypeWeights {
fn default() -> Self {
Self {
recency: 0.5,
priority: 0.3,
size_penalty: 0.2,
}
}
}
#[derive(Debug, Clone)]
pub struct WeightConfig {
pub default_weights: TypeWeights,
overrides: HashMap<CacheMediaType, TypeWeights>,
}
impl Default for WeightConfig {
fn default() -> Self {
let mut cfg = Self {
default_weights: TypeWeights::default(),
overrides: HashMap::new(),
};
cfg.set_weights(
CacheMediaType::Manifest,
TypeWeights { recency: 0.5, priority: 0.45, size_penalty: 0.05 },
);
cfg.set_weights(
CacheMediaType::Thumbnail,
TypeWeights { recency: 0.4, priority: 0.4, size_penalty: 0.2 },
);
cfg.set_weights(
CacheMediaType::VideoSegment,
TypeWeights { recency: 0.4, priority: 0.25, size_penalty: 0.35 },
);
cfg.set_weights(
CacheMediaType::AudioSegment,
TypeWeights { recency: 0.4, priority: 0.3, size_penalty: 0.3 },
);
cfg.set_weights(
CacheMediaType::Image,
TypeWeights { recency: 0.35, priority: 0.25, size_penalty: 0.40 },
);
cfg.set_weights(
CacheMediaType::Metadata,
TypeWeights { recency: 0.5, priority: 0.35, size_penalty: 0.15 },
);
cfg
}
}
impl WeightConfig {
#[must_use]
pub fn new() -> Self {
Self {
default_weights: TypeWeights::default(),
overrides: HashMap::new(),
}
}
pub fn set_weights(&mut self, media_type: CacheMediaType, weights: TypeWeights) {
self.overrides.insert(media_type, weights.normalise());
}
#[must_use]
pub fn weights_for(&self, media_type: CacheMediaType) -> TypeWeights {
self.overrides
.get(&media_type)
.copied()
.unwrap_or_else(|| self.default_weights.normalise())
}
pub fn validate(&self) -> Result<(), WeightedCacheError> {
let check = |w: TypeWeights, label: &str| {
if w.recency < 0.0 || !w.recency.is_finite() {
return Err(WeightedCacheError::InvalidConfig(format!(
"{label}.recency must be finite and >= 0"
)));
}
if w.priority < 0.0 || !w.priority.is_finite() {
return Err(WeightedCacheError::InvalidConfig(format!(
"{label}.priority must be finite and >= 0"
)));
}
if w.size_penalty < 0.0 || !w.size_penalty.is_finite() {
return Err(WeightedCacheError::InvalidConfig(format!(
"{label}.size_penalty must be finite and >= 0"
)));
}
Ok(())
};
check(self.default_weights, "default_weights")?;
for (mt, w) in &self.overrides {
check(*w, &format!("{mt:?}"))?;
}
Ok(())
}
}
struct Entry {
value: Vec<u8>,
media_type: CacheMediaType,
priority: u8,
last_accessed: Instant,
size_bytes: usize,
}
pub struct WeightedCache {
capacity: usize,
weights: WeightConfig,
entries: HashMap<String, Entry>,
hits: u64,
misses: u64,
evictions: u64,
}
impl WeightedCache {
#[must_use]
pub fn new(capacity: usize, weights: WeightConfig) -> Self {
assert!(capacity > 0, "WeightedCache: capacity must be > 0");
Self {
capacity,
weights,
entries: HashMap::with_capacity(capacity),
hits: 0,
misses: 0,
evictions: 0,
}
}
pub fn insert(
&mut self,
key: impl Into<String>,
value: Vec<u8>,
media_type: CacheMediaType,
priority: u8,
) {
let key = key.into();
let size_bytes = value.len();
let now = Instant::now();
self.entries.insert(
key,
Entry {
value,
media_type,
priority,
last_accessed: now,
size_bytes,
},
);
while self.entries.len() > self.capacity {
self.evict_one();
}
}
pub fn get(&mut self, key: &str) -> Option<&[u8]> {
if let Some(entry) = self.entries.get_mut(key) {
self.hits += 1;
entry.last_accessed = Instant::now();
Some(&entry.value)
} else {
self.misses += 1;
None
}
}
pub fn remove(&mut self, key: &str) -> Option<Vec<u8>> {
self.entries.remove(key).map(|e| e.value)
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.entries.contains_key(key)
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses
}
#[must_use]
pub fn evictions(&self) -> u64 {
self.evictions
}
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn resize(&mut self, new_capacity: usize) {
assert!(new_capacity > 0, "WeightedCache: capacity must be > 0");
self.capacity = new_capacity;
while self.entries.len() > self.capacity {
self.evict_one();
}
}
pub fn clear(&mut self) {
self.entries.clear();
self.hits = 0;
self.misses = 0;
self.evictions = 0;
}
fn score(&self, entry: &Entry, max_age_ns: u64, max_size: usize) -> f64 {
let w = self.weights.weights_for(entry.media_type);
let age_ns = entry.last_accessed.elapsed().as_nanos() as f64;
let max_age = max_age_ns as f64;
let recency_factor = if max_age == 0.0 {
1.0
} else {
(1.0 - (age_ns / max_age)).clamp(0.0, 1.0)
};
let priority_factor = f64::from(entry.priority) / 255.0;
let size_factor = if max_size == 0 {
0.0
} else {
(entry.size_bytes as f64 / max_size as f64).clamp(0.0, 1.0)
};
w.recency * recency_factor + w.priority * priority_factor - w.size_penalty * size_factor
}
fn evict_one(&mut self) {
if self.entries.is_empty() {
return;
}
let max_age_ns = self
.entries
.values()
.map(|e| e.last_accessed.elapsed().as_nanos() as u64)
.max()
.unwrap_or(1);
let max_size = self
.entries
.values()
.map(|e| e.size_bytes)
.max()
.unwrap_or(1);
let victim_key = self
.entries
.iter()
.map(|(k, e)| (k.clone(), self.score(e, max_age_ns, max_size)))
.min_by(|(_, s1), (_, s2)| s1.partial_cmp(s2).unwrap_or(std::cmp::Ordering::Equal))
.map(|(k, _)| k);
if let Some(key) = victim_key {
self.entries.remove(&key);
self.evictions += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_cache(cap: usize) -> WeightedCache {
WeightedCache::new(cap, WeightConfig::default())
}
#[test]
fn test_new_cache_is_empty() {
let cache = default_cache(8);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn test_insert_and_get() {
let mut cache = default_cache(4);
cache.insert("key1", vec![1, 2, 3], CacheMediaType::Generic, 5);
let val = cache.get("key1").expect("should find key1");
assert_eq!(val, &[1u8, 2, 3]);
}
#[test]
fn test_get_absent_returns_none() {
let mut cache = default_cache(4);
assert!(cache.get("absent").is_none());
}
#[test]
fn test_hit_miss_counters() {
let mut cache = default_cache(4);
cache.insert("k", vec![0], CacheMediaType::Generic, 1);
let _ = cache.get("k");
let _ = cache.get("missing");
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 1);
}
#[test]
fn test_capacity_eviction() {
let mut cache = default_cache(3);
cache.insert("a", vec![0; 100], CacheMediaType::VideoSegment, 3);
cache.insert("b", vec![0; 100], CacheMediaType::VideoSegment, 3);
cache.insert("c", vec![0; 100], CacheMediaType::VideoSegment, 3);
cache.insert("d", vec![0; 100], CacheMediaType::VideoSegment, 3);
assert_eq!(cache.len(), 3, "cache should still be at capacity");
assert!(cache.evictions() > 0);
}
#[test]
fn test_high_priority_survives_eviction() {
let mut cfg = WeightConfig::new();
cfg.set_weights(
CacheMediaType::Manifest,
TypeWeights { recency: 0.1, priority: 0.85, size_penalty: 0.05 },
);
cfg.set_weights(
CacheMediaType::Generic,
TypeWeights { recency: 0.5, priority: 0.05, size_penalty: 0.45 },
);
let mut cache = WeightedCache::new(2, cfg);
cache.insert("manifest", vec![0u8; 10], CacheMediaType::Manifest, 255);
cache.insert("generic", vec![0u8; 10], CacheMediaType::Generic, 0);
cache.insert("third", vec![0u8; 10], CacheMediaType::Generic, 0);
assert!(
cache.contains("manifest"),
"Manifest should survive eviction"
);
}
#[test]
fn test_remove() {
let mut cache = default_cache(4);
cache.insert("k", vec![9], CacheMediaType::Metadata, 1);
let removed = cache.remove("k");
assert_eq!(removed, Some(vec![9]));
assert!(!cache.contains("k"));
}
#[test]
fn test_remove_absent() {
let mut cache = default_cache(4);
assert!(cache.remove("nope").is_none());
}
#[test]
fn test_overwrite_key() {
let mut cache = default_cache(4);
cache.insert("k", vec![1], CacheMediaType::Generic, 1);
cache.insert("k", vec![2, 3], CacheMediaType::Generic, 5);
assert_eq!(cache.len(), 1, "overwrite should not duplicate");
let val = cache.get("k").expect("should exist");
assert_eq!(val, &[2u8, 3]);
}
#[test]
fn test_hit_rate() {
let mut cache = default_cache(4);
cache.insert("a", vec![0], CacheMediaType::Generic, 1);
let _ = cache.get("a"); let _ = cache.get("a"); let _ = cache.get("b"); assert!((cache.hit_rate() - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_resize_shrinks() {
let mut cache = default_cache(5);
for i in 0..5u8 {
cache.insert(format!("k{i}"), vec![i], CacheMediaType::Generic, i);
}
assert_eq!(cache.len(), 5);
cache.resize(3);
assert_eq!(cache.len(), 3);
}
#[test]
fn test_validate_rejects_negative_weights() {
let mut cfg = WeightConfig::new();
cfg.default_weights = TypeWeights { recency: -0.1, priority: 0.5, size_penalty: 0.5 };
assert!(cfg.validate().is_err());
}
#[test]
fn test_set_weights_normalises() {
let mut cfg = WeightConfig::new();
cfg.set_weights(
CacheMediaType::Image,
TypeWeights { recency: 2.0, priority: 2.0, size_penalty: 6.0 },
);
let w = cfg.weights_for(CacheMediaType::Image);
let sum = w.recency + w.priority + w.size_penalty;
assert!((sum - 1.0).abs() < 1e-9, "weights should normalise to 1.0, got {sum}");
}
#[test]
fn test_clear() {
let mut cache = default_cache(4);
cache.insert("x", vec![1], CacheMediaType::Image, 3);
let _ = cache.get("x");
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.hits(), 0);
assert_eq!(cache.misses(), 0);
assert_eq!(cache.evictions(), 0);
}
#[test]
fn test_multiple_media_types() {
let mut cache = default_cache(10);
cache.insert("m", vec![0; 5], CacheMediaType::Manifest, 10);
cache.insert("v", vec![0; 200], CacheMediaType::VideoSegment, 5);
cache.insert("t", vec![0; 8], CacheMediaType::Thumbnail, 8);
cache.insert("a", vec![0; 50], CacheMediaType::AudioSegment, 4);
assert_eq!(cache.len(), 4);
}
#[test]
fn test_evictions_counter() {
let mut cache = default_cache(2);
cache.insert("a", vec![0], CacheMediaType::Generic, 1);
cache.insert("b", vec![0], CacheMediaType::Generic, 1);
cache.insert("c", vec![0], CacheMediaType::Generic, 1);
cache.insert("d", vec![0], CacheMediaType::Generic, 1);
assert_eq!(cache.evictions(), 2);
}
#[test]
fn test_capacity_getter() {
let cache = default_cache(42);
assert_eq!(cache.capacity(), 42);
}
#[test]
fn test_default_fallback_weights() {
let cfg = WeightConfig::new(); let w = cfg.weights_for(CacheMediaType::VideoSegment);
let sum = w.recency + w.priority + w.size_penalty;
assert!((sum - 1.0).abs() < 1e-9);
}
}