use crate::backend::strategy::L2BackendStrategy;
use crate::error::Result;
use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
#[derive(Clone)]
pub struct AccessPatternTracker {
frequency: Arc<DashMap<String, AtomicU64>>,
recent_access: Arc<DashMap<String, Instant>>,
access_window: Arc<DashMap<String, VecDeque<Instant>>>,
total_accesses: Arc<AtomicU64>,
key_correlations: Arc<DashMap<String, DashMap<String, u32>>>,
}
impl AccessPatternTracker {
pub fn new() -> Self {
Self {
frequency: Arc::new(DashMap::new()),
recent_access: Arc::new(DashMap::new()),
access_window: Arc::new(DashMap::new()),
total_accesses: Arc::new(AtomicU64::new(0)),
key_correlations: Arc::new(DashMap::new()),
}
}
pub fn record_access(&self, key: &str, correlated_keys: &[&str]) {
let freq_entry = self.frequency.entry(key.to_string()).or_default();
freq_entry.fetch_add(1, Ordering::Relaxed);
let now = Instant::now();
self.recent_access.insert(key.to_string(), now);
let window_entry = self.access_window.entry(key.to_string()).or_default();
window_entry.push_back(now);
const MAX_WINDOW_SIZE: usize = 100;
while window_entry.len() > MAX_WINDOW_SIZE {
window_entry.pop_front();
}
for &corr_key in correlated_keys {
let corr_map = self.key_correlations
.entry(key.to_string())
.or_default();
let corr_entry = corr_map.entry(corr_key.to_string()).or_default();
*corr_entry += 1;
}
self.total_accesses.fetch_add(1, Ordering::Relaxed);
}
pub fn get_access_score(&self, key: &str) -> f64 {
let freq = self.frequency.get(key).map(|v| v.load(Ordering::Relaxed)).unwrap_or(0) as f64;
let recency_score = if let Some(accesses) = self.access_window.get(key) {
let now = Instant::now();
let recent_count = accesses
.iter()
.filter(|&&t| now.duration_since(t) < Duration::from_secs(60))
.count() as f64;
recent_count / 60.0 } else {
0.0
};
(freq * 0.6 + recency_score * 0.4).max(0.0)
}
pub fn get_predicted_keys(&self, accessed_key: &str, limit: usize) -> Vec<(String, f64)> {
let mut predictions = Vec::new();
if let Some(corr_map) = self.key_correlations.get(accessed_key) {
for entry in corr_map.iter() {
let key = entry.key().clone();
let score = *entry.value() as f64;
let access_score = self.get_access_score(&key);
let combined_score = score * 0.7 + access_score * 0.3;
predictions.push((key, combined_score));
}
}
predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
predictions.into_iter().take(limit).collect()
}
pub fn get_hot_keys(&self, min_frequency: u64, limit: usize) -> Vec<(String, u64)> {
let mut hot_keys: Vec<_> = self
.frequency
.iter()
.filter(|e| e.value().load(Ordering::Relaxed) >= min_frequency)
.map(|e| (e.key().clone(), e.value().load(Ordering::Relaxed)))
.collect();
hot_keys.sort_by(|a, b| b.1.cmp(&a.1));
hot_keys.into_iter().take(limit).collect()
}
pub fn total_accesses(&self) -> u64 {
self.total_accesses.load(Ordering::Relaxed)
}
pub fn cleanup(&self, max_age: Duration) {
let now = Instant::now();
self.recent_access
.retain(|_, &mut t| now.duration_since(t) < max_age * 2);
self.access_window
.retain(|_, accesses| {
accesses.retain(|&t| now.duration_since(t) < max_age);
!accesses.is_empty()
});
self.key_correlations
.retain(|_, corr_map| {
corr_map.retain(|_, &mut v| v > 0);
!corr_map.is_empty()
});
}
}
impl Default for AccessPatternTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct PrefetchRequest {
pub key: String,
pub priority: u8,
pub created_at: Instant,
}
impl PrefetchRequest {
pub fn new(key: String, priority: u8) -> Self {
Self {
key,
priority,
created_at: Instant::now(),
}
}
}
#[derive(Clone)]
pub struct AdaptivePrefetcher {
pub tracker: Arc<AccessPatternTracker>,
prefetch_tx: Arc<mpsc::Sender<PrefetchRequest>>,
in_progress: Arc<DashMap<String, Instant>>,
_prefetch_task: JoinHandle<()>,
enabled: Arc<AtomicU8>,
batch_size: usize,
prefetch_interval_ms: u64,
}
impl AdaptivePrefetcher {
pub fn new(
l2_backend: Arc<dyn L2BackendStrategy>,
tracker: Arc<AccessPatternTracker>,
batch_size: usize,
prefetch_interval_ms: u64,
) -> Self {
let (tx, rx) = mpsc::channel(1000);
let enabled = Arc::new(AtomicU8::new(1)); let in_progress = Arc::new(DashMap::new());
let prefetch_task = tokio::spawn(Self::prefetch_worker(
rx,
l2_backend,
tracker.clone(),
in_progress.clone(),
Arc::clone(&enabled),
batch_size,
prefetch_interval_ms,
));
Self {
tracker,
prefetch_tx: Arc::new(tx),
in_progress,
_prefetch_task: prefetch_task,
enabled,
batch_size,
prefetch_interval_ms,
}
}
async fn prefetch_worker(
mut rx: mpsc::Receiver<PrefetchRequest>,
l2_backend: Arc<dyn L2BackendStrategy>,
tracker: Arc<AccessPatternTracker>,
in_progress: Arc<DashMap<String, Instant>>,
enabled: Arc<AtomicU8>,
_batch_size: usize,
interval_ms: u64,
) {
let interval = Duration::from_millis(interval_ms);
while let Some(request) = rx.recv().await {
if enabled.load(Ordering::Relaxed) == 0 {
continue;
}
if in_progress.contains_key(&request.key) {
continue;
}
in_progress.insert(request.key.clone(), Instant::now());
match l2_backend.get(&request.key).await {
Ok(Some(_)) => {
debug!("Prefetched key: {}", request.key);
}
Ok(None) => {
}
Err(e) => {
warn!("Prefetch failed for key {}: {}", request.key, e);
}
}
in_progress.remove(&request.key);
tokio::time::sleep(interval).await;
}
}
pub async fn record_and_prefetch(&self, key: &str, correlated_keys: &[&str]) {
self.tracker.record_access(key, correlated_keys);
if self.enabled.load(Ordering::Relaxed) == 0 {
return;
}
let predictions = self.tracker.get_predicted_keys(key, self.batch_size);
for (pred_key, score) in predictions {
if score > 1.0 {
if !self.in_progress.contains_key(&pred_key) {
let request = PrefetchRequest::new(pred_key, (score.min(255.0) as u8));
if let Err(e) = self.prefetch_tx.send(request).await {
warn!("Failed to send prefetch request: {}", e);
}
}
}
}
}
pub async fn prefetch(&self, key: &str) -> Result<()> {
if self.enabled.load(Ordering::Relaxed) == 0 {
return Ok(());
}
if self.in_progress.contains_key(key) {
return Ok(());
}
let request = PrefetchRequest::new(key.to_string(), 128);
self.prefetch_tx.send(request).await.map_err(|e| {
crate::error::CacheError::L2Error(format!("Failed to send prefetch request: {}", e))
})?;
Ok(())
}
pub async fn prefetch_batch(&self, keys: &[&str]) {
for &key in keys {
if let Err(e) = self.prefetch(key).await {
warn!("Batch prefetch failed for key {}: {}", key, e);
}
}
}
pub fn set_enabled(&self, enabled: bool) {
self.enabled.store(if enabled { 1 } else { 0 }, Ordering::Relaxed);
info!("Adaptive prefetch {}", if enabled { "enabled" } else { "disabled" });
}
pub fn stats(&self) -> PrefetchStats {
PrefetchStats {
total_accesses: self.tracker.total_accesses(),
hot_keys_count: self.frequency_count(),
in_progress_count: self.in_progress.len(),
enabled: self.enabled.load(Ordering::Relaxed) == 1,
}
}
fn frequency_count(&self) -> usize {
self.tracker.frequency.len()
}
}
#[derive(Debug, Clone)]
pub struct PrefetchStats {
pub total_accesses: u64,
pub hot_keys_count: usize,
pub in_progress_count: usize,
pub enabled: bool,
}