use std::collections::{BTreeMap, HashMap};
use std::fmt;
#[derive(Debug, Clone)]
pub struct MediaSegment {
pub segment_id: String,
pub stream_id: String,
pub sequence: u64,
pub duration_secs: f32,
pub data: Vec<u8>,
pub content_type: String,
}
impl MediaSegment {
#[inline]
pub fn byte_len(&self) -> u64 {
self.data.len() as u64
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SegmentRef {
pub stream_id: String,
pub sequence: u64,
}
#[derive(Debug, Clone)]
pub struct SegmentCacheConfig {
pub max_segments: usize,
pub max_bytes: u64,
pub prefetch_ahead: u8,
pub evict_played: bool,
}
impl Default for SegmentCacheConfig {
fn default() -> Self {
Self {
max_segments: 512,
max_bytes: 256 * 1024 * 1024, prefetch_ahead: 3,
evict_played: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SegmentCacheError {
CacheFull,
SegmentTooLarge,
DuplicateSegment,
}
impl fmt::Display for SegmentCacheError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SegmentCacheError::CacheFull => write!(f, "segment cache is full"),
SegmentCacheError::SegmentTooLarge => {
write!(f, "segment exceeds cache byte budget")
}
SegmentCacheError::DuplicateSegment => {
write!(f, "segment with this (stream_id, sequence) already cached")
}
}
}
}
impl std::error::Error for SegmentCacheError {}
#[derive(Debug, Clone, Default)]
pub struct SegmentCacheStats {
pub total_segments: usize,
pub total_bytes: u64,
pub hit_count: u64,
pub miss_count: u64,
}
struct StreamMeta {
sequences: BTreeMap<u64, PlayState>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PlayState {
Unplayed,
Played,
}
impl StreamMeta {
fn new() -> Self {
Self {
sequences: BTreeMap::new(),
}
}
fn oldest_sequence(&self) -> Option<u64> {
self.sequences.keys().next().copied()
}
}
pub struct SegmentCache {
config: SegmentCacheConfig,
segments: HashMap<SegmentRef, MediaSegment>,
streams: HashMap<String, StreamMeta>,
total_bytes: u64,
hit_count: u64,
miss_count: u64,
}
impl SegmentCache {
pub fn new(config: SegmentCacheConfig) -> Self {
Self {
config,
segments: HashMap::new(),
streams: HashMap::new(),
total_bytes: 0,
hit_count: 0,
miss_count: 0,
}
}
pub fn insert(&mut self, segment: MediaSegment) -> Result<(), SegmentCacheError> {
let byte_len = segment.byte_len();
if byte_len > self.config.max_bytes {
return Err(SegmentCacheError::SegmentTooLarge);
}
let key = SegmentRef {
stream_id: segment.stream_id.clone(),
sequence: segment.sequence,
};
if self.segments.contains_key(&key) {
return Err(SegmentCacheError::DuplicateSegment);
}
let mut eviction_attempts = 0usize;
loop {
let count_ok = self.segments.len() < self.config.max_segments;
let bytes_ok = self.total_bytes + byte_len <= self.config.max_bytes;
if count_ok && bytes_ok {
break;
}
let freed = self.evict_one();
if freed == 0 {
return Err(SegmentCacheError::CacheFull);
}
eviction_attempts += 1;
if eviction_attempts > self.config.max_segments + 1 {
return Err(SegmentCacheError::CacheFull);
}
}
self.streams
.entry(segment.stream_id.clone())
.or_insert_with(StreamMeta::new)
.sequences
.insert(segment.sequence, PlayState::Unplayed);
self.total_bytes += byte_len;
self.segments.insert(key, segment);
Ok(())
}
pub fn get(&mut self, ref_: &SegmentRef) -> Option<&MediaSegment> {
match self.segments.get(ref_) {
Some(seg) => {
self.hit_count += 1;
Some(seg)
}
None => {
self.miss_count += 1;
None
}
}
}
pub fn mark_played(&mut self, ref_: &SegmentRef) {
if let Some(meta) = self.streams.get_mut(&ref_.stream_id) {
if let Some(state) = meta.sequences.get_mut(&ref_.sequence) {
*state = PlayState::Played;
}
}
}
pub fn prefetch_hints(&self, current_seq: u64, stream_id: &str) -> Vec<SegmentRef> {
let ahead = self.config.prefetch_ahead as u64;
let mut hints = Vec::with_capacity(ahead as usize);
for delta in 1..=ahead {
let seq = match current_seq.checked_add(delta) {
Some(s) => s,
None => break,
};
let ref_ = SegmentRef {
stream_id: stream_id.to_string(),
sequence: seq,
};
if !self.segments.contains_key(&ref_) {
hints.push(ref_);
}
}
hints
}
pub fn evict_oldest_stream(&mut self) -> usize {
self.evict_one()
}
pub fn stats(&self) -> SegmentCacheStats {
SegmentCacheStats {
total_segments: self.segments.len(),
total_bytes: self.total_bytes,
hit_count: self.hit_count,
miss_count: self.miss_count,
}
}
pub fn total_bytes(&self) -> u64 {
self.total_bytes
}
pub fn segment_count(&self) -> usize {
self.segments.len()
}
fn evict_one(&mut self) -> usize {
if let Some(target) = self.find_eviction_target() {
return self.remove_segment(&target);
}
0
}
fn find_eviction_target(&self) -> Option<SegmentRef> {
if self.config.evict_played {
if let Some(r) = self.oldest_played() {
return Some(r);
}
}
self.globally_oldest()
}
fn oldest_played(&self) -> Option<SegmentRef> {
let mut best: Option<(u64, &str)> = None; for (stream_id, meta) in &self.streams {
for (seq, state) in &meta.sequences {
if *state == PlayState::Played {
match best {
None => best = Some((*seq, stream_id.as_str())),
Some((best_seq, _)) if *seq < best_seq => {
best = Some((*seq, stream_id.as_str()));
}
_ => {}
}
}
}
}
best.map(|(seq, sid)| SegmentRef {
stream_id: sid.to_string(),
sequence: seq,
})
}
fn globally_oldest(&self) -> Option<SegmentRef> {
let mut best: Option<(u64, &str)> = None;
for (stream_id, meta) in &self.streams {
if let Some(seq) = meta.oldest_sequence() {
match best {
None => best = Some((seq, stream_id.as_str())),
Some((best_seq, _)) if seq < best_seq => {
best = Some((seq, stream_id.as_str()));
}
_ => {}
}
}
}
best.map(|(seq, sid)| SegmentRef {
stream_id: sid.to_string(),
sequence: seq,
})
}
fn remove_segment(&mut self, ref_: &SegmentRef) -> usize {
if let Some(seg) = self.segments.remove(ref_) {
let freed = seg.data.len();
self.total_bytes = self.total_bytes.saturating_sub(freed as u64);
if let Some(meta) = self.streams.get_mut(&ref_.stream_id) {
meta.sequences.remove(&ref_.sequence);
if meta.sequences.is_empty() {
self.streams.remove(&ref_.stream_id);
}
}
freed
} else {
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn seg(stream_id: &str, seq: u64, bytes: usize) -> MediaSegment {
MediaSegment {
segment_id: format!("{stream_id}-{seq:04}"),
stream_id: stream_id.to_string(),
sequence: seq,
duration_secs: 6.0,
data: vec![0u8; bytes],
content_type: "video/mp2t".to_string(),
}
}
fn default_config() -> SegmentCacheConfig {
SegmentCacheConfig {
max_segments: 16,
max_bytes: 1024 * 1024, prefetch_ahead: 3,
evict_played: true,
}
}
#[test]
fn test_insert_and_get() {
let mut cache = SegmentCache::new(default_config());
cache.insert(seg("s1", 0, 100)).expect("insert");
let r = SegmentRef {
stream_id: "s1".to_string(),
sequence: 0,
};
assert!(cache.get(&r).is_some());
assert_eq!(cache.stats().hit_count, 1);
}
#[test]
fn test_miss_increments_miss_count() {
let mut cache = SegmentCache::new(default_config());
let r = SegmentRef {
stream_id: "s1".to_string(),
sequence: 99,
};
assert!(cache.get(&r).is_none());
assert_eq!(cache.stats().miss_count, 1);
}
#[test]
fn test_duplicate_segment_rejected() {
let mut cache = SegmentCache::new(default_config());
cache.insert(seg("s1", 0, 100)).expect("first insert");
let err = cache.insert(seg("s1", 0, 100)).expect_err("should fail");
assert_eq!(err, SegmentCacheError::DuplicateSegment);
}
#[test]
fn test_segment_too_large_rejected() {
let config = SegmentCacheConfig {
max_bytes: 500,
..default_config()
};
let mut cache = SegmentCache::new(config);
let err = cache.insert(seg("s1", 0, 1000)).expect_err("too large");
assert_eq!(err, SegmentCacheError::SegmentTooLarge);
}
#[test]
fn test_byte_limit_evicts_oldest() {
let config = SegmentCacheConfig {
max_segments: 100,
max_bytes: 1000,
prefetch_ahead: 2,
evict_played: false,
};
let mut cache = SegmentCache::new(config);
for i in 0..5u64 {
cache.insert(seg("s1", i, 200)).expect("insert");
}
cache.insert(seg("s1", 5, 200)).expect("insert 6th");
let r0 = SegmentRef {
stream_id: "s1".to_string(),
sequence: 0,
};
assert!(cache.get(&r0).is_none());
assert!(cache.total_bytes() <= 1000);
}
#[test]
fn test_segment_count_limit_evicts() {
let config = SegmentCacheConfig {
max_segments: 3,
max_bytes: 10 * 1024 * 1024,
prefetch_ahead: 1,
evict_played: false,
};
let mut cache = SegmentCache::new(config);
for i in 0..4u64 {
cache.insert(seg("s1", i, 10)).expect("insert");
}
assert_eq!(cache.segment_count(), 3);
let r0 = SegmentRef {
stream_id: "s1".to_string(),
sequence: 0,
};
assert!(cache.get(&r0).is_none());
}
#[test]
fn test_played_eviction_prioritised() {
let config = SegmentCacheConfig {
max_segments: 3,
max_bytes: 10 * 1024 * 1024,
prefetch_ahead: 1,
evict_played: true,
};
let mut cache = SegmentCache::new(config);
for i in 0..3u64 {
cache.insert(seg("s1", i, 100)).expect("insert");
}
let r2 = SegmentRef {
stream_id: "s1".to_string(),
sequence: 2,
};
cache.mark_played(&r2);
cache.insert(seg("s1", 3, 100)).expect("insert");
assert!(cache.get(&r2).is_none(), "played segment should be evicted");
let r0 = SegmentRef {
stream_id: "s1".to_string(),
sequence: 0,
};
assert!(cache.get(&r0).is_some(), "unplayed seq 0 should remain");
}
#[test]
fn test_prefetch_hints_excludes_cached() {
let mut cache = SegmentCache::new(default_config());
cache.insert(seg("stream", 1, 50)).expect("insert");
cache.insert(seg("stream", 3, 50)).expect("insert");
let hints = cache.prefetch_hints(0, "stream");
assert_eq!(hints.len(), 1);
assert_eq!(hints[0].sequence, 2);
}
#[test]
fn test_prefetch_hints_respects_ahead_count() {
let config = SegmentCacheConfig {
prefetch_ahead: 5,
..default_config()
};
let cache = SegmentCache::new(config);
let hints = cache.prefetch_hints(10, "stream");
assert_eq!(hints.len(), 5);
for (i, h) in hints.iter().enumerate() {
assert_eq!(h.sequence, 11 + i as u64);
}
}
#[test]
fn test_evict_oldest_stream_returns_bytes_freed() {
let mut cache = SegmentCache::new(default_config());
cache.insert(seg("s1", 0, 512)).expect("insert");
let freed = cache.evict_oldest_stream();
assert_eq!(freed, 512);
assert_eq!(cache.segment_count(), 0);
assert_eq!(cache.total_bytes(), 0);
}
#[test]
fn test_stats_total_bytes() {
let mut cache = SegmentCache::new(default_config());
cache.insert(seg("s1", 0, 100)).expect("insert");
cache.insert(seg("s1", 1, 200)).expect("insert");
let s = cache.stats();
assert_eq!(s.total_segments, 2);
assert_eq!(s.total_bytes, 300);
}
#[test]
fn test_multi_stream_eviction_fairness() {
let config = SegmentCacheConfig {
max_segments: 4,
max_bytes: 10 * 1024 * 1024,
prefetch_ahead: 2,
evict_played: false,
};
let mut cache = SegmentCache::new(config);
cache.insert(seg("streamA", 0, 10)).expect("insert");
cache.insert(seg("streamB", 0, 10)).expect("insert");
cache.insert(seg("streamA", 1, 10)).expect("insert");
cache.insert(seg("streamB", 1, 10)).expect("insert");
cache.insert(seg("streamA", 2, 10)).expect("insert 5th");
assert_eq!(cache.segment_count(), 4);
}
}