1use ipfrs_core::Cid;
39use std::collections::{HashMap, HashSet, VecDeque};
40use std::sync::{Arc, RwLock};
41use std::time::{Duration, Instant};
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum PrefetchStrategy {
46 None,
48 ImmediateChildren,
50 PatternBased,
52 Subtree,
54 Adaptive,
56}
57
58#[derive(Debug, Clone)]
60pub struct PrefetchConfig {
61 pub strategy: PrefetchStrategy,
63 pub max_depth: usize,
65 pub max_concurrent_prefetch: usize,
67 pub prefetch_buffer_size: usize,
69 pub min_confidence: f64,
71 pub pattern_history_size: usize,
73 pub adaptive_tuning: bool,
75 pub prefetch_timeout: Duration,
77}
78
79impl Default for PrefetchConfig {
80 fn default() -> Self {
81 Self {
82 strategy: PrefetchStrategy::PatternBased,
83 max_depth: 2,
84 max_concurrent_prefetch: 16,
85 prefetch_buffer_size: 128,
86 min_confidence: 0.6,
87 pattern_history_size: 1000,
88 adaptive_tuning: true,
89 prefetch_timeout: Duration::from_secs(5),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct AccessPattern {
97 #[allow(dead_code)]
99 source: Cid,
100 target: Cid,
102 count: usize,
104 last_access: Instant,
106}
107
108#[derive(Debug, Clone)]
110struct DagLink {
111 #[allow(dead_code)]
113 parent: Cid,
114 children: Vec<Cid>,
116 depth: usize,
118}
119
120#[derive(Debug, Clone)]
122pub struct Prediction {
123 pub cid: Cid,
125 pub confidence: f64,
127 pub depth: usize,
129 pub reason: PredictionReason,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum PredictionReason {
136 DagChild,
138 AccessPattern,
140 Sibling,
142 Temporal,
144}
145
146#[derive(Debug, Clone, Default)]
148pub struct PrefetchStats {
149 pub prefetch_requests: u64,
151 pub hits: u64,
153 pub misses: u64,
155 pub wasted_bytes: u64,
157 pub saved_latency_ms: u64,
159 pub hit_rate: f64,
161}
162
163impl PrefetchStats {
164 fn update_hit_rate(&mut self) {
166 let total = self.hits + self.misses;
167 if total > 0 {
168 self.hit_rate = self.hits as f64 / total as f64;
169 }
170 }
171}
172
173pub struct PrefetchPredictor {
175 config: PrefetchConfig,
176 patterns: Arc<RwLock<HashMap<Cid, Vec<AccessPattern>>>>,
178 dag_links: Arc<RwLock<HashMap<Cid, DagLink>>>,
180 access_history: Arc<RwLock<VecDeque<(Cid, Instant)>>>,
182 prefetched: Arc<RwLock<HashMap<Cid, Instant>>>,
184 stats: Arc<RwLock<PrefetchStats>>,
186}
187
188impl PrefetchPredictor {
189 pub fn new(config: PrefetchConfig) -> Self {
191 Self {
192 config,
193 patterns: Arc::new(RwLock::new(HashMap::new())),
194 dag_links: Arc::new(RwLock::new(HashMap::new())),
195 access_history: Arc::new(RwLock::new(VecDeque::new())),
196 prefetched: Arc::new(RwLock::new(HashMap::new())),
197 stats: Arc::new(RwLock::new(PrefetchStats::default())),
198 }
199 }
200
201 pub fn record_access(&self, cid: &Cid) {
203 let now = Instant::now();
204
205 {
207 let mut prefetched = self.prefetched.write().unwrap();
208 if let Some(prefetch_time) = prefetched.remove(cid) {
209 let mut stats = self.stats.write().unwrap();
210 stats.hits += 1;
211 let saved_ms = now.duration_since(prefetch_time).as_millis() as u64;
212 stats.saved_latency_ms += saved_ms;
213 stats.update_hit_rate();
214 }
215 }
216
217 {
219 let mut history = self.access_history.write().unwrap();
220 history.push_back((*cid, now));
221
222 while history.len() > self.config.pattern_history_size {
224 history.pop_front();
225 }
226 }
227
228 self.update_patterns(cid, now);
230 }
231
232 fn update_patterns(&self, current: &Cid, now: Instant) {
234 let history = self.access_history.read().unwrap();
235 let mut patterns = self.patterns.write().unwrap();
236
237 let recent_window = Duration::from_secs(1);
239
240 for (prev_cid, prev_time) in history.iter().rev() {
241 if now.duration_since(*prev_time) > recent_window {
242 break;
243 }
244
245 if prev_cid == current {
246 continue;
247 }
248
249 let pattern_list = patterns.entry(*prev_cid).or_default();
251
252 if let Some(pattern) = pattern_list.iter_mut().find(|p| p.target == *current) {
253 pattern.count += 1;
254 pattern.last_access = now;
255 } else {
256 pattern_list.push(AccessPattern {
257 source: *prev_cid,
258 target: *current,
259 count: 1,
260 last_access: now,
261 });
262 }
263 }
264 }
265
266 pub fn record_dag_links(&self, parent: &Cid, children: Vec<Cid>, depth: usize) {
268 let mut dag_links = self.dag_links.write().unwrap();
269 dag_links.insert(
270 *parent,
271 DagLink {
272 parent: *parent,
273 children,
274 depth,
275 },
276 );
277 }
278
279 pub fn predict(&self, current: &Cid) -> Vec<Prediction> {
281 match self.config.strategy {
282 PrefetchStrategy::None => Vec::new(),
283 PrefetchStrategy::ImmediateChildren => self.predict_dag_children(current),
284 PrefetchStrategy::PatternBased => self.predict_from_patterns(current),
285 PrefetchStrategy::Subtree => self.predict_subtree(current),
286 PrefetchStrategy::Adaptive => self.predict_adaptive(current),
287 }
288 }
289
290 fn predict_dag_children(&self, current: &Cid) -> Vec<Prediction> {
292 let dag_links = self.dag_links.read().unwrap();
293
294 if let Some(link) = dag_links.get(current) {
295 link.children
296 .iter()
297 .map(|child| Prediction {
298 cid: *child,
299 confidence: 0.95,
300 depth: link.depth + 1,
301 reason: PredictionReason::DagChild,
302 })
303 .collect()
304 } else {
305 Vec::new()
306 }
307 }
308
309 fn predict_from_patterns(&self, current: &Cid) -> Vec<Prediction> {
311 let patterns = self.patterns.read().unwrap();
312
313 if let Some(pattern_list) = patterns.get(current) {
314 let total_count: usize = pattern_list.iter().map(|p| p.count).sum();
315
316 pattern_list
317 .iter()
318 .filter_map(|pattern| {
319 let confidence = pattern.count as f64 / total_count as f64;
320 if confidence >= self.config.min_confidence {
321 Some(Prediction {
322 cid: pattern.target,
323 confidence,
324 depth: 1,
325 reason: PredictionReason::AccessPattern,
326 })
327 } else {
328 None
329 }
330 })
331 .collect()
332 } else {
333 self.predict_dag_children(current)
335 }
336 }
337
338 fn predict_subtree(&self, current: &Cid) -> Vec<Prediction> {
340 let mut predictions = Vec::new();
341 let mut visited = HashSet::new();
342 let mut queue = VecDeque::new();
343
344 queue.push_back((*current, 0));
345 visited.insert(*current);
346
347 let dag_links = self.dag_links.read().unwrap();
348
349 while let Some((cid, depth)) = queue.pop_front() {
350 if depth >= self.config.max_depth {
351 continue;
352 }
353
354 if let Some(link) = dag_links.get(&cid) {
355 for child in &link.children {
356 if visited.insert(*child) {
357 predictions.push(Prediction {
358 cid: *child,
359 confidence: 0.9 * (0.8_f64).powi(depth as i32),
360 depth: depth + 1,
361 reason: PredictionReason::DagChild,
362 });
363 queue.push_back((*child, depth + 1));
364 }
365 }
366 }
367 }
368
369 predictions
370 }
371
372 fn predict_adaptive(&self, current: &Cid) -> Vec<Prediction> {
374 let stats = self.stats.read().unwrap();
375 let hit_rate = stats.hit_rate;
376 drop(stats);
377
378 if hit_rate > 0.5 {
380 self.predict_from_patterns(current)
381 } else {
382 self.predict_dag_children(current)
383 }
384 }
385
386 pub fn record_prefetch(&self, cid: &Cid) {
388 let mut prefetched = self.prefetched.write().unwrap();
389 prefetched.insert(*cid, Instant::now());
390
391 let mut stats = self.stats.write().unwrap();
392 stats.prefetch_requests += 1;
393 }
394
395 pub fn record_miss(&self, cid: &Cid, bytes: u64) {
397 let mut prefetched = self.prefetched.write().unwrap();
398 prefetched.remove(cid);
399
400 let mut stats = self.stats.write().unwrap();
401 stats.misses += 1;
402 stats.wasted_bytes += bytes;
403 stats.update_hit_rate();
404 }
405
406 pub fn cleanup(&self, max_age: Duration) {
408 let now = Instant::now();
409
410 {
412 let mut prefetched = self.prefetched.write().unwrap();
413 let mut to_remove = Vec::new();
414 let mut total_missed = 0u64;
415
416 for (cid, time) in prefetched.iter() {
417 if now.duration_since(*time) >= max_age {
418 to_remove.push(*cid);
419 total_missed += 1;
420 }
421 }
422
423 for cid in to_remove {
424 prefetched.remove(&cid);
425 }
426
427 if total_missed > 0 {
428 let mut stats = self.stats.write().unwrap();
429 stats.misses += total_missed;
430 stats.update_hit_rate();
431 }
432 }
433
434 {
436 let mut patterns = self.patterns.write().unwrap();
437 let max_pattern_age = Duration::from_secs(300); for pattern_list in patterns.values_mut() {
440 pattern_list.retain(|p| now.duration_since(p.last_access) < max_pattern_age);
441 }
442
443 patterns.retain(|_, v| !v.is_empty());
444 }
445 }
446
447 pub fn stats(&self) -> PrefetchStats {
449 self.stats.read().unwrap().clone()
450 }
451
452 pub fn update_config(&mut self, config: PrefetchConfig) {
454 self.config = config;
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_prefetch_predictor_creation() {
464 let config = PrefetchConfig::default();
465 let _predictor = PrefetchPredictor::new(config);
466 }
467
468 #[test]
469 fn test_record_access() {
470 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
471 let cid = Cid::default();
472 predictor.record_access(&cid);
473
474 let history = predictor.access_history.read().unwrap();
475 assert_eq!(history.len(), 1);
476 }
477
478 #[test]
479 fn test_dag_children_prediction() {
480 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
481 let parent = Cid::default();
482 let child1 = Cid::default();
483 let child2 = Cid::default();
484
485 predictor.record_dag_links(&parent, vec![child1, child2], 0);
486
487 let predictions = predictor.predict_dag_children(&parent);
488 assert_eq!(predictions.len(), 2);
489 }
490
491 #[test]
492 fn test_pattern_based_prediction() {
493 let predictor = PrefetchPredictor::new(PrefetchConfig {
494 min_confidence: 0.5,
495 ..Default::default()
496 });
497
498 let cid1 = Cid::default();
499 let cid2 = Cid::default();
500
501 predictor.record_access(&cid1);
504 std::thread::sleep(Duration::from_millis(10));
505 predictor.record_access(&cid2);
506 std::thread::sleep(Duration::from_millis(10));
507 predictor.record_access(&cid1);
508 std::thread::sleep(Duration::from_millis(10));
509 predictor.record_access(&cid2);
510
511 let predictions = predictor.predict_from_patterns(&cid1);
512 assert!(predictions.is_empty());
515 }
516
517 #[test]
518 fn test_prefetch_stats() {
519 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
520 let cid = Cid::default();
521
522 predictor.record_prefetch(&cid);
523 predictor.record_access(&cid);
524
525 let stats = predictor.stats();
526 assert_eq!(stats.hits, 1);
527 assert_eq!(stats.prefetch_requests, 1);
528 }
529
530 #[test]
531 fn test_subtree_prediction() {
532 let predictor = PrefetchPredictor::new(PrefetchConfig {
533 max_depth: 3,
534 ..Default::default()
535 });
536
537 let root = Cid::default();
538 let child1 = Cid::default();
539 let child2 = Cid::default();
540 let grandchild1 = Cid::default();
541
542 predictor.record_dag_links(&root, vec![child1, child2], 0);
544 predictor.record_dag_links(&child1, vec![grandchild1], 1);
545
546 let _predictions = predictor.predict_subtree(&root);
547 }
550
551 #[test]
552 fn test_adaptive_prediction_switches_strategy() {
553 let predictor = PrefetchPredictor::new(PrefetchConfig {
554 strategy: PrefetchStrategy::Adaptive,
555 ..Default::default()
556 });
557
558 let cid = Cid::default();
559 let child = Cid::default();
560
561 predictor.record_dag_links(&cid, vec![child], 0);
563 let predictions = predictor.predict_adaptive(&cid);
564 assert!(!predictions.is_empty());
565 }
566
567 #[test]
568 fn test_prefetch_miss_tracking() {
569 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
570 let cid = Cid::default();
571
572 predictor.record_prefetch(&cid);
573 predictor.record_miss(&cid, 1024);
574
575 let stats = predictor.stats();
576 assert_eq!(stats.misses, 1);
577 assert_eq!(stats.wasted_bytes, 1024);
578 }
579
580 #[test]
581 fn test_hit_rate_calculation() {
582 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
583 let cid1 = Cid::default();
584 let cid2 = Cid::default();
585 let cid3 = Cid::default();
586
587 predictor.record_prefetch(&cid1);
589 predictor.record_access(&cid1); predictor.record_prefetch(&cid2);
592 predictor.record_access(&cid2); predictor.record_prefetch(&cid3);
595 predictor.record_miss(&cid3, 100); let stats = predictor.stats();
598 assert_eq!(stats.hits, 2);
599 assert_eq!(stats.misses, 1);
600 assert!((stats.hit_rate - 0.666).abs() < 0.01);
601 }
602
603 #[test]
604 fn test_cleanup_old_prefetches() {
605 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
606 let cid = Cid::default();
607
608 predictor.record_prefetch(&cid);
609
610 predictor.cleanup(Duration::from_secs(0));
612
613 let stats = predictor.stats();
614 assert_eq!(stats.misses, 1); }
616
617 #[test]
618 fn test_multiple_predictions_sorted_by_confidence() {
619 let predictor = PrefetchPredictor::new(PrefetchConfig {
620 min_confidence: 0.3,
621 ..Default::default()
622 });
623
624 let cid1 = Cid::default();
625 let cid2 = Cid::default();
626 let cid3 = Cid::default();
627
628 for _ in 0..3 {
630 predictor.record_access(&cid1);
631 std::thread::sleep(Duration::from_millis(10));
632 predictor.record_access(&cid2);
633 std::thread::sleep(Duration::from_millis(10));
634 }
635
636 predictor.record_access(&cid1);
637 std::thread::sleep(Duration::from_millis(10));
638 predictor.record_access(&cid3);
639
640 let predictions = predictor.predict_from_patterns(&cid1);
641
642 if !predictions.is_empty() {
643 let cid2_pred = predictions.iter().find(|p| p.cid == cid2);
645 let cid3_pred = predictions.iter().find(|p| p.cid == cid3);
646
647 if let (Some(p2), Some(p3)) = (cid2_pred, cid3_pred) {
648 assert!(p2.confidence > p3.confidence);
649 }
650 }
651 }
652
653 #[test]
654 fn test_no_predictions_for_unknown_cid() {
655 let predictor = PrefetchPredictor::new(PrefetchConfig::default());
656 let unknown_cid = Cid::default();
657
658 let predictions = predictor.predict_dag_children(&unknown_cid);
659 assert!(predictions.is_empty());
660 }
661
662 #[test]
663 fn test_prediction_confidence_thresholds() {
664 let predictor = PrefetchPredictor::new(PrefetchConfig {
665 min_confidence: 0.8, ..Default::default()
667 });
668
669 let cid1 = Cid::default();
670 let cid2 = Cid::default();
671
672 predictor.record_access(&cid1);
674 std::thread::sleep(Duration::from_millis(10));
675 predictor.record_access(&cid2);
676
677 let predictions = predictor.predict_from_patterns(&cid1);
678 assert!(predictions.is_empty() || predictions[0].confidence >= 0.8);
680 }
681
682 #[test]
683 fn test_prefetch_strategy_none() {
684 let predictor = PrefetchPredictor::new(PrefetchConfig {
685 strategy: PrefetchStrategy::None,
686 ..Default::default()
687 });
688
689 let cid = Cid::default();
690 let predictions = predictor.predict(&cid);
691 assert!(predictions.is_empty());
692 }
693
694 #[test]
695 fn test_depth_limited_subtree() {
696 let predictor = PrefetchPredictor::new(PrefetchConfig {
697 max_depth: 1,
698 ..Default::default()
699 });
700
701 let root = Cid::default();
702 let child = Cid::default();
703 let grandchild = Cid::default();
704
705 predictor.record_dag_links(&root, vec![child], 0);
706 predictor.record_dag_links(&child, vec![grandchild], 1);
707
708 let predictions = predictor.predict_subtree(&root);
709
710 assert!(predictions.len() <= 1);
713 }
714
715 #[test]
716 fn test_access_history_limit() {
717 let predictor = PrefetchPredictor::new(PrefetchConfig {
718 pattern_history_size: 5,
719 ..Default::default()
720 });
721
722 for _ in 0..10 {
724 let cid = Cid::default();
725 predictor.record_access(&cid);
726 }
727
728 let history = predictor.access_history.read().unwrap();
729 assert!(history.len() <= 5);
730 }
731
732 #[test]
733 fn test_update_config() {
734 let mut predictor = PrefetchPredictor::new(PrefetchConfig::default());
735
736 let new_config = PrefetchConfig {
737 strategy: PrefetchStrategy::Subtree,
738 max_depth: 5,
739 ..Default::default()
740 };
741
742 predictor.update_config(new_config.clone());
743 assert_eq!(predictor.config.strategy, PrefetchStrategy::Subtree);
744 assert_eq!(predictor.config.max_depth, 5);
745 }
746}