1use std::collections::VecDeque;
7use std::sync::{Arc, Mutex};
8
9pub trait Cache: Send + Sync {
16 fn contains(&self, key: &str) -> bool;
18
19 fn insert(&self, key: String, value: Vec<u8>);
21
22 fn get(&self, key: &str) -> Option<Vec<u8>>;
24}
25
26#[derive(Debug, Clone)]
30pub enum PrefetchStrategy {
31 Sequential {
36 lookahead: usize,
38 },
39
40 AccessPattern(Vec<String>),
45}
46
47impl PrefetchStrategy {
48 pub fn predict_next(&self, current_key: &str) -> Vec<String> {
52 match self {
53 PrefetchStrategy::Sequential { lookahead } => {
54 predict_sequential(current_key, *lookahead)
55 }
56 PrefetchStrategy::AccessPattern(pattern) => {
57 predict_access_pattern(current_key, pattern)
58 }
59 }
60 }
61}
62
63fn split_numeric_suffix(key: &str) -> Option<(&str, u64)> {
72 let digits_start = key
74 .char_indices()
75 .rev()
76 .take_while(|(_, c)| c.is_ascii_digit())
77 .last()
78 .map(|(i, _)| i);
79
80 match digits_start {
81 Some(idx) if idx < key.len() => {
82 let prefix = &key[..idx];
83 let separator_ok = prefix
88 .chars()
89 .next_back()
90 .map_or(false, |c| matches!(c, '-' | '_' | '/'));
91 if !separator_ok {
92 return None;
93 }
94 let num_str = &key[idx..];
95 num_str.parse::<u64>().ok().map(|n| (prefix, n))
96 }
97 _ => None,
98 }
99}
100
101fn predict_sequential(current_key: &str, lookahead: usize) -> Vec<String> {
102 if lookahead == 0 {
103 return Vec::new();
104 }
105 match split_numeric_suffix(current_key) {
106 Some((prefix, n)) => (1..=lookahead as u64)
107 .map(|offset| {
108 let width = current_key.len() - prefix.len();
110 if width > 1 {
111 format!("{prefix}{:0>width$}", n + offset, width = width)
112 } else {
113 format!("{prefix}{}", n + offset)
114 }
115 })
116 .collect(),
117 None => Vec::new(),
118 }
119}
120
121fn predict_access_pattern(current_key: &str, pattern: &[String]) -> Vec<String> {
122 if pattern.is_empty() {
123 return Vec::new();
124 }
125 pattern
127 .iter()
128 .position(|k| k == current_key)
129 .map(|idx| {
130 let next_idx = (idx + 1) % pattern.len();
131 vec![pattern[next_idx].clone()]
132 })
133 .unwrap_or_default()
134}
135
136pub struct Prefetcher {
149 pub strategy: PrefetchStrategy,
151 cache: Arc<dyn Cache>,
153 pending: Mutex<VecDeque<String>>,
155 max_pending: usize,
157 loader: Arc<dyn Fn(&str) -> Vec<u8> + Send + Sync>,
160}
161
162impl std::fmt::Debug for Prefetcher {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("Prefetcher")
165 .field("strategy", &self.strategy)
166 .field("max_pending", &self.max_pending)
167 .finish()
168 }
169}
170
171impl Prefetcher {
172 pub fn new(strategy: PrefetchStrategy, cache: Arc<dyn Cache>) -> Self {
176 Self {
177 strategy,
178 cache,
179 pending: Mutex::new(VecDeque::new()),
180 max_pending: 256,
181 loader: Arc::new(|_key| Vec::new()),
182 }
183 }
184
185 pub fn with_loader<F>(strategy: PrefetchStrategy, cache: Arc<dyn Cache>, loader: F) -> Self
190 where
191 F: Fn(&str) -> Vec<u8> + Send + Sync + 'static,
192 {
193 Self {
194 strategy,
195 cache,
196 pending: Mutex::new(VecDeque::new()),
197 max_pending: 256,
198 loader: Arc::new(loader),
199 }
200 }
201
202 pub fn with_max_pending(mut self, max: usize) -> Self {
204 self.max_pending = max.max(1);
205 self
206 }
207
208 pub fn trigger_prefetch(&self, current_key: &str) {
215 let predicted = self.strategy.predict_next(current_key);
216 for key in predicted {
217 if !self.cache.contains(&key) {
218 let value = (self.loader)(&key);
219 self.cache.insert(key.clone(), value);
220 if let Ok(mut q) = self.pending.lock() {
222 if q.len() >= self.max_pending {
223 q.pop_front();
224 }
225 q.push_back(key);
226 }
227 }
228 }
229 }
230
231 pub fn pending_count(&self) -> usize {
233 self.pending.lock().map(|q| q.len()).unwrap_or(0)
234 }
235
236 pub fn drain_pending(&self) -> Vec<String> {
241 self.pending
242 .lock()
243 .map(|mut q| q.drain(..).collect())
244 .unwrap_or_default()
245 }
246
247 pub fn cache(&self) -> &Arc<dyn Cache> {
249 &self.cache
250 }
251}
252
253pub struct MemoryCache {
259 store: Mutex<std::collections::HashMap<String, Vec<u8>>>,
260}
261
262impl MemoryCache {
263 pub fn new() -> Self {
265 Self {
266 store: Mutex::new(std::collections::HashMap::new()),
267 }
268 }
269}
270
271impl Default for MemoryCache {
272 fn default() -> Self {
273 Self::new()
274 }
275}
276
277impl Cache for MemoryCache {
278 fn contains(&self, key: &str) -> bool {
279 self.store
280 .lock()
281 .map(|m| m.contains_key(key))
282 .unwrap_or(false)
283 }
284
285 fn insert(&self, key: String, value: Vec<u8>) {
286 if let Ok(mut m) = self.store.lock() {
287 m.insert(key, value);
288 }
289 }
290
291 fn get(&self, key: &str) -> Option<Vec<u8>> {
292 self.store.lock().ok().and_then(|m| m.get(key).cloned())
293 }
294}
295
296#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::sync::Arc;
302 use std::thread;
303
304 fn make_cache() -> Arc<MemoryCache> {
305 Arc::new(MemoryCache::new())
306 }
307
308 #[test]
310 fn test_sequential_predict_basic() {
311 let strategy = PrefetchStrategy::Sequential { lookahead: 3 };
312 let next = strategy.predict_next("segment-005");
313 assert_eq!(next, vec!["segment-006", "segment-007", "segment-008"]);
314 }
315
316 #[test]
318 fn test_sequential_predict_zero_lookahead() {
319 let strategy = PrefetchStrategy::Sequential { lookahead: 0 };
320 assert!(strategy.predict_next("seg-1").is_empty());
321 }
322
323 #[test]
325 fn test_sequential_predict_non_numeric() {
326 let strategy = PrefetchStrategy::Sequential { lookahead: 2 };
327 assert!(strategy.predict_next("manifest.m3u8").is_empty());
328 }
329
330 #[test]
332 fn test_access_pattern_predict_next() {
333 let keys = vec!["a".to_string(), "b".to_string(), "c".to_string()];
334 let strategy = PrefetchStrategy::AccessPattern(keys);
335 let next = strategy.predict_next("b");
336 assert_eq!(next, vec!["c"]);
337 }
338
339 #[test]
341 fn test_access_pattern_wrap_around() {
342 let keys = vec!["x".to_string(), "y".to_string(), "z".to_string()];
343 let strategy = PrefetchStrategy::AccessPattern(keys);
344 let next = strategy.predict_next("z");
345 assert_eq!(next, vec!["x"]);
346 }
347
348 #[test]
350 fn test_access_pattern_unknown_key() {
351 let keys = vec!["a".to_string(), "b".to_string()];
352 let strategy = PrefetchStrategy::AccessPattern(keys);
353 assert!(strategy.predict_next("unknown").is_empty());
354 }
355
356 #[test]
358 fn test_trigger_prefetch_sequential() {
359 let cache = make_cache();
360 let prefetcher = Prefetcher::new(
361 PrefetchStrategy::Sequential { lookahead: 2 },
362 Arc::clone(&cache) as Arc<dyn Cache>,
363 );
364 prefetcher.trigger_prefetch("seg-010");
365 assert!(cache.contains("seg-011"), "seg-011 should be prefetched");
366 assert!(cache.contains("seg-012"), "seg-012 should be prefetched");
367 assert!(
368 !cache.contains("seg-013"),
369 "seg-013 should NOT be prefetched"
370 );
371 }
372
373 #[test]
375 fn test_trigger_prefetch_no_overwrite() {
376 let cache = make_cache();
377 cache.insert("seg-002".to_string(), vec![0xAB]);
379 let prefetcher = Prefetcher::new(
380 PrefetchStrategy::Sequential { lookahead: 2 },
381 Arc::clone(&cache) as Arc<dyn Cache>,
382 );
383 prefetcher.trigger_prefetch("seg-001");
384 assert_eq!(
386 cache.get("seg-002"),
387 Some(vec![0xAB]),
388 "existing entry should not be overwritten"
389 );
390 }
391
392 #[test]
394 fn test_custom_loader() {
395 let cache = make_cache();
396 let prefetcher = Prefetcher::with_loader(
397 PrefetchStrategy::Sequential { lookahead: 1 },
398 Arc::clone(&cache) as Arc<dyn Cache>,
399 |key| format!("data-for-{key}").into_bytes(),
400 );
401 prefetcher.trigger_prefetch("chunk-004");
402 let val = cache
403 .get("chunk-005")
404 .expect("chunk-005 should be in cache");
405 assert_eq!(val, b"data-for-chunk-005");
406 }
407
408 #[test]
410 fn test_pending_queue() {
411 let cache = make_cache();
412 let prefetcher = Prefetcher::new(
413 PrefetchStrategy::Sequential { lookahead: 3 },
414 Arc::clone(&cache) as Arc<dyn Cache>,
415 );
416 prefetcher.trigger_prefetch("frame-100");
417 assert_eq!(prefetcher.pending_count(), 3);
419 let drained = prefetcher.drain_pending();
420 assert_eq!(drained.len(), 3);
421 assert_eq!(prefetcher.pending_count(), 0);
422 }
423
424 #[test]
426 fn test_max_pending_limit() {
427 let cache = make_cache();
428 let prefetcher = Prefetcher::new(
429 PrefetchStrategy::Sequential { lookahead: 5 },
430 Arc::clone(&cache) as Arc<dyn Cache>,
431 )
432 .with_max_pending(3);
433 prefetcher.trigger_prefetch("v-000");
435 assert!(
436 prefetcher.pending_count() <= 3,
437 "pending should not exceed max_pending=3"
438 );
439 }
440
441 #[test]
443 fn test_trigger_prefetch_access_pattern() {
444 let cache = make_cache();
445 let keys = vec![
446 "intro".to_string(),
447 "main".to_string(),
448 "credits".to_string(),
449 ];
450 let prefetcher = Prefetcher::new(
451 PrefetchStrategy::AccessPattern(keys),
452 Arc::clone(&cache) as Arc<dyn Cache>,
453 );
454 prefetcher.trigger_prefetch("intro");
455 assert!(cache.contains("main"), "main should be prefetched");
456 assert!(
457 !cache.contains("credits"),
458 "credits should NOT be prefetched yet"
459 );
460 }
461
462 #[test]
464 fn test_concurrent_trigger_prefetch() {
465 let cache = Arc::new(MemoryCache::new());
466 let prefetcher = Arc::new(Prefetcher::new(
467 PrefetchStrategy::Sequential { lookahead: 1 },
468 Arc::clone(&cache) as Arc<dyn Cache>,
469 ));
470 let threads: Vec<_> = (0..4)
471 .map(|i| {
472 let p = Arc::clone(&prefetcher);
473 thread::spawn(move || {
474 for j in 0..25u32 {
475 p.trigger_prefetch(&format!("seg-{}", i * 100 + j));
476 }
477 })
478 })
479 .collect();
480 for t in threads {
481 t.join().expect("thread panicked");
482 }
483 }
486
487 #[test]
489 fn test_sequential_zero_padded() {
490 let strategy = PrefetchStrategy::Sequential { lookahead: 2 };
491 let next = strategy.predict_next("segment-099");
492 assert_eq!(next, vec!["segment-100", "segment-101"]);
493 }
494}