1use std::collections::{HashMap, VecDeque};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::{RwLock, mpsc};
36use tracing::{debug, warn};
37
38#[derive(Debug, Clone)]
40pub struct PrefetchConfig {
41 pub max_cached_chunks: usize,
43 pub prefetch_ahead: u64,
45 pub max_cache_memory: usize,
47 pub cache_ttl: Duration,
49 pub enable_sequential_prediction: bool,
51 pub enable_popularity_prefetch: bool,
53}
54
55impl Default for PrefetchConfig {
56 fn default() -> Self {
57 Self {
58 max_cached_chunks: 100,
59 prefetch_ahead: 3,
60 max_cache_memory: 256 * 1024 * 1024, cache_ttl: Duration::from_secs(300), enable_sequential_prediction: true,
63 enable_popularity_prefetch: true,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct CachedChunk {
71 pub cid: String,
73 pub chunk_index: u64,
75 pub data: Vec<u8>,
77 pub cached_at: Instant,
79 pub access_count: u32,
81}
82
83impl CachedChunk {
84 #[inline]
86 #[must_use]
87 pub fn is_expired(&self, ttl: Duration) -> bool {
88 self.cached_at.elapsed() > ttl
89 }
90
91 #[inline]
93 #[must_use]
94 pub fn size(&self) -> usize {
95 self.data.len()
96 }
97}
98
99#[derive(Debug, Clone)]
101struct AccessPattern {
102 recent_accesses: VecDeque<u64>,
104 pattern_type: PatternType,
106 confidence: f64,
108 last_access: Instant,
110}
111
112impl Default for AccessPattern {
113 fn default() -> Self {
114 Self {
115 recent_accesses: VecDeque::with_capacity(10),
116 pattern_type: PatternType::Unknown,
117 confidence: 0.0,
118 last_access: Instant::now(),
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum PatternType {
126 Sequential,
128 ReverseSequential,
130 Strided { stride: i64 },
132 Random,
134 Unknown,
136}
137
138#[derive(Debug, Clone, Hash, PartialEq, Eq)]
140struct CacheKey {
141 cid: String,
142 chunk_index: u64,
143}
144
145#[derive(Debug, Clone)]
147pub struct PrefetchHint {
148 pub cid: String,
150 pub chunk_indices: Vec<u64>,
152 pub priority: u8,
154}
155
156pub struct ChunkPrefetcher {
158 config: PrefetchConfig,
159 cache: Arc<RwLock<HashMap<CacheKey, CachedChunk>>>,
161 patterns: Arc<RwLock<HashMap<String, AccessPattern>>>,
163 cache_memory: Arc<RwLock<usize>>,
165 prefetch_tx: Option<mpsc::Sender<PrefetchHint>>,
167 stats: Arc<RwLock<PrefetchStats>>,
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct PrefetchStats {
174 pub cache_hits: u64,
176 pub cache_misses: u64,
178 pub chunks_prefetched: u64,
180 pub successful_predictions: u64,
182 pub failed_predictions: u64,
184 pub cache_entries: usize,
186 pub cache_memory_bytes: usize,
188}
189
190impl PrefetchStats {
191 #[inline]
193 #[must_use]
194 pub fn hit_rate(&self) -> f64 {
195 let total = self.cache_hits + self.cache_misses;
196 if total == 0 {
197 0.0
198 } else {
199 self.cache_hits as f64 / total as f64
200 }
201 }
202
203 #[inline]
205 #[must_use]
206 pub fn prediction_accuracy(&self) -> f64 {
207 let total = self.successful_predictions + self.failed_predictions;
208 if total == 0 {
209 0.0
210 } else {
211 self.successful_predictions as f64 / total as f64
212 }
213 }
214}
215
216impl ChunkPrefetcher {
217 #[inline]
219 pub fn new(config: PrefetchConfig) -> Self {
220 Self {
221 config,
222 cache: Arc::new(RwLock::new(HashMap::new())),
223 patterns: Arc::new(RwLock::new(HashMap::new())),
224 cache_memory: Arc::new(RwLock::new(0)),
225 prefetch_tx: None,
226 stats: Arc::new(RwLock::new(PrefetchStats::default())),
227 }
228 }
229
230 pub fn set_prefetch_channel(&mut self, tx: mpsc::Sender<PrefetchHint>) {
232 self.prefetch_tx = Some(tx);
233 }
234
235 pub async fn get_cached(&self, cid: &str, chunk_index: u64) -> Option<Vec<u8>> {
237 let key = CacheKey {
238 cid: cid.to_string(),
239 chunk_index,
240 };
241
242 let mut cache = self.cache.write().await;
243 let mut stats = self.stats.write().await;
244
245 if let Some(entry) = cache.get_mut(&key) {
246 if entry.is_expired(self.config.cache_ttl) {
247 let size = entry.size();
249 cache.remove(&key);
250 let mut mem = self.cache_memory.write().await;
251 *mem = mem.saturating_sub(size);
252 stats.cache_misses += 1;
253 return None;
254 }
255
256 entry.access_count += 1;
258 stats.cache_hits += 1;
259 return Some(entry.data.clone());
260 }
261
262 stats.cache_misses += 1;
263 None
264 }
265
266 pub async fn put_cached(&self, cid: &str, chunk_index: u64, data: Vec<u8>) {
268 let key = CacheKey {
269 cid: cid.to_string(),
270 chunk_index,
271 };
272
273 let entry = CachedChunk {
274 cid: cid.to_string(),
275 chunk_index,
276 data,
277 cached_at: Instant::now(),
278 access_count: 1,
279 };
280
281 let entry_size = entry.size();
282
283 {
285 let mem = self.cache_memory.read().await;
286 if *mem + entry_size > self.config.max_cache_memory {
287 self.evict_entries(entry_size).await;
289 }
290 }
291
292 let mut cache = self.cache.write().await;
294 if cache.len() >= self.config.max_cached_chunks {
295 self.evict_lru(&mut cache).await;
296 }
297
298 cache.insert(key, entry);
299
300 let mut mem = self.cache_memory.write().await;
301 *mem += entry_size;
302
303 let mut stats = self.stats.write().await;
304 stats.cache_entries = cache.len();
305 stats.cache_memory_bytes = *mem;
306 }
307
308 pub async fn record_access(&self, cid: &str, chunk_index: u64) -> Vec<u64> {
310 let mut patterns = self.patterns.write().await;
311 let pattern = patterns
312 .entry(cid.to_string())
313 .or_insert_with(AccessPattern::default);
314
315 pattern.recent_accesses.push_back(chunk_index);
317 if pattern.recent_accesses.len() > 10 {
318 pattern.recent_accesses.pop_front();
319 }
320 pattern.last_access = Instant::now();
321
322 if pattern.recent_accesses.len() >= 3 {
324 pattern.pattern_type = self.detect_pattern(&pattern.recent_accesses);
325 pattern.confidence =
326 self.calculate_confidence(&pattern.recent_accesses, pattern.pattern_type);
327 }
328
329 self.predict_next_chunks(chunk_index, pattern)
331 }
332
333 pub async fn request_prefetch(&self, cid: &str, chunk_indices: Vec<u64>) {
335 if chunk_indices.is_empty() {
336 return;
337 }
338
339 if let Some(tx) = &self.prefetch_tx {
340 let hint = PrefetchHint {
341 cid: cid.to_string(),
342 chunk_indices,
343 priority: 128, };
345
346 if let Err(e) = tx.try_send(hint) {
347 warn!("Failed to send prefetch hint: {}", e);
348 }
349 }
350
351 let mut stats = self.stats.write().await;
352 stats.chunks_prefetched += 1;
353 }
354
355 pub async fn stats(&self) -> PrefetchStats {
357 self.stats.read().await.clone()
358 }
359
360 pub async fn clear_cache(&self) {
362 let mut cache = self.cache.write().await;
363 cache.clear();
364
365 let mut mem = self.cache_memory.write().await;
366 *mem = 0;
367
368 let mut stats = self.stats.write().await;
369 stats.cache_entries = 0;
370 stats.cache_memory_bytes = 0;
371 }
372
373 pub async fn clear_pattern(&self, cid: &str) {
375 let mut patterns = self.patterns.write().await;
376 patterns.remove(cid);
377 }
378
379 pub async fn evict_expired(&self) {
381 let mut cache = self.cache.write().await;
382 let mut mem = self.cache_memory.write().await;
383
384 let expired: Vec<CacheKey> = cache
385 .iter()
386 .filter(|(_, entry)| entry.is_expired(self.config.cache_ttl))
387 .map(|(key, _)| key.clone())
388 .collect();
389
390 for key in expired {
391 if let Some(entry) = cache.remove(&key) {
392 *mem = mem.saturating_sub(entry.size());
393 }
394 }
395
396 let mut stats = self.stats.write().await;
397 stats.cache_entries = cache.len();
398 stats.cache_memory_bytes = *mem;
399 }
400
401 fn detect_pattern(&self, accesses: &VecDeque<u64>) -> PatternType {
404 if accesses.len() < 3 {
405 return PatternType::Unknown;
406 }
407
408 let diffs: Vec<i64> = accesses
409 .iter()
410 .zip(accesses.iter().skip(1))
411 .map(|(a, b)| *b as i64 - *a as i64)
412 .collect();
413
414 if diffs.iter().all(|&d| d == 1) {
416 return PatternType::Sequential;
417 }
418
419 if diffs.iter().all(|&d| d == -1) {
421 return PatternType::ReverseSequential;
422 }
423
424 if diffs.len() >= 2 {
426 let first_diff = diffs[0];
427 if first_diff != 0 && diffs.iter().all(|&d| d == first_diff) {
428 return PatternType::Strided { stride: first_diff };
429 }
430 }
431
432 PatternType::Random
433 }
434
435 fn calculate_confidence(&self, accesses: &VecDeque<u64>, pattern: PatternType) -> f64 {
436 if accesses.len() < 3 {
437 return 0.0;
438 }
439
440 let base_confidence = match pattern {
441 PatternType::Sequential | PatternType::ReverseSequential => 0.9,
442 PatternType::Strided { .. } => 0.8,
443 PatternType::Random => 0.1,
444 PatternType::Unknown => 0.0,
445 };
446
447 let sample_factor = (accesses.len() as f64 / 10.0).min(1.0);
449 base_confidence * sample_factor
450 }
451
452 fn predict_next_chunks(&self, current: u64, pattern: &AccessPattern) -> Vec<u64> {
453 if !self.config.enable_sequential_prediction {
454 return vec![];
455 }
456
457 if pattern.confidence < 0.5 {
458 return (1..=self.config.prefetch_ahead)
460 .map(|i| current + i)
461 .collect();
462 }
463
464 let prefetch_count = self.config.prefetch_ahead;
465
466 match pattern.pattern_type {
467 PatternType::Sequential => (1..=prefetch_count).map(|i| current + i).collect(),
468 PatternType::ReverseSequential => (1..=prefetch_count)
469 .filter_map(|i| current.checked_sub(i))
470 .collect(),
471 PatternType::Strided { stride } => (1..=prefetch_count)
472 .filter_map(|i| {
473 let next = current as i64 + stride * i as i64;
474 if next >= 0 { Some(next as u64) } else { None }
475 })
476 .collect(),
477 PatternType::Random | PatternType::Unknown => {
478 (1..=prefetch_count).map(|i| current + i).collect()
480 }
481 }
482 }
483
484 async fn evict_entries(&self, needed_bytes: usize) {
485 let mut cache = self.cache.write().await;
486 let mut mem = self.cache_memory.write().await;
487
488 while *mem + needed_bytes > self.config.max_cache_memory && !cache.is_empty() {
489 self.evict_lru(&mut cache).await;
490 *mem = cache.values().map(|e| e.size()).sum();
491 }
492 }
493
494 async fn evict_lru(&self, cache: &mut HashMap<CacheKey, CachedChunk>) {
495 let lru_key = cache
497 .iter()
498 .min_by(|a, b| {
499 let score_a =
500 a.1.access_count as f64 / a.1.cached_at.elapsed().as_secs_f64().max(1.0);
501 let score_b =
502 b.1.access_count as f64 / b.1.cached_at.elapsed().as_secs_f64().max(1.0);
503 score_a
504 .partial_cmp(&score_b)
505 .unwrap_or(std::cmp::Ordering::Equal)
506 })
507 .map(|(k, _)| k.clone());
508
509 if let Some(key) = lru_key {
510 if let Some(entry) = cache.remove(&key) {
511 debug!(
512 "Evicted chunk from cache: {}:{}",
513 entry.cid, entry.chunk_index
514 );
515 }
516 }
517 }
518}
519
520#[derive(Debug, Default)]
522pub struct PrefetcherBuilder {
523 config: PrefetchConfig,
524}
525
526impl PrefetcherBuilder {
527 #[inline]
529 #[must_use]
530 pub fn new() -> Self {
531 Self::default()
532 }
533
534 #[inline]
536 #[must_use]
537 pub fn max_cached_chunks(mut self, count: usize) -> Self {
538 self.config.max_cached_chunks = count;
539 self
540 }
541
542 #[inline]
544 #[must_use]
545 pub fn prefetch_ahead(mut self, count: u64) -> Self {
546 self.config.prefetch_ahead = count;
547 self
548 }
549
550 #[inline]
552 #[must_use]
553 pub fn max_cache_memory(mut self, bytes: usize) -> Self {
554 self.config.max_cache_memory = bytes;
555 self
556 }
557
558 #[inline]
560 #[must_use]
561 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
562 self.config.cache_ttl = ttl;
563 self
564 }
565
566 #[inline]
568 #[must_use]
569 pub fn enable_sequential_prediction(mut self, enable: bool) -> Self {
570 self.config.enable_sequential_prediction = enable;
571 self
572 }
573
574 #[inline]
576 #[must_use]
577 pub fn build(self) -> ChunkPrefetcher {
578 ChunkPrefetcher::new(self.config)
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[tokio::test]
587 async fn test_cache_put_get() {
588 let prefetcher = ChunkPrefetcher::new(PrefetchConfig::default());
589
590 let data = vec![1, 2, 3, 4, 5];
591 prefetcher.put_cached("cid1", 0, data.clone()).await;
592
593 let cached = prefetcher.get_cached("cid1", 0).await;
594 assert_eq!(cached, Some(data));
595
596 let not_cached = prefetcher.get_cached("cid1", 1).await;
597 assert_eq!(not_cached, None);
598 }
599
600 #[tokio::test]
601 async fn test_pattern_detection_sequential() {
602 let prefetcher = ChunkPrefetcher::new(PrefetchConfig::default());
603
604 for i in 0..5 {
606 prefetcher.record_access("cid1", i).await;
607 }
608
609 let predicted = prefetcher.record_access("cid1", 5).await;
611 assert!(predicted.contains(&6));
612 assert!(predicted.contains(&7));
613 }
614
615 #[tokio::test]
616 async fn test_stats() {
617 let prefetcher = ChunkPrefetcher::new(PrefetchConfig::default());
618
619 prefetcher.get_cached("cid1", 0).await;
621
622 prefetcher.put_cached("cid1", 0, vec![1, 2, 3]).await;
624 prefetcher.get_cached("cid1", 0).await;
625
626 let stats = prefetcher.stats().await;
627 assert_eq!(stats.cache_hits, 1);
628 assert_eq!(stats.cache_misses, 1);
629 assert_eq!(stats.cache_entries, 1);
630 }
631}