use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
pub trait Cache: Send + Sync {
fn contains(&self, key: &str) -> bool;
fn insert(&self, key: String, value: Vec<u8>);
fn get(&self, key: &str) -> Option<Vec<u8>>;
}
#[derive(Debug, Clone)]
pub enum PrefetchStrategy {
Sequential {
lookahead: usize,
},
AccessPattern(Vec<String>),
}
impl PrefetchStrategy {
pub fn predict_next(&self, current_key: &str) -> Vec<String> {
match self {
PrefetchStrategy::Sequential { lookahead } => {
predict_sequential(current_key, *lookahead)
}
PrefetchStrategy::AccessPattern(pattern) => {
predict_access_pattern(current_key, pattern)
}
}
}
}
fn split_numeric_suffix(key: &str) -> Option<(&str, u64)> {
let digits_start = key
.char_indices()
.rev()
.take_while(|(_, c)| c.is_ascii_digit())
.last()
.map(|(i, _)| i);
match digits_start {
Some(idx) if idx < key.len() => {
let prefix = &key[..idx];
let separator_ok = prefix
.chars()
.next_back()
.map_or(false, |c| matches!(c, '-' | '_' | '/'));
if !separator_ok {
return None;
}
let num_str = &key[idx..];
num_str.parse::<u64>().ok().map(|n| (prefix, n))
}
_ => None,
}
}
fn predict_sequential(current_key: &str, lookahead: usize) -> Vec<String> {
if lookahead == 0 {
return Vec::new();
}
match split_numeric_suffix(current_key) {
Some((prefix, n)) => (1..=lookahead as u64)
.map(|offset| {
let width = current_key.len() - prefix.len();
if width > 1 {
format!("{prefix}{:0>width$}", n + offset, width = width)
} else {
format!("{prefix}{}", n + offset)
}
})
.collect(),
None => Vec::new(),
}
}
fn predict_access_pattern(current_key: &str, pattern: &[String]) -> Vec<String> {
if pattern.is_empty() {
return Vec::new();
}
pattern
.iter()
.position(|k| k == current_key)
.map(|idx| {
let next_idx = (idx + 1) % pattern.len();
vec![pattern[next_idx].clone()]
})
.unwrap_or_default()
}
pub struct Prefetcher {
pub strategy: PrefetchStrategy,
cache: Arc<dyn Cache>,
pending: Mutex<VecDeque<String>>,
max_pending: usize,
loader: Arc<dyn Fn(&str) -> Vec<u8> + Send + Sync>,
}
impl std::fmt::Debug for Prefetcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Prefetcher")
.field("strategy", &self.strategy)
.field("max_pending", &self.max_pending)
.finish()
}
}
impl Prefetcher {
pub fn new(strategy: PrefetchStrategy, cache: Arc<dyn Cache>) -> Self {
Self {
strategy,
cache,
pending: Mutex::new(VecDeque::new()),
max_pending: 256,
loader: Arc::new(|_key| Vec::new()),
}
}
pub fn with_loader<F>(strategy: PrefetchStrategy, cache: Arc<dyn Cache>, loader: F) -> Self
where
F: Fn(&str) -> Vec<u8> + Send + Sync + 'static,
{
Self {
strategy,
cache,
pending: Mutex::new(VecDeque::new()),
max_pending: 256,
loader: Arc::new(loader),
}
}
pub fn with_max_pending(mut self, max: usize) -> Self {
self.max_pending = max.max(1);
self
}
pub fn trigger_prefetch(&self, current_key: &str) {
let predicted = self.strategy.predict_next(current_key);
for key in predicted {
if !self.cache.contains(&key) {
let value = (self.loader)(&key);
self.cache.insert(key.clone(), value);
if let Ok(mut q) = self.pending.lock() {
if q.len() >= self.max_pending {
q.pop_front();
}
q.push_back(key);
}
}
}
}
pub fn pending_count(&self) -> usize {
self.pending.lock().map(|q| q.len()).unwrap_or(0)
}
pub fn drain_pending(&self) -> Vec<String> {
self.pending
.lock()
.map(|mut q| q.drain(..).collect())
.unwrap_or_default()
}
pub fn cache(&self) -> &Arc<dyn Cache> {
&self.cache
}
}
pub struct MemoryCache {
store: Mutex<std::collections::HashMap<String, Vec<u8>>>,
}
impl MemoryCache {
pub fn new() -> Self {
Self {
store: Mutex::new(std::collections::HashMap::new()),
}
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::new()
}
}
impl Cache for MemoryCache {
fn contains(&self, key: &str) -> bool {
self.store
.lock()
.map(|m| m.contains_key(key))
.unwrap_or(false)
}
fn insert(&self, key: String, value: Vec<u8>) {
if let Ok(mut m) = self.store.lock() {
m.insert(key, value);
}
}
fn get(&self, key: &str) -> Option<Vec<u8>> {
self.store.lock().ok().and_then(|m| m.get(key).cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
fn make_cache() -> Arc<MemoryCache> {
Arc::new(MemoryCache::new())
}
#[test]
fn test_sequential_predict_basic() {
let strategy = PrefetchStrategy::Sequential { lookahead: 3 };
let next = strategy.predict_next("segment-005");
assert_eq!(next, vec!["segment-006", "segment-007", "segment-008"]);
}
#[test]
fn test_sequential_predict_zero_lookahead() {
let strategy = PrefetchStrategy::Sequential { lookahead: 0 };
assert!(strategy.predict_next("seg-1").is_empty());
}
#[test]
fn test_sequential_predict_non_numeric() {
let strategy = PrefetchStrategy::Sequential { lookahead: 2 };
assert!(strategy.predict_next("manifest.m3u8").is_empty());
}
#[test]
fn test_access_pattern_predict_next() {
let keys = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let strategy = PrefetchStrategy::AccessPattern(keys);
let next = strategy.predict_next("b");
assert_eq!(next, vec!["c"]);
}
#[test]
fn test_access_pattern_wrap_around() {
let keys = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let strategy = PrefetchStrategy::AccessPattern(keys);
let next = strategy.predict_next("z");
assert_eq!(next, vec!["x"]);
}
#[test]
fn test_access_pattern_unknown_key() {
let keys = vec!["a".to_string(), "b".to_string()];
let strategy = PrefetchStrategy::AccessPattern(keys);
assert!(strategy.predict_next("unknown").is_empty());
}
#[test]
fn test_trigger_prefetch_sequential() {
let cache = make_cache();
let prefetcher = Prefetcher::new(
PrefetchStrategy::Sequential { lookahead: 2 },
Arc::clone(&cache) as Arc<dyn Cache>,
);
prefetcher.trigger_prefetch("seg-010");
assert!(cache.contains("seg-011"), "seg-011 should be prefetched");
assert!(cache.contains("seg-012"), "seg-012 should be prefetched");
assert!(
!cache.contains("seg-013"),
"seg-013 should NOT be prefetched"
);
}
#[test]
fn test_trigger_prefetch_no_overwrite() {
let cache = make_cache();
cache.insert("seg-002".to_string(), vec![0xAB]);
let prefetcher = Prefetcher::new(
PrefetchStrategy::Sequential { lookahead: 2 },
Arc::clone(&cache) as Arc<dyn Cache>,
);
prefetcher.trigger_prefetch("seg-001");
assert_eq!(
cache.get("seg-002"),
Some(vec![0xAB]),
"existing entry should not be overwritten"
);
}
#[test]
fn test_custom_loader() {
let cache = make_cache();
let prefetcher = Prefetcher::with_loader(
PrefetchStrategy::Sequential { lookahead: 1 },
Arc::clone(&cache) as Arc<dyn Cache>,
|key| format!("data-for-{key}").into_bytes(),
);
prefetcher.trigger_prefetch("chunk-004");
let val = cache
.get("chunk-005")
.expect("chunk-005 should be in cache");
assert_eq!(val, b"data-for-chunk-005");
}
#[test]
fn test_pending_queue() {
let cache = make_cache();
let prefetcher = Prefetcher::new(
PrefetchStrategy::Sequential { lookahead: 3 },
Arc::clone(&cache) as Arc<dyn Cache>,
);
prefetcher.trigger_prefetch("frame-100");
assert_eq!(prefetcher.pending_count(), 3);
let drained = prefetcher.drain_pending();
assert_eq!(drained.len(), 3);
assert_eq!(prefetcher.pending_count(), 0);
}
#[test]
fn test_max_pending_limit() {
let cache = make_cache();
let prefetcher = Prefetcher::new(
PrefetchStrategy::Sequential { lookahead: 5 },
Arc::clone(&cache) as Arc<dyn Cache>,
)
.with_max_pending(3);
prefetcher.trigger_prefetch("v-000");
assert!(
prefetcher.pending_count() <= 3,
"pending should not exceed max_pending=3"
);
}
#[test]
fn test_trigger_prefetch_access_pattern() {
let cache = make_cache();
let keys = vec![
"intro".to_string(),
"main".to_string(),
"credits".to_string(),
];
let prefetcher = Prefetcher::new(
PrefetchStrategy::AccessPattern(keys),
Arc::clone(&cache) as Arc<dyn Cache>,
);
prefetcher.trigger_prefetch("intro");
assert!(cache.contains("main"), "main should be prefetched");
assert!(
!cache.contains("credits"),
"credits should NOT be prefetched yet"
);
}
#[test]
fn test_concurrent_trigger_prefetch() {
let cache = Arc::new(MemoryCache::new());
let prefetcher = Arc::new(Prefetcher::new(
PrefetchStrategy::Sequential { lookahead: 1 },
Arc::clone(&cache) as Arc<dyn Cache>,
));
let threads: Vec<_> = (0..4)
.map(|i| {
let p = Arc::clone(&prefetcher);
thread::spawn(move || {
for j in 0..25u32 {
p.trigger_prefetch(&format!("seg-{}", i * 100 + j));
}
})
})
.collect();
for t in threads {
t.join().expect("thread panicked");
}
}
#[test]
fn test_sequential_zero_padded() {
let strategy = PrefetchStrategy::Sequential { lookahead: 2 };
let next = strategy.predict_next("segment-099");
assert_eq!(next, vec!["segment-100", "segment-101"]);
}
}