use crate::traits::BlockStore;
use dashmap::DashMap;
use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::Semaphore;
use tracing::{debug, trace};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AccessPattern {
Sequential,
Random,
Clustered,
Temporal,
}
#[derive(Debug, Clone)]
struct AccessRecord {
#[allow(dead_code)]
timestamp: SystemTime,
#[allow(dead_code)]
previous_cid: Option<Cid>,
next_cid: Option<Cid>,
}
#[derive(Debug, Clone)]
struct CoLocationPattern {
count: u64,
last_seen: SystemTime,
confidence: f64,
}
#[derive(Debug, Clone)]
pub struct PrefetchPrediction {
pub cid: Cid,
pub confidence: f64,
pub predicted_access: SystemTime,
pub pattern: AccessPattern,
}
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub max_prefetch_depth: usize,
pub min_confidence: f64,
pub max_concurrent_prefetch: usize,
pub pattern_window: Duration,
pub enable_sequential: bool,
pub enable_colocation: bool,
pub enable_temporal: bool,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
max_prefetch_depth: 5,
min_confidence: 0.6,
max_concurrent_prefetch: 3,
pattern_window: Duration::from_secs(300), enable_sequential: true,
enable_colocation: true,
enable_temporal: true,
}
}
}
#[derive(Debug, Default)]
pub struct PrefetchStats {
pub prefetch_attempts: AtomicU64,
pub prefetch_hits: AtomicU64,
pub prefetch_misses: AtomicU64,
pub bytes_prefetched: AtomicU64,
pub avg_confidence: parking_lot::Mutex<f64>,
}
impl PrefetchStats {
fn record_attempt(&self) {
self.prefetch_attempts.fetch_add(1, Ordering::Relaxed);
}
fn record_hit(&self, bytes: u64) {
self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
self.bytes_prefetched.fetch_add(bytes, Ordering::Relaxed);
}
fn record_miss(&self) {
self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
}
pub fn hit_rate(&self) -> f64 {
let hits = self.prefetch_hits.load(Ordering::Relaxed) as f64;
let total = self.prefetch_attempts.load(Ordering::Relaxed) as f64;
if total > 0.0 {
hits / total
} else {
0.0
}
}
}
pub struct PredictivePrefetcher<S: BlockStore> {
store: Arc<S>,
config: parking_lot::RwLock<PrefetchConfig>,
access_history: DashMap<Cid, VecDeque<AccessRecord>>,
colocation_patterns: DashMap<Cid, DashMap<Cid, CoLocationPattern>>,
last_accessed: parking_lot::Mutex<Option<Cid>>,
#[allow(dead_code)]
prefetch_queue: DashMap<Cid, PrefetchPrediction>,
prefetch_cache: DashMap<Cid, (Vec<u8>, SystemTime)>,
stats: PrefetchStats,
prefetch_semaphore: Arc<Semaphore>,
current_depth: AtomicUsize,
}
impl<S: BlockStore + Send + Sync + 'static> PredictivePrefetcher<S> {
pub fn new(store: Arc<S>, config: PrefetchConfig) -> Self {
let max_concurrent = config.max_concurrent_prefetch;
let initial_depth = config.max_prefetch_depth;
Self {
store,
config: parking_lot::RwLock::new(config),
access_history: DashMap::new(),
colocation_patterns: DashMap::new(),
last_accessed: parking_lot::Mutex::new(None),
prefetch_queue: DashMap::new(),
prefetch_cache: DashMap::new(),
stats: PrefetchStats::default(),
prefetch_semaphore: Arc::new(Semaphore::new(max_concurrent)),
current_depth: AtomicUsize::new(initial_depth),
}
}
pub fn record_access(&self, cid: &Cid) {
let now = SystemTime::now();
let previous = self.last_accessed.lock().clone();
{
let mut history = self
.access_history
.entry(*cid)
.or_insert_with(VecDeque::new);
history.push_back(AccessRecord {
timestamp: now,
previous_cid: previous,
next_cid: None,
});
if history.len() > 100 {
history.pop_front();
}
}
if let Some(prev_cid) = previous {
if prev_cid != *cid {
if let Some(mut prev_history) = self.access_history.get_mut(&prev_cid) {
if let Some(last_record) = prev_history.back_mut() {
last_record.next_cid = Some(*cid);
}
}
}
if self.config.read().enable_colocation {
self.update_colocation_pattern(&prev_cid, cid);
}
}
*self.last_accessed.lock() = Some(*cid);
if let Some(entry) = self.prefetch_cache.get(cid) {
let prefetch_time = entry.value().1;
let age = now.duration_since(prefetch_time).unwrap_or_default();
if age < Duration::from_secs(60) {
self.stats.record_hit(0); } else {
self.stats.record_miss();
}
}
}
fn update_colocation_pattern(&self, cid1: &Cid, cid2: &Cid) {
let patterns = self
.colocation_patterns
.entry(*cid1)
.or_insert_with(DashMap::new);
patterns
.entry(*cid2)
.and_modify(|pattern| {
pattern.count += 1;
pattern.last_seen = SystemTime::now();
let recency_factor = 0.9; pattern.confidence = (pattern.confidence * recency_factor + 0.1).min(1.0);
})
.or_insert_with(|| CoLocationPattern {
count: 1,
last_seen: SystemTime::now(),
confidence: 0.5,
});
}
pub fn predict_next_blocks(&self, current_cid: &Cid) -> Vec<PrefetchPrediction> {
let config = self.config.read();
let mut predictions = Vec::new();
if config.enable_sequential {
if let Some(seq_predictions) = self.predict_sequential(current_cid) {
predictions.extend(seq_predictions);
}
}
if config.enable_colocation {
if let Some(coloc_predictions) = self.predict_colocation(current_cid) {
predictions.extend(coloc_predictions);
}
}
predictions.retain(|p| p.confidence >= config.min_confidence);
predictions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let depth = self.current_depth.load(Ordering::Relaxed);
predictions.truncate(depth);
predictions
}
fn predict_sequential(&self, cid: &Cid) -> Option<Vec<PrefetchPrediction>> {
let history = self.access_history.get(cid)?;
let next_counts: DashMap<Cid, u64> = DashMap::new();
for record in history.iter() {
if let Some(next_cid) = record.next_cid {
*next_counts.entry(next_cid).or_insert(0) += 1;
}
}
if next_counts.is_empty() {
return None;
}
let mut predictions = Vec::new();
let total_accesses = history.len() as f64;
for entry in next_counts.iter() {
let count = *entry.value() as f64;
let confidence = count / total_accesses;
if confidence >= 0.3 {
predictions.push(PrefetchPrediction {
cid: *entry.key(),
confidence,
predicted_access: SystemTime::now(),
pattern: AccessPattern::Sequential,
});
}
}
Some(predictions)
}
fn predict_colocation(&self, cid: &Cid) -> Option<Vec<PrefetchPrediction>> {
let patterns = self.colocation_patterns.get(cid)?;
let mut predictions = Vec::new();
for entry in patterns.iter() {
let pattern = entry.value();
let age = SystemTime::now()
.duration_since(pattern.last_seen)
.unwrap_or_default();
if age < self.config.read().pattern_window {
predictions.push(PrefetchPrediction {
cid: *entry.key(),
confidence: pattern.confidence,
predicted_access: SystemTime::now(),
pattern: AccessPattern::Clustered,
});
}
}
Some(predictions)
}
pub async fn prefetch_background(&self, predictions: Vec<PrefetchPrediction>) {
for prediction in predictions {
let store = self.store.clone();
let cache = self.prefetch_cache.clone();
let stats = &self.stats;
let semaphore = self.prefetch_semaphore.clone();
stats.record_attempt();
let cid = prediction.cid;
trace!(
"Prefetching block {} (confidence: {:.2})",
cid,
prediction.confidence
);
tokio::spawn(async move {
let _permit = semaphore.acquire().await.ok();
if let Ok(Some(block)) = store.get(&cid).await {
cache.insert(cid, (block.data().to_vec(), SystemTime::now()));
debug!("Prefetched block {}", cid);
}
});
}
}
pub fn adapt_depth(&self) {
let hit_rate = self.stats.hit_rate();
let current = self.current_depth.load(Ordering::Relaxed);
let max_depth = self.config.read().max_prefetch_depth;
let new_depth = if hit_rate > 0.8 {
(current + 1).min(max_depth)
} else if hit_rate < 0.4 {
(current.saturating_sub(1)).max(1)
} else {
current
};
if new_depth != current {
self.current_depth.store(new_depth, Ordering::Relaxed);
debug!(
"Adapted prefetch depth: {} -> {} (hit rate: {:.2})",
current, new_depth, hit_rate
);
}
}
pub fn stats(&self) -> PrefetchStatsSnapshot {
PrefetchStatsSnapshot {
prefetch_attempts: self.stats.prefetch_attempts.load(Ordering::Relaxed),
prefetch_hits: self.stats.prefetch_hits.load(Ordering::Relaxed),
prefetch_misses: self.stats.prefetch_misses.load(Ordering::Relaxed),
bytes_prefetched: self.stats.bytes_prefetched.load(Ordering::Relaxed),
hit_rate: self.stats.hit_rate(),
current_depth: self.current_depth.load(Ordering::Relaxed),
}
}
pub fn clear_cache(&self) {
self.prefetch_cache.clear();
}
pub fn cache_size(&self) -> usize {
self.prefetch_cache.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefetchStatsSnapshot {
pub prefetch_attempts: u64,
pub prefetch_hits: u64,
pub prefetch_misses: u64,
pub bytes_prefetched: u64,
pub hit_rate: f64,
pub current_depth: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::MemoryBlockStore;
use ipfrs_core::cid::CidBuilder;
fn test_cid(index: u64) -> Cid {
CidBuilder::new()
.build(&index.to_le_bytes())
.expect("failed to create test cid")
}
#[tokio::test]
async fn test_prefetcher_creation() {
let store = Arc::new(MemoryBlockStore::new());
let config = PrefetchConfig::default();
let prefetcher = PredictivePrefetcher::new(store, config);
let stats = prefetcher.stats();
assert_eq!(stats.prefetch_attempts, 0);
assert_eq!(stats.hit_rate, 0.0);
}
#[tokio::test]
async fn test_access_recording() {
let store = Arc::new(MemoryBlockStore::new());
let prefetcher = PredictivePrefetcher::new(store, PrefetchConfig::default());
let cid1 = test_cid(1);
let cid2 = test_cid(2);
prefetcher.record_access(&cid1);
prefetcher.record_access(&cid2);
assert!(prefetcher.colocation_patterns.contains_key(&cid1));
}
#[tokio::test]
async fn test_sequential_prediction() {
let store = Arc::new(MemoryBlockStore::new());
let prefetcher = PredictivePrefetcher::new(store, PrefetchConfig::default());
let cid1 = test_cid(1);
let cid2 = test_cid(2);
for _ in 0..5 {
prefetcher.record_access(&cid1);
prefetcher.record_access(&cid2);
}
let predictions = prefetcher.predict_next_blocks(&cid1);
assert!(!predictions.is_empty());
assert!(predictions
.iter()
.any(|p| p.pattern == AccessPattern::Sequential));
}
}